datafusion_common_runtime/join_set.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::trace_utils::{trace_block, trace_future};
19use std::future::Future;
20use std::task::{Context, Poll};
21use tokio::runtime::Handle;
22use tokio::task::{AbortHandle, Id, JoinError, LocalSet};
23
24/// A wrapper around Tokio's JoinSet that forwards all API calls while optionally
25/// instrumenting spawned tasks and blocking closures with custom tracing behavior.
26/// If no tracer is injected via `trace_utils::set_tracer`, tasks and closures are executed
27/// without any instrumentation.
28#[derive(Debug)]
29pub struct JoinSet<T> {
30 inner: tokio::task::JoinSet<T>,
31}
32
33impl<T> Default for JoinSet<T> {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl<T> JoinSet<T> {
40 /// [JoinSet::new](tokio::task::JoinSet::new) - Create a new JoinSet.
41 pub fn new() -> Self {
42 Self {
43 inner: tokio::task::JoinSet::new(),
44 }
45 }
46
47 /// [JoinSet::len](tokio::task::JoinSet::len) - Return the number of tasks.
48 pub fn len(&self) -> usize {
49 self.inner.len()
50 }
51
52 /// [JoinSet::is_empty](tokio::task::JoinSet::is_empty) - Check if the JoinSet is empty.
53 pub fn is_empty(&self) -> bool {
54 self.inner.is_empty()
55 }
56}
57
58impl<T: 'static> JoinSet<T> {
59 /// [JoinSet::spawn](tokio::task::JoinSet::spawn) - Spawn a new task.
60 pub fn spawn<F>(&mut self, task: F) -> AbortHandle
61 where
62 F: Future<Output = T>,
63 F: Send + 'static,
64 T: Send,
65 {
66 self.inner.spawn(trace_future(task))
67 }
68
69 /// [JoinSet::spawn_on](tokio::task::JoinSet::spawn_on) - Spawn a task on a provided runtime.
70 pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
71 where
72 F: Future<Output = T>,
73 F: Send + 'static,
74 T: Send,
75 {
76 self.inner.spawn_on(trace_future(task), handle)
77 }
78
79 /// [JoinSet::spawn_local](tokio::task::JoinSet::spawn_local) - Spawn a local task.
80 pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
81 where
82 F: Future<Output = T>,
83 F: 'static,
84 {
85 self.inner.spawn_local(task)
86 }
87
88 /// [JoinSet::spawn_local_on](tokio::task::JoinSet::spawn_local_on) - Spawn a local task on a provided LocalSet.
89 pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet) -> AbortHandle
90 where
91 F: Future<Output = T>,
92 F: 'static,
93 {
94 self.inner.spawn_local_on(task, local_set)
95 }
96
97 /// [JoinSet::spawn_blocking](tokio::task::JoinSet::spawn_blocking) - Spawn a blocking task.
98 pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
99 where
100 F: FnOnce() -> T,
101 F: Send + 'static,
102 T: Send,
103 {
104 self.inner.spawn_blocking(trace_block(f))
105 }
106
107 /// [JoinSet::spawn_blocking_on](tokio::task::JoinSet::spawn_blocking_on) - Spawn a blocking task on a provided runtime.
108 pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
109 where
110 F: FnOnce() -> T,
111 F: Send + 'static,
112 T: Send,
113 {
114 self.inner.spawn_blocking_on(trace_block(f), handle)
115 }
116
117 /// [JoinSet::join_next](tokio::task::JoinSet::join_next) - Await the next completed task.
118 pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
119 self.inner.join_next().await
120 }
121
122 /// [JoinSet::try_join_next](tokio::task::JoinSet::try_join_next) - Try to join the next completed task.
123 pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
124 self.inner.try_join_next()
125 }
126
127 /// [JoinSet::abort_all](tokio::task::JoinSet::abort_all) - Abort all tasks.
128 pub fn abort_all(&mut self) {
129 self.inner.abort_all()
130 }
131
132 /// [JoinSet::detach_all](tokio::task::JoinSet::detach_all) - Detach all tasks.
133 pub fn detach_all(&mut self) {
134 self.inner.detach_all()
135 }
136
137 /// [JoinSet::poll_join_next](tokio::task::JoinSet::poll_join_next) - Poll for the next completed task.
138 pub fn poll_join_next(
139 &mut self,
140 cx: &mut Context<'_>,
141 ) -> Poll<Option<Result<T, JoinError>>> {
142 self.inner.poll_join_next(cx)
143 }
144
145 /// [JoinSet::join_next_with_id](tokio::task::JoinSet::join_next_with_id) - Await the next completed task with its ID.
146 pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
147 self.inner.join_next_with_id().await
148 }
149
150 /// [JoinSet::try_join_next_with_id](tokio::task::JoinSet::try_join_next_with_id) - Try to join the next completed task with its ID.
151 pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
152 self.inner.try_join_next_with_id()
153 }
154
155 /// [JoinSet::poll_join_next_with_id](tokio::task::JoinSet::poll_join_next_with_id) - Poll for the next completed task with its ID.
156 pub fn poll_join_next_with_id(
157 &mut self,
158 cx: &mut Context<'_>,
159 ) -> Poll<Option<Result<(Id, T), JoinError>>> {
160 self.inner.poll_join_next_with_id(cx)
161 }
162
163 /// [JoinSet::shutdown](tokio::task::JoinSet::shutdown) - Abort all tasks and wait for shutdown.
164 pub async fn shutdown(&mut self) {
165 self.inner.shutdown().await
166 }
167
168 /// [JoinSet::join_all](tokio::task::JoinSet::join_all) - Await all tasks.
169 pub async fn join_all(self) -> Vec<T> {
170 self.inner.join_all().await
171 }
172}