Skip to content

Commit b30f6a0

Browse files
[ML] Refactor inference request executor to leverage scheduled execution (#126858) (#126950)
* Using threadpool schedule and fixing tests * Update docs/changelog/126858.yaml * Clean up * change log (cherry picked from commit 7a0f63c) # Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java
1 parent da9e149 commit b30f6a0

File tree

4 files changed

+57
-74
lines changed

4 files changed

+57
-74
lines changed

docs/changelog/126858.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 126858
2+
summary: Leverage threadpool schedule for inference api to avoid long running thread
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 126853

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,6 @@
5555
*/
5656
class RequestExecutorService implements RequestExecutor {
5757

58-
/**
59-
* Provides dependency injection mainly for testing
60-
*/
61-
interface Sleeper {
62-
void sleep(TimeValue sleepTime) throws InterruptedException;
63-
}
64-
65-
// default for tests
66-
static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration());
6758
// default for tests
6859
static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> DEFAULT_QUEUE_CREATOR =
6960
new AdjustableCapacityBlockingQueue.QueueCreator<>() {
@@ -106,7 +97,6 @@ interface RateLimiterCreator {
10697
private final Clock clock;
10798
private final AtomicBoolean shutdown = new AtomicBoolean(false);
10899
private final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator;
109-
private final Sleeper sleeper;
110100
private final RateLimiterCreator rateLimiterCreator;
111101
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>();
112102
private final AtomicBoolean started = new AtomicBoolean(false);
@@ -117,16 +107,7 @@ interface RateLimiterCreator {
117107
RequestExecutorServiceSettings settings,
118108
RequestSender requestSender
119109
) {
120-
this(
121-
threadPool,
122-
DEFAULT_QUEUE_CREATOR,
123-
startupLatch,
124-
settings,
125-
requestSender,
126-
Clock.systemUTC(),
127-
DEFAULT_SLEEPER,
128-
DEFAULT_RATE_LIMIT_CREATOR
129-
);
110+
this(threadPool, DEFAULT_QUEUE_CREATOR, startupLatch, settings, requestSender, Clock.systemUTC(), DEFAULT_RATE_LIMIT_CREATOR);
130111
}
131112

132113
RequestExecutorService(
@@ -136,7 +117,6 @@ interface RateLimiterCreator {
136117
RequestExecutorServiceSettings settings,
137118
RequestSender requestSender,
138119
Clock clock,
139-
Sleeper sleeper,
140120
RateLimiterCreator rateLimiterCreator
141121
) {
142122
this.threadPool = Objects.requireNonNull(threadPool);
@@ -145,7 +125,6 @@ interface RateLimiterCreator {
145125
this.requestSender = Objects.requireNonNull(requestSender);
146126
this.settings = Objects.requireNonNull(settings);
147127
this.clock = Objects.requireNonNull(clock);
148-
this.sleeper = Objects.requireNonNull(sleeper);
149128
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator);
150129
}
151130

@@ -188,15 +167,10 @@ public void start() {
188167
startCleanupTask();
189168
signalStartInitiated();
190169

191-
while (isShutdown() == false) {
192-
handleTasks();
193-
}
194-
} catch (InterruptedException e) {
195-
Thread.currentThread().interrupt();
196-
} finally {
197-
shutdown();
198-
notifyRequestsOfShutdown();
199-
terminationLatch.countDown();
170+
handleTasks();
171+
} catch (Exception e) {
172+
logger.warn("Failed to start request executor", e);
173+
cleanup();
200174
}
201175
}
202176

@@ -231,13 +205,44 @@ void removeStaleGroupings() {
231205
}
232206
}
233207

234-
private void handleTasks() throws InterruptedException {
235-
var timeToWait = settings.getTaskPollFrequency();
236-
for (var endpoint : rateLimitGroupings.values()) {
237-
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
208+
private void scheduleNextHandleTasks(TimeValue timeToWait) {
209+
if (shutdown.get()) {
210+
logger.debug("Shutdown requested while scheduling next handle task call, cleaning up");
211+
cleanup();
212+
return;
238213
}
239214

240-
sleeper.sleep(timeToWait);
215+
threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME));
216+
}
217+
218+
private void cleanup() {
219+
try {
220+
shutdown();
221+
notifyRequestsOfShutdown();
222+
terminationLatch.countDown();
223+
} catch (Exception e) {
224+
logger.warn("Encountered an error while cleaning up", e);
225+
}
226+
}
227+
228+
private void handleTasks() {
229+
try {
230+
if (shutdown.get()) {
231+
logger.debug("Shutdown requested while handling tasks, cleaning up");
232+
cleanup();
233+
return;
234+
}
235+
236+
var timeToWait = settings.getTaskPollFrequency();
237+
for (var endpoint : rateLimitGroupings.values()) {
238+
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
239+
}
240+
241+
scheduleNextHandleTasks(timeToWait);
242+
} catch (Exception e) {
243+
logger.warn("Encountered an error while handling tasks", e);
244+
cleanup();
245+
}
241246
}
242247

243248
private void notifyRequestsOfShutdown() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
4141
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
4242
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
43+
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
4344
import static org.hamcrest.Matchers.equalTo;
4445
import static org.hamcrest.Matchers.hasSize;
4546
import static org.hamcrest.Matchers.instanceOf;
@@ -77,7 +78,7 @@ public void shutdown() throws IOException, InterruptedException {
7778
}
7879

7980
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
80-
var senderFactory = createSenderFactory(clientManager, threadRef);
81+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
8182

8283
try (var sender = createSender(senderFactory)) {
8384
sender.start();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
import static org.mockito.ArgumentMatchers.any;
5151
import static org.mockito.ArgumentMatchers.anyInt;
5252
import static org.mockito.Mockito.doAnswer;
53-
import static org.mockito.Mockito.doThrow;
5453
import static org.mockito.Mockito.mock;
5554
import static org.mockito.Mockito.times;
5655
import static org.mockito.Mockito.verify;
@@ -195,7 +194,7 @@ public void testExecute_Throws_WhenQueueIsFull() {
195194
assertFalse(thrownException.isExecutorShutdown());
196195
}
197196

198-
public void testTaskThrowsError_CallsOnFailure() {
197+
public void testTaskThrowsError_CallsOnFailure() throws InterruptedException {
199198
var requestSender = mock(RetryingHttpSender.class);
200199

201200
var service = createRequestExecutorService(null, requestSender);
@@ -218,6 +217,8 @@ public void testTaskThrowsError_CallsOnFailure() {
218217
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
219218
assertThat(thrownException.getMessage(), is(format("Failed to send request from inference entity id [%s]", "id")));
220219
assertThat(thrownException.getCause(), instanceOf(IllegalArgumentException.class));
220+
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
221+
221222
assertTrue(service.isTerminated());
222223
}
223224

@@ -342,7 +343,6 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I
342343
createRequestExecutorServiceSettingsEmpty(),
343344
requestSender,
344345
Clock.systemUTC(),
345-
RequestExecutorService.DEFAULT_SLEEPER,
346346
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
347347
);
348348

@@ -356,36 +356,7 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I
356356
});
357357
service.start();
358358

359-
assertTrue(service.isTerminated());
360-
}
361-
362-
public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception {
363-
@SuppressWarnings("unchecked")
364-
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
365-
var sleeper = mock(RequestExecutorService.Sleeper.class);
366-
doThrow(new InterruptedException("failed")).when(sleeper).sleep(any());
367-
368-
var service = new RequestExecutorService(
369-
threadPool,
370-
mockQueueCreator(queue),
371-
null,
372-
createRequestExecutorServiceSettingsEmpty(),
373-
mock(RetryingHttpSender.class),
374-
Clock.systemUTC(),
375-
sleeper,
376-
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
377-
);
378-
379-
Future<?> executorTermination = threadPool.generic().submit(() -> {
380-
try {
381-
service.start();
382-
} catch (Exception e) {
383-
fail(Strings.format("Failed to shutdown executor: %s", e));
384-
}
385-
});
386-
387-
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
388-
359+
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
389360
assertTrue(service.isTerminated());
390361
}
391362

@@ -552,7 +523,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens() {
552523
settings,
553524
requestSender,
554525
Clock.systemUTC(),
555-
RequestExecutorService.DEFAULT_SLEEPER,
556526
rateLimiterCreator
557527
);
558528
var requestManager = RequestManagerTests.createMock(requestSender);
@@ -585,7 +555,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And
585555
settings,
586556
requestSender,
587557
Clock.systemUTC(),
588-
RequestExecutorService.DEFAULT_SLEEPER,
589558
rateLimiterCreator
590559
);
591560
var requestManager = RequestManagerTests.createMock(requestSender);
@@ -597,11 +566,15 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And
597566

598567
doAnswer(invocation -> {
599568
service.shutdown();
569+
ActionListener<InferenceServiceResults> passedListener = invocation.getArgument(4);
570+
passedListener.onResponse(null);
571+
600572
return Void.TYPE;
601573
}).when(requestSender).send(any(), any(), any(), any(), any());
602574

603575
service.start();
604576

577+
listener.actionGet(TIMEOUT);
605578
verify(requestSender, times(1)).send(any(), any(), any(), any(), any());
606579
}
607580

@@ -619,7 +592,6 @@ public void testRemovesRateLimitGroup_AfterStaleDuration() {
619592
settings,
620593
requestSender,
621594
clock,
622-
RequestExecutorService.DEFAULT_SLEEPER,
623595
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
624596
);
625597
var requestManager = RequestManagerTests.createMock(requestSender, "id1");
@@ -653,7 +625,6 @@ public void testStartsCleanupThread() {
653625
settings,
654626
requestSender,
655627
Clock.systemUTC(),
656-
RequestExecutorService.DEFAULT_SLEEPER,
657628
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
658629
);
659630

0 commit comments

Comments
 (0)