Skip to content

[ML] Change format for Unified Chat error responses #121396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/121396.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 121396
summary: Change format for Unified Chat
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContent;

import java.util.Iterator;
import java.util.Locale;
import java.util.Objects;

import static java.util.Collections.emptyIterator;
import static org.elasticsearch.ExceptionsHelper.maybeError;
import static org.elasticsearch.common.collect.Iterators.concat;
import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.endObject;
import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.startObject;

public class UnifiedChatCompletionException extends XContentFormattedException {

private static final Logger log = LogManager.getLogger(UnifiedChatCompletionException.class);
private final String message;
private final String type;
@Nullable
private final String code;
@Nullable
private final String param;

public UnifiedChatCompletionException(RestStatus status, String message, String type, @Nullable String code) {
this(status, message, type, code, null);
}

public UnifiedChatCompletionException(RestStatus status, String message, String type, @Nullable String code, @Nullable String param) {
super(message, status);
this.message = Objects.requireNonNull(message);
this.type = Objects.requireNonNull(type);
this.code = code;
this.param = param;
}

public UnifiedChatCompletionException(
Throwable cause,
RestStatus status,
String message,
String type,
@Nullable String code,
@Nullable String param
) {
super(message, cause, status);
this.message = Objects.requireNonNull(message);
this.type = Objects.requireNonNull(type);
this.code = code;
this.param = param;
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(Params params) {
return concat(
startObject(),
startObject("error"),
optionalField("code", code),
field("message", message),
optionalField("param", param),
field("type", type),
endObject(),
endObject()
);
}

private static Iterator<ToXContent> field(String key, String value) {
return ChunkedToXContentHelper.chunk((b, p) -> b.field(key, value));
}

private static Iterator<ToXContent> optionalField(String key, String value) {
return value != null ? ChunkedToXContentHelper.chunk((b, p) -> b.field(key, value)) : emptyIterator();
}

public static UnifiedChatCompletionException fromThrowable(Throwable t) {
if (ExceptionsHelper.unwrapCause(t) instanceof UnifiedChatCompletionException e) {
return e;
} else {
return maybeError(t).map(error -> {
// we should never be throwing Error, but just in case we are, rethrow it on another thread so the JVM can handle it and
// return a vague error to the user so that they at least see something went wrong but don't leak JVM details to users
ExceptionsHelper.maybeDieOnAnotherThread(error);
var e = new RuntimeException("Fatal error while streaming response. Please retry the request.");
log.error(e.getMessage(), t);
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
e.getMessage(),
getExceptionName(e),
RestStatus.INTERNAL_SERVER_ERROR.name().toLowerCase(Locale.ROOT)
);
}).orElseGet(() -> {
log.atDebug().withThrowable(t).log("UnifiedChatCompletionException stack trace for debugging purposes.");
var status = ExceptionsHelper.status(t);
return new UnifiedChatCompletionException(
t,
status,
t.getMessage(),
getExceptionName(t),
status.name().toLowerCase(Locale.ROOT),
null
);
});
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.results;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;

import java.util.Iterator;
import java.util.Objects;

/**
* Similar to {@link org.elasticsearch.ElasticsearchWrapperException}, this will wrap an Exception to generate an xContent using
* {@link ElasticsearchException#generateFailureXContent(XContentBuilder, Params, Exception, boolean)}.
* Extends {@link ElasticsearchException} to provide REST handlers the {@link #status()} method in order to set the response header.
*/
public class XContentFormattedException extends ElasticsearchException implements ChunkedToXContent {

public static final String X_CONTENT_PARAM = "detailedErrorsEnabled";
private final RestStatus status;
private final Throwable cause;

public XContentFormattedException(String message, RestStatus status) {
super(message);
this.status = Objects.requireNonNull(status);
this.cause = null;
}

public XContentFormattedException(Throwable cause, RestStatus status) {
super(cause);
this.status = Objects.requireNonNull(status);
this.cause = cause;
}

public XContentFormattedException(String message, Throwable cause, RestStatus status) {
super(message, cause);
this.status = Objects.requireNonNull(status);
this.cause = cause;
}

@Override
public RestStatus status() {
return status;
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(Params params) {
return Iterators.concat(
ChunkedToXContentHelper.startObject(),
Iterators.single(
(b, p) -> ElasticsearchException.generateFailureXContent(
b,
p,
cause instanceof Exception e ? e : this,
params.paramAsBoolean(X_CONTENT_PARAM, false)
)
),
Iterators.single((b, p) -> b.field("status", status.getStatus())),
ChunkedToXContentHelper.endObject()
);
}

@Override
public Iterator<? extends ToXContent> toXContentChunked(RestApiVersion restApiVersion, Params params) {
return ChunkedToXContent.super.toXContentChunked(restApiVersion, params);
}

@Override
public Iterator<? extends ToXContent> toXContentChunkedV8(Params params) {
return ChunkedToXContent.super.toXContentChunkedV8(params);
}

@Override
public boolean isFragment() {
return super.isFragment();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.XContentFormattedException;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
Expand Down Expand Up @@ -80,6 +82,7 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
private static final String REQUEST_COUNT = "request_count";
private static final String WITH_ERROR = "with_error";
private static final String ERROR_ROUTE = "/_inference_error";
private static final String FORMATTED_ERROR_ROUTE = "/_formatted_inference_error";
private static final String NO_STREAM_ROUTE = "/_inference_no_stream";
private static final Exception expectedException = new IllegalStateException("hello there");
private static final String expectedExceptionAsServerSentEvent = """
Expand All @@ -88,6 +91,11 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
"type":"illegal_state_exception","reason":"hello there"},"status":500\
}""";

private static final Exception expectedFormattedException = new XContentFormattedException(
expectedException,
RestStatus.INTERNAL_SERVER_ERROR
);

@Override
protected boolean addMockHttpTransport() {
return false;
Expand Down Expand Up @@ -145,6 +153,16 @@ public List<Route> routes() {
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedException);
}
}, new RestHandler() {
@Override
public List<Route> routes() {
return List.of(new Route(RestRequest.Method.POST, FORMATTED_ERROR_ROUTE));
}

@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) {
new ServerSentEventsRestActionListener(channel, threadPool).onFailure(expectedFormattedException);
}
}, new RestHandler() {
@Override
public List<Route> routes() {
Expand Down Expand Up @@ -424,6 +442,21 @@ public void testErrorMidStream() {
assertThat(collector.stringsVerified.getLast(), equalTo(expectedExceptionAsServerSentEvent));
}

public void testFormattedError() throws IOException {
var request = new Request(RestRequest.Method.POST.name(), FORMATTED_ERROR_ROUTE);

try {
getRestClient().performRequest(request);
fail("Expected an exception to be thrown from the error route");
} catch (ResponseException e) {
var response = e.getResponse();
assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
assertThat(EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8), equalTo("""
\uFEFFevent: error
data:\s""" + expectedExceptionAsServerSentEvent + "\n\n"));
}
}

public void testNoStream() {
var collector = new RandomStringCollector();
var expectedTestCount = randomIntBetween(2, 30);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@
import java.io.IOException;
import java.util.Random;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static org.elasticsearch.ExceptionsHelper.unwrapCause;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
Expand Down Expand Up @@ -280,7 +282,9 @@ private void inferOnServiceWithMetrics(
var instrumentedStream = new PublisherWithMetrics(timer, model);
taskProcessor.subscribe(instrumentedStream);

listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
var streamErrorHandler = streamErrorHandler(instrumentedStream);

listener.onResponse(new InferenceAction.Response(inferenceResults, streamErrorHandler));
} else {
recordMetrics(model, timer, null);
listener.onResponse(new InferenceAction.Response(inferenceResults));
Expand All @@ -291,9 +295,13 @@ private void inferOnServiceWithMetrics(
}));
}

protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<ChunkedToXContent, ChunkedToXContent> upstream) {
return upstream;
}

private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
try {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, unwrapCause(t)));
} catch (Exception e) {
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InferenceServiceResults;
Expand All @@ -20,14 +21,19 @@
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

import java.util.concurrent.Flow;

public class TransportUnifiedCompletionInferenceAction extends BaseTransportInferenceAction<UnifiedCompletionAction.Request> {

@Inject
Expand Down Expand Up @@ -86,4 +92,40 @@ protected void doInference(
) {
service.unifiedCompletionInfer(model, request.getUnifiedCompletionRequest(), null, listener);
}

@Override
protected void doExecute(Task task, UnifiedCompletionAction.Request request, ActionListener<InferenceAction.Response> listener) {
super.doExecute(task, request, listener.delegateResponse((l, e) -> l.onFailure(UnifiedChatCompletionException.fromThrowable(e))));
}

/**
* If we get any errors, either in {@link #doExecute} via the listener.onFailure or while streaming, make sure that they are formatted
* as {@link UnifiedChatCompletionException}.
*/
@Override
protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<ChunkedToXContent, ChunkedToXContent> upstream) {
return downstream -> {
upstream.subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
downstream.onSubscribe(subscription);
}

@Override
public void onNext(ChunkedToXContent item) {
downstream.onNext(item);
}

@Override
public void onError(Throwable throwable) {
downstream.onError(UnifiedChatCompletionException.fromThrowable(throwable));
}

@Override
public void onComplete() {
downstream.onComplete();
}
});
};
}
}
Loading