datafusion_ffi/udf/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{
19    arrow_wrappers::{WrappedArray, WrappedSchema},
20    df_result, rresult, rresult_return,
21    util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
22    volatility::FFI_Volatility,
23};
24use abi_stable::{
25    std_types::{RResult, RString, RVec},
26    StableAbi,
27};
28use arrow::datatypes::{DataType, Field};
29use arrow::{
30    array::ArrayRef,
31    error::ArrowError,
32    ffi::{from_ffi, to_ffi, FFI_ArrowSchema},
33};
34use arrow_schema::FieldRef;
35use datafusion::logical_expr::ReturnFieldArgs;
36use datafusion::{
37    error::DataFusionError,
38    logical_expr::type_coercion::functions::data_types_with_scalar_udf,
39};
40use datafusion::{
41    error::Result,
42    logical_expr::{
43        ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
44    },
45};
46use return_type_args::{
47    FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned,
48};
49use std::{ffi::c_void, sync::Arc};
50
51pub mod return_type_args;
52
53/// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries.
54#[repr(C)]
55#[derive(Debug, StableAbi)]
56#[allow(non_camel_case_types)]
57pub struct FFI_ScalarUDF {
58    /// FFI equivalent to the `name` of a [`ScalarUDF`]
59    pub name: RString,
60
61    /// FFI equivalent to the `aliases` of a [`ScalarUDF`]
62    pub aliases: RVec<RString>,
63
64    /// FFI equivalent to the `volatility` of a [`ScalarUDF`]
65    pub volatility: FFI_Volatility,
66
67    /// Determines the return type of the underlying [`ScalarUDF`] based on the
68    /// argument types.
69    pub return_type: unsafe extern "C" fn(
70        udf: &Self,
71        arg_types: RVec<WrappedSchema>,
72    ) -> RResult<WrappedSchema, RString>,
73
74    /// Determines the return info of the underlying [`ScalarUDF`]. Either this
75    /// or return_type may be implemented on a UDF.
76    pub return_field_from_args: unsafe extern "C" fn(
77        udf: &Self,
78        args: FFI_ReturnFieldArgs,
79    )
80        -> RResult<WrappedSchema, RString>,
81
82    /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray`
83    /// within an AbiStable wrapper.
84    #[allow(clippy::type_complexity)]
85    pub invoke_with_args: unsafe extern "C" fn(
86        udf: &Self,
87        args: RVec<WrappedArray>,
88        arg_fields: RVec<WrappedSchema>,
89        num_rows: usize,
90        return_field: WrappedSchema,
91    ) -> RResult<WrappedArray, RString>,
92
93    /// See [`ScalarUDFImpl`] for details on short_circuits
94    pub short_circuits: bool,
95
96    /// Performs type coersion. To simply this interface, all UDFs are treated as having
97    /// user defined signatures, which will in turn call coerce_types to be called. This
98    /// call should be transparent to most users as the internal function performs the
99    /// appropriate calls on the underlying [`ScalarUDF`]
100    pub coerce_types: unsafe extern "C" fn(
101        udf: &Self,
102        arg_types: RVec<WrappedSchema>,
103    ) -> RResult<RVec<WrappedSchema>, RString>,
104
105    /// Used to create a clone on the provider of the udf. This should
106    /// only need to be called by the receiver of the udf.
107    pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
108
109    /// Release the memory of the private data when it is no longer being used.
110    pub release: unsafe extern "C" fn(udf: &mut Self),
111
112    /// Internal data. This is only to be accessed by the provider of the udf.
113    /// A [`ForeignScalarUDF`] should never attempt to access this data.
114    pub private_data: *mut c_void,
115}
116
117unsafe impl Send for FFI_ScalarUDF {}
118unsafe impl Sync for FFI_ScalarUDF {}
119
120pub struct ScalarUDFPrivateData {
121    pub udf: Arc<ScalarUDF>,
122}
123
124unsafe extern "C" fn return_type_fn_wrapper(
125    udf: &FFI_ScalarUDF,
126    arg_types: RVec<WrappedSchema>,
127) -> RResult<WrappedSchema, RString> {
128    let private_data = udf.private_data as *const ScalarUDFPrivateData;
129    let udf = &(*private_data).udf;
130
131    let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
132
133    let return_type = udf
134        .return_type(&arg_types)
135        .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from))
136        .map(WrappedSchema);
137
138    rresult!(return_type)
139}
140
141unsafe extern "C" fn return_field_from_args_fn_wrapper(
142    udf: &FFI_ScalarUDF,
143    args: FFI_ReturnFieldArgs,
144) -> RResult<WrappedSchema, RString> {
145    let private_data = udf.private_data as *const ScalarUDFPrivateData;
146    let udf = &(*private_data).udf;
147
148    let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into());
149    let args_ref: ForeignReturnFieldArgs = (&args).into();
150
151    let return_type = udf
152        .return_field_from_args((&args_ref).into())
153        .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from))
154        .map(WrappedSchema);
155
156    rresult!(return_type)
157}
158
159unsafe extern "C" fn coerce_types_fn_wrapper(
160    udf: &FFI_ScalarUDF,
161    arg_types: RVec<WrappedSchema>,
162) -> RResult<RVec<WrappedSchema>, RString> {
163    let private_data = udf.private_data as *const ScalarUDFPrivateData;
164    let udf = &(*private_data).udf;
165
166    let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
167
168    let return_types = rresult_return!(data_types_with_scalar_udf(&arg_types, udf));
169
170    rresult!(vec_datatype_to_rvec_wrapped(&return_types))
171}
172
173unsafe extern "C" fn invoke_with_args_fn_wrapper(
174    udf: &FFI_ScalarUDF,
175    args: RVec<WrappedArray>,
176    arg_fields: RVec<WrappedSchema>,
177    number_rows: usize,
178    return_field: WrappedSchema,
179) -> RResult<WrappedArray, RString> {
180    let private_data = udf.private_data as *const ScalarUDFPrivateData;
181    let udf = &(*private_data).udf;
182
183    let args = args
184        .into_iter()
185        .map(|arr| {
186            from_ffi(arr.array, &arr.schema.0)
187                .map(|v| ColumnarValue::Array(arrow::array::make_array(v)))
188        })
189        .collect::<std::result::Result<_, _>>();
190
191    let args = rresult_return!(args);
192    let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
193
194    let arg_fields = arg_fields
195        .into_iter()
196        .map(|wrapped_field| {
197            Field::try_from(&wrapped_field.0)
198                .map(Arc::new)
199                .map_err(DataFusionError::from)
200        })
201        .collect::<Result<Vec<FieldRef>>>();
202    let arg_fields = rresult_return!(arg_fields);
203
204    let args = ScalarFunctionArgs {
205        args,
206        arg_fields,
207        number_rows,
208        return_field,
209    };
210
211    let result = rresult_return!(udf
212        .invoke_with_args(args)
213        .and_then(|r| r.to_array(number_rows)));
214
215    let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data()));
216
217    RResult::ROk(WrappedArray {
218        array: result_array,
219        schema: WrappedSchema(result_schema),
220    })
221}
222
223unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) {
224    let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData);
225    drop(private_data);
226}
227
228unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF {
229    let private_data = udf.private_data as *const ScalarUDFPrivateData;
230    let udf_data = &(*private_data);
231
232    Arc::clone(&udf_data.udf).into()
233}
234
235impl Clone for FFI_ScalarUDF {
236    fn clone(&self) -> Self {
237        unsafe { (self.clone)(self) }
238    }
239}
240
241impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
242    fn from(udf: Arc<ScalarUDF>) -> Self {
243        let name = udf.name().into();
244        let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
245        let volatility = udf.signature().volatility.into();
246        let short_circuits = udf.short_circuits();
247
248        let private_data = Box::new(ScalarUDFPrivateData { udf });
249
250        Self {
251            name,
252            aliases,
253            volatility,
254            short_circuits,
255            invoke_with_args: invoke_with_args_fn_wrapper,
256            return_type: return_type_fn_wrapper,
257            return_field_from_args: return_field_from_args_fn_wrapper,
258            coerce_types: coerce_types_fn_wrapper,
259            clone: clone_fn_wrapper,
260            release: release_fn_wrapper,
261            private_data: Box::into_raw(private_data) as *mut c_void,
262        }
263    }
264}
265
266impl Drop for FFI_ScalarUDF {
267    fn drop(&mut self) {
268        unsafe { (self.release)(self) }
269    }
270}
271
272/// This struct is used to access an UDF provided by a foreign
273/// library across a FFI boundary.
274///
275/// The ForeignScalarUDF is to be used by the caller of the UDF, so it has
276/// no knowledge or access to the private data. All interaction with the UDF
277/// must occur through the functions defined in FFI_ScalarUDF.
278#[derive(Debug)]
279pub struct ForeignScalarUDF {
280    name: String,
281    aliases: Vec<String>,
282    udf: FFI_ScalarUDF,
283    signature: Signature,
284}
285
286unsafe impl Send for ForeignScalarUDF {}
287unsafe impl Sync for ForeignScalarUDF {}
288
289impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF {
290    type Error = DataFusionError;
291
292    fn try_from(udf: &FFI_ScalarUDF) -> Result<Self, Self::Error> {
293        let name = udf.name.to_owned().into();
294        let signature = Signature::user_defined((&udf.volatility).into());
295
296        let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
297
298        Ok(Self {
299            name,
300            udf: udf.clone(),
301            aliases,
302            signature,
303        })
304    }
305}
306
307impl ScalarUDFImpl for ForeignScalarUDF {
308    fn as_any(&self) -> &dyn std::any::Any {
309        self
310    }
311
312    fn name(&self) -> &str {
313        &self.name
314    }
315
316    fn signature(&self) -> &Signature {
317        &self.signature
318    }
319
320    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
321        let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
322
323        let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) };
324
325        let result = df_result!(result);
326
327        result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from))
328    }
329
330    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
331        let args: FFI_ReturnFieldArgs = args.try_into()?;
332
333        let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) };
334
335        let result = df_result!(result);
336
337        result.and_then(|r| {
338            Field::try_from(&r.0)
339                .map(Arc::new)
340                .map_err(DataFusionError::from)
341        })
342    }
343
344    fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result<ColumnarValue> {
345        let ScalarFunctionArgs {
346            args,
347            arg_fields,
348            number_rows,
349            return_field,
350        } = invoke_args;
351
352        let args = args
353            .into_iter()
354            .map(|v| v.to_array(number_rows))
355            .collect::<Result<Vec<_>>>()?
356            .into_iter()
357            .map(|v| {
358                to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray {
359                    array: ffi_array,
360                    schema: WrappedSchema(ffi_schema),
361                })
362            })
363            .collect::<std::result::Result<Vec<_>, ArrowError>>()?
364            .into();
365
366        let arg_fields_wrapped = arg_fields
367            .iter()
368            .map(FFI_ArrowSchema::try_from)
369            .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
370
371        let arg_fields = arg_fields_wrapped
372            .into_iter()
373            .map(WrappedSchema)
374            .collect::<RVec<_>>();
375
376        let return_field = return_field.as_ref().clone();
377        let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?);
378
379        let result = unsafe {
380            (self.udf.invoke_with_args)(
381                &self.udf,
382                args,
383                arg_fields,
384                number_rows,
385                return_field,
386            )
387        };
388
389        let result = df_result!(result)?;
390        let result_array: ArrayRef = result.try_into()?;
391
392        Ok(ColumnarValue::Array(result_array))
393    }
394
395    fn aliases(&self) -> &[String] {
396        &self.aliases
397    }
398
399    fn short_circuits(&self) -> bool {
400        self.udf.short_circuits
401    }
402
403    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
404        unsafe {
405            let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
406            let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
407            Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
408        }
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_round_trip_scalar_udf() -> Result<()> {
418        let original_udf = datafusion::functions::math::abs::AbsFunc::new();
419        let original_udf = Arc::new(ScalarUDF::from(original_udf));
420
421        let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into();
422
423        let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?;
424
425        assert_eq!(original_udf.name(), foreign_udf.name());
426
427        Ok(())
428    }
429}