Skip to content

Resume Driver on cancelled or early finished #120020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/120020.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 120020
summary: Resume Driver on cancelled or early finished
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.compute.Describable;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.exchange.ExchangeSinkOperator;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
Expand Down Expand Up @@ -74,10 +75,9 @@ public class Driver implements Releasable, Describable {
private final long statusNanos;

private final AtomicReference<String> cancelReason = new AtomicReference<>();
private final AtomicReference<SubscribableListener<Void>> blocked = new AtomicReference<>();

private final AtomicBoolean started = new AtomicBoolean();
private final SubscribableListener<Void> completionListener = new SubscribableListener<>();
private final DriverScheduler scheduler = new DriverScheduler();

/**
* Status reported to the tasks API. We write the status at most once every
Expand Down Expand Up @@ -245,31 +245,33 @@ private IsBlockedResult runSingleLoopIteration() {
ensureNotCancelled();
boolean movedPage = false;

for (int i = 0; i < activeOperators.size() - 1; i++) {
Operator op = activeOperators.get(i);
Operator nextOp = activeOperators.get(i + 1);
if (activeOperators.isEmpty() == false && activeOperators.getLast().isFinished() == false) {
for (int i = 0; i < activeOperators.size() - 1; i++) {
Operator op = activeOperators.get(i);
Operator nextOp = activeOperators.get(i + 1);

// skip blocked operator
if (op.isBlocked().listener().isDone() == false) {
continue;
}
// skip blocked operator
if (op.isBlocked().listener().isDone() == false) {
continue;
}

if (op.isFinished() == false && nextOp.needsInput()) {
Page page = op.getOutput();
if (page == null) {
// No result, just move to the next iteration
} else if (page.getPositionCount() == 0) {
// Empty result, release any memory it holds immediately and move to the next iteration
page.releaseBlocks();
} else {
// Non-empty result from the previous operation, move it to the next operation
nextOp.addInput(page);
movedPage = true;
if (op.isFinished() == false && nextOp.needsInput()) {
Page page = op.getOutput();
if (page == null) {
// No result, just move to the next iteration
} else if (page.getPositionCount() == 0) {
// Empty result, release any memory it holds immediately and move to the next iteration
page.releaseBlocks();
} else {
// Non-empty result from the previous operation, move it to the next operation
nextOp.addInput(page);
movedPage = true;
}
}
}

if (op.isFinished()) {
nextOp.finish();
if (op.isFinished()) {
nextOp.finish();
}
}
}

Expand Down Expand Up @@ -312,19 +314,10 @@ private IsBlockedResult runSingleLoopIteration() {

public void cancel(String reason) {
if (cancelReason.compareAndSet(null, reason)) {
synchronized (this) {
SubscribableListener<Void> fut = this.blocked.get();
if (fut != null) {
fut.onFailure(new TaskCancelledException(reason));
}
}
scheduler.runPendingTasks();
}
}

private boolean isCancelled() {
return cancelReason.get() != null;
}

private void ensureNotCancelled() {
String reason = cancelReason.get();
if (reason != null) {
Expand All @@ -342,6 +335,15 @@ public static void start(
driver.completionListener.addListener(listener);
if (driver.started.compareAndSet(false, true)) {
driver.updateStatus(0, 0, DriverStatus.Status.STARTING, "driver starting");
// Register a listener to an exchange sink to handle early completion scenarios:
// 1. When the query accumulates sufficient data (e.g., reaching the LIMIT).
// 2. When users abort the query but want to retain the current result.
// This allows the Driver to finish early without waiting for the scheduled task.
if (driver.activeOperators.isEmpty() == false) {
if (driver.activeOperators.getLast() instanceof ExchangeSinkOperator sinkOperator) {
sinkOperator.addCompletionListener(ActionListener.running(driver.scheduler::runPendingTasks));
}
}
schedule(DEFAULT_TIME_BEFORE_YIELDING, maxIterations, threadContext, executor, driver, driver.completionListener);
}
}
Expand Down Expand Up @@ -371,7 +373,7 @@ private static void schedule(
Driver driver,
ActionListener<Void> listener
) {
executor.execute(new AbstractRunnable() {
final var task = new AbstractRunnable() {

@Override
protected void doRun() {
Expand All @@ -383,16 +385,12 @@ protected void doRun() {
if (fut.isDone()) {
schedule(maxTime, maxIterations, threadContext, executor, driver, listener);
} else {
synchronized (driver) {
if (driver.isCancelled() == false) {
driver.blocked.set(fut);
}
}
ActionListener<Void> readyListener = ActionListener.wrap(
ignored -> schedule(maxTime, maxIterations, threadContext, executor, driver, listener),
this::onFailure
);
fut.addListener(ContextPreservingActionListener.wrapPreservingContext(readyListener, threadContext));
driver.scheduler.addOrRunDelayedTask(() -> fut.onResponse(null));
}
}

Expand All @@ -405,7 +403,8 @@ public void onFailure(Exception e) {
void onComplete(ActionListener<Void> listener) {
driver.driverContext.waitForAsyncActions(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext));
}
});
};
driver.scheduler.scheduleOrRunTask(executor, task);
}

private static IsBlockedResult oneOf(List<IsBlockedResult> results) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.operator;

import org.elasticsearch.common.util.concurrent.EsExecutors;

import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

/**
* A Driver be put to sleep while its sink is full or its source is empty or be rescheduled after running several iterations.
* This scheduler tracks the delayed and scheduled tasks, allowing them to run without waking up the driver or waiting for
* the thread pool to pick up the task. This enables fast cancellation or early finishing without discarding the current result.
*/
final class DriverScheduler {
private final AtomicReference<Runnable> delayedTask = new AtomicReference<>();
private final AtomicReference<Runnable> scheduledTask = new AtomicReference<>();
private final AtomicBoolean completing = new AtomicBoolean();

void addOrRunDelayedTask(Runnable task) {
delayedTask.set(task);
if (completing.get()) {
final Runnable toRun = delayedTask.getAndSet(null);
if (toRun != null) {
assert task == toRun;
toRun.run();
}
}
}

void scheduleOrRunTask(Executor executor, Runnable task) {
final Runnable existing = scheduledTask.getAndSet(task);
assert existing == null : existing;
final Executor executorToUse = completing.get() ? EsExecutors.DIRECT_EXECUTOR_SERVICE : executor;
executorToUse.execute(() -> {
final Runnable next = scheduledTask.getAndSet(null);
if (next != null) {
assert next == task;
next.run();
}
});
}

void runPendingTasks() {
completing.set(true);
for (var taskHolder : List.of(delayedTask, scheduledTask)) {
final Runnable task = taskHolder.getAndSet(null);
if (task != null) {
task.run();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.compute.operator.exchange;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.IsBlockedResult;

Expand All @@ -30,6 +31,11 @@ public interface ExchangeSink {
*/
boolean isFinished();

/**
* Adds a listener that will be notified when this exchange sink is finished.
*/
void addCompletionListener(ActionListener<Void> listener);

/**
* Whether the sink is blocked on adding more pages
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ public ExchangeSinkHandler(BlockFactory blockFactory, int maxBufferSize, LongSup

private class ExchangeSinkImpl implements ExchangeSink {
boolean finished;
private final SubscribableListener<Void> onFinished = new SubscribableListener<>();

ExchangeSinkImpl() {
onChanged();
buffer.addCompletionListener(onFinished);
outstandingSinks.incrementAndGet();
}

Expand All @@ -68,6 +70,7 @@ public void addPage(Page page) {
public void finish() {
if (finished == false) {
finished = true;
onFinished.onResponse(null);
onChanged();
if (outstandingSinks.decrementAndGet() == 0) {
buffer.finish(false);
Expand All @@ -78,7 +81,12 @@ public void finish() {

@Override
public boolean isFinished() {
return finished || buffer.isFinished();
return onFinished.isDone();
}

@Override
public void addCompletionListener(ActionListener<Void> listener) {
onFinished.addListener(listener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -59,6 +60,10 @@ public boolean isFinished() {
return sink.isFinished();
}

public void addCompletionListener(ActionListener<Void> listener) {
sink.addCompletionListener(listener);
}

@Override
public void finish() {
sink.finish();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
Expand All @@ -21,6 +22,10 @@
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler;
import org.elasticsearch.compute.operator.exchange.ExchangeSinkOperator;
import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler;
import org.elasticsearch.compute.operator.exchange.ExchangeSourceOperator;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
Expand All @@ -35,8 +40,10 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.LongSupplier;

import static org.hamcrest.Matchers.either;
import static org.hamcrest.Matchers.equalTo;

public class DriverTests extends ESTestCase {
Expand Down Expand Up @@ -273,6 +280,33 @@ public Page getOutput() {
}
}

public void testResumeOnEarlyFinish() throws Exception {
DriverContext driverContext = driverContext();
ThreadPool threadPool = threadPool();
try {
PlainActionFuture<Void> sourceFuture = new PlainActionFuture<>();
var sourceHandler = new ExchangeSourceHandler(between(1, 5), threadPool.executor("esql"), sourceFuture);
var sinkHandler = new ExchangeSinkHandler(driverContext.blockFactory(), between(1, 5), System::currentTimeMillis);
var sourceOperator = new ExchangeSourceOperator(sourceHandler.createExchangeSource());
var sinkOperator = new ExchangeSinkOperator(sinkHandler.createExchangeSink(), Function.identity());
Driver driver = new Driver(driverContext, sourceOperator, List.of(), sinkOperator, () -> {});
PlainActionFuture<Void> future = new PlainActionFuture<>();
Driver.start(threadPool.getThreadContext(), threadPool.executor("esql"), driver, between(1, 1000), future);
assertBusy(
() -> assertThat(
driver.status().status(),
either(equalTo(DriverStatus.Status.ASYNC)).or(equalTo(DriverStatus.Status.STARTING))
)
);
sinkHandler.fetchPageAsync(true, ActionListener.noop());
future.actionGet(5, TimeUnit.SECONDS);
assertThat(driver.status().status(), equalTo(DriverStatus.Status.DONE));
sourceFuture.actionGet(5, TimeUnit.SECONDS);
} finally {
terminate(threadPool);
}
}

private static void assertRunningWithRegularUser(ThreadPool threadPool) {
String user = threadPool.getThreadContext().getHeader("user");
assertThat(user, equalTo("user1"));
Expand Down