datafusion_common_runtime/
join_set.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::trace_utils::{trace_block, trace_future};
19use std::future::Future;
20use std::task::{Context, Poll};
21use tokio::runtime::Handle;
22use tokio::task::{AbortHandle, Id, JoinError, LocalSet};
23
24/// A wrapper around Tokio's JoinSet that forwards all API calls while optionally
25/// instrumenting spawned tasks and blocking closures with custom tracing behavior.
26/// If no tracer is injected via `trace_utils::set_tracer`, tasks and closures are executed
27/// without any instrumentation.
28#[derive(Debug)]
29pub struct JoinSet<T> {
30    inner: tokio::task::JoinSet<T>,
31}
32
33impl<T> Default for JoinSet<T> {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl<T> JoinSet<T> {
40    /// [JoinSet::new](tokio::task::JoinSet::new) - Create a new JoinSet.
41    pub fn new() -> Self {
42        Self {
43            inner: tokio::task::JoinSet::new(),
44        }
45    }
46
47    /// [JoinSet::len](tokio::task::JoinSet::len) - Return the number of tasks.
48    pub fn len(&self) -> usize {
49        self.inner.len()
50    }
51
52    /// [JoinSet::is_empty](tokio::task::JoinSet::is_empty) - Check if the JoinSet is empty.
53    pub fn is_empty(&self) -> bool {
54        self.inner.is_empty()
55    }
56}
57
58impl<T: 'static> JoinSet<T> {
59    /// [JoinSet::spawn](tokio::task::JoinSet::spawn) - Spawn a new task.
60    pub fn spawn<F>(&mut self, task: F) -> AbortHandle
61    where
62        F: Future<Output = T>,
63        F: Send + 'static,
64        T: Send,
65    {
66        self.inner.spawn(trace_future(task))
67    }
68
69    /// [JoinSet::spawn_on](tokio::task::JoinSet::spawn_on) - Spawn a task on a provided runtime.
70    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
71    where
72        F: Future<Output = T>,
73        F: Send + 'static,
74        T: Send,
75    {
76        self.inner.spawn_on(trace_future(task), handle)
77    }
78
79    /// [JoinSet::spawn_local](tokio::task::JoinSet::spawn_local) - Spawn a local task.
80    pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
81    where
82        F: Future<Output = T>,
83        F: 'static,
84    {
85        self.inner.spawn_local(task)
86    }
87
88    /// [JoinSet::spawn_local_on](tokio::task::JoinSet::spawn_local_on) - Spawn a local task on a provided LocalSet.
89    pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet) -> AbortHandle
90    where
91        F: Future<Output = T>,
92        F: 'static,
93    {
94        self.inner.spawn_local_on(task, local_set)
95    }
96
97    /// [JoinSet::spawn_blocking](tokio::task::JoinSet::spawn_blocking) - Spawn a blocking task.
98    pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
99    where
100        F: FnOnce() -> T,
101        F: Send + 'static,
102        T: Send,
103    {
104        self.inner.spawn_blocking(trace_block(f))
105    }
106
107    /// [JoinSet::spawn_blocking_on](tokio::task::JoinSet::spawn_blocking_on) - Spawn a blocking task on a provided runtime.
108    pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
109    where
110        F: FnOnce() -> T,
111        F: Send + 'static,
112        T: Send,
113    {
114        self.inner.spawn_blocking_on(trace_block(f), handle)
115    }
116
117    /// [JoinSet::join_next](tokio::task::JoinSet::join_next) - Await the next completed task.
118    pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
119        self.inner.join_next().await
120    }
121
122    /// [JoinSet::try_join_next](tokio::task::JoinSet::try_join_next) - Try to join the next completed task.
123    pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
124        self.inner.try_join_next()
125    }
126
127    /// [JoinSet::abort_all](tokio::task::JoinSet::abort_all) - Abort all tasks.
128    pub fn abort_all(&mut self) {
129        self.inner.abort_all()
130    }
131
132    /// [JoinSet::detach_all](tokio::task::JoinSet::detach_all) - Detach all tasks.
133    pub fn detach_all(&mut self) {
134        self.inner.detach_all()
135    }
136
137    /// [JoinSet::poll_join_next](tokio::task::JoinSet::poll_join_next) - Poll for the next completed task.
138    pub fn poll_join_next(
139        &mut self,
140        cx: &mut Context<'_>,
141    ) -> Poll<Option<Result<T, JoinError>>> {
142        self.inner.poll_join_next(cx)
143    }
144
145    /// [JoinSet::join_next_with_id](tokio::task::JoinSet::join_next_with_id) - Await the next completed task with its ID.
146    pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
147        self.inner.join_next_with_id().await
148    }
149
150    /// [JoinSet::try_join_next_with_id](tokio::task::JoinSet::try_join_next_with_id) - Try to join the next completed task with its ID.
151    pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
152        self.inner.try_join_next_with_id()
153    }
154
155    /// [JoinSet::poll_join_next_with_id](tokio::task::JoinSet::poll_join_next_with_id) - Poll for the next completed task with its ID.
156    pub fn poll_join_next_with_id(
157        &mut self,
158        cx: &mut Context<'_>,
159    ) -> Poll<Option<Result<(Id, T), JoinError>>> {
160        self.inner.poll_join_next_with_id(cx)
161    }
162
163    /// [JoinSet::shutdown](tokio::task::JoinSet::shutdown) - Abort all tasks and wait for shutdown.
164    pub async fn shutdown(&mut self) {
165        self.inner.shutdown().await
166    }
167
168    /// [JoinSet::join_all](tokio::task::JoinSet::join_all) - Await all tasks.
169    pub async fn join_all(self) -> Vec<T> {
170        self.inner.join_all().await
171    }
172}