1use crate::{
19 arrow_wrappers::{WrappedArray, WrappedSchema},
20 df_result, rresult, rresult_return,
21 util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
22 volatility::FFI_Volatility,
23};
24use abi_stable::{
25 std_types::{RResult, RString, RVec},
26 StableAbi,
27};
28use arrow::datatypes::{DataType, Field};
29use arrow::{
30 array::ArrayRef,
31 error::ArrowError,
32 ffi::{from_ffi, to_ffi, FFI_ArrowSchema},
33};
34use arrow_schema::FieldRef;
35use datafusion::logical_expr::ReturnFieldArgs;
36use datafusion::{
37 error::DataFusionError,
38 logical_expr::type_coercion::functions::data_types_with_scalar_udf,
39};
40use datafusion::{
41 error::Result,
42 logical_expr::{
43 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
44 },
45};
46use return_type_args::{
47 FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned,
48};
49use std::{ffi::c_void, sync::Arc};
50
51pub mod return_type_args;
52
53#[repr(C)]
55#[derive(Debug, StableAbi)]
56#[allow(non_camel_case_types)]
57pub struct FFI_ScalarUDF {
58 pub name: RString,
60
61 pub aliases: RVec<RString>,
63
64 pub volatility: FFI_Volatility,
66
67 pub return_type: unsafe extern "C" fn(
70 udf: &Self,
71 arg_types: RVec<WrappedSchema>,
72 ) -> RResult<WrappedSchema, RString>,
73
74 pub return_field_from_args: unsafe extern "C" fn(
77 udf: &Self,
78 args: FFI_ReturnFieldArgs,
79 )
80 -> RResult<WrappedSchema, RString>,
81
82 #[allow(clippy::type_complexity)]
85 pub invoke_with_args: unsafe extern "C" fn(
86 udf: &Self,
87 args: RVec<WrappedArray>,
88 arg_fields: RVec<WrappedSchema>,
89 num_rows: usize,
90 return_field: WrappedSchema,
91 ) -> RResult<WrappedArray, RString>,
92
93 pub short_circuits: bool,
95
96 pub coerce_types: unsafe extern "C" fn(
101 udf: &Self,
102 arg_types: RVec<WrappedSchema>,
103 ) -> RResult<RVec<WrappedSchema>, RString>,
104
105 pub clone: unsafe extern "C" fn(udf: &Self) -> Self,
108
109 pub release: unsafe extern "C" fn(udf: &mut Self),
111
112 pub private_data: *mut c_void,
115}
116
117unsafe impl Send for FFI_ScalarUDF {}
118unsafe impl Sync for FFI_ScalarUDF {}
119
120pub struct ScalarUDFPrivateData {
121 pub udf: Arc<ScalarUDF>,
122}
123
124unsafe extern "C" fn return_type_fn_wrapper(
125 udf: &FFI_ScalarUDF,
126 arg_types: RVec<WrappedSchema>,
127) -> RResult<WrappedSchema, RString> {
128 let private_data = udf.private_data as *const ScalarUDFPrivateData;
129 let udf = &(*private_data).udf;
130
131 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
132
133 let return_type = udf
134 .return_type(&arg_types)
135 .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from))
136 .map(WrappedSchema);
137
138 rresult!(return_type)
139}
140
141unsafe extern "C" fn return_field_from_args_fn_wrapper(
142 udf: &FFI_ScalarUDF,
143 args: FFI_ReturnFieldArgs,
144) -> RResult<WrappedSchema, RString> {
145 let private_data = udf.private_data as *const ScalarUDFPrivateData;
146 let udf = &(*private_data).udf;
147
148 let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into());
149 let args_ref: ForeignReturnFieldArgs = (&args).into();
150
151 let return_type = udf
152 .return_field_from_args((&args_ref).into())
153 .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from))
154 .map(WrappedSchema);
155
156 rresult!(return_type)
157}
158
159unsafe extern "C" fn coerce_types_fn_wrapper(
160 udf: &FFI_ScalarUDF,
161 arg_types: RVec<WrappedSchema>,
162) -> RResult<RVec<WrappedSchema>, RString> {
163 let private_data = udf.private_data as *const ScalarUDFPrivateData;
164 let udf = &(*private_data).udf;
165
166 let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types));
167
168 let return_types = rresult_return!(data_types_with_scalar_udf(&arg_types, udf));
169
170 rresult!(vec_datatype_to_rvec_wrapped(&return_types))
171}
172
173unsafe extern "C" fn invoke_with_args_fn_wrapper(
174 udf: &FFI_ScalarUDF,
175 args: RVec<WrappedArray>,
176 arg_fields: RVec<WrappedSchema>,
177 number_rows: usize,
178 return_field: WrappedSchema,
179) -> RResult<WrappedArray, RString> {
180 let private_data = udf.private_data as *const ScalarUDFPrivateData;
181 let udf = &(*private_data).udf;
182
183 let args = args
184 .into_iter()
185 .map(|arr| {
186 from_ffi(arr.array, &arr.schema.0)
187 .map(|v| ColumnarValue::Array(arrow::array::make_array(v)))
188 })
189 .collect::<std::result::Result<_, _>>();
190
191 let args = rresult_return!(args);
192 let return_field = rresult_return!(Field::try_from(&return_field.0)).into();
193
194 let arg_fields = arg_fields
195 .into_iter()
196 .map(|wrapped_field| {
197 Field::try_from(&wrapped_field.0)
198 .map(Arc::new)
199 .map_err(DataFusionError::from)
200 })
201 .collect::<Result<Vec<FieldRef>>>();
202 let arg_fields = rresult_return!(arg_fields);
203
204 let args = ScalarFunctionArgs {
205 args,
206 arg_fields,
207 number_rows,
208 return_field,
209 };
210
211 let result = rresult_return!(udf
212 .invoke_with_args(args)
213 .and_then(|r| r.to_array(number_rows)));
214
215 let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data()));
216
217 RResult::ROk(WrappedArray {
218 array: result_array,
219 schema: WrappedSchema(result_schema),
220 })
221}
222
223unsafe extern "C" fn release_fn_wrapper(udf: &mut FFI_ScalarUDF) {
224 let private_data = Box::from_raw(udf.private_data as *mut ScalarUDFPrivateData);
225 drop(private_data);
226}
227
228unsafe extern "C" fn clone_fn_wrapper(udf: &FFI_ScalarUDF) -> FFI_ScalarUDF {
229 let private_data = udf.private_data as *const ScalarUDFPrivateData;
230 let udf_data = &(*private_data);
231
232 Arc::clone(&udf_data.udf).into()
233}
234
235impl Clone for FFI_ScalarUDF {
236 fn clone(&self) -> Self {
237 unsafe { (self.clone)(self) }
238 }
239}
240
241impl From<Arc<ScalarUDF>> for FFI_ScalarUDF {
242 fn from(udf: Arc<ScalarUDF>) -> Self {
243 let name = udf.name().into();
244 let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect();
245 let volatility = udf.signature().volatility.into();
246 let short_circuits = udf.short_circuits();
247
248 let private_data = Box::new(ScalarUDFPrivateData { udf });
249
250 Self {
251 name,
252 aliases,
253 volatility,
254 short_circuits,
255 invoke_with_args: invoke_with_args_fn_wrapper,
256 return_type: return_type_fn_wrapper,
257 return_field_from_args: return_field_from_args_fn_wrapper,
258 coerce_types: coerce_types_fn_wrapper,
259 clone: clone_fn_wrapper,
260 release: release_fn_wrapper,
261 private_data: Box::into_raw(private_data) as *mut c_void,
262 }
263 }
264}
265
266impl Drop for FFI_ScalarUDF {
267 fn drop(&mut self) {
268 unsafe { (self.release)(self) }
269 }
270}
271
272#[derive(Debug)]
279pub struct ForeignScalarUDF {
280 name: String,
281 aliases: Vec<String>,
282 udf: FFI_ScalarUDF,
283 signature: Signature,
284}
285
286unsafe impl Send for ForeignScalarUDF {}
287unsafe impl Sync for ForeignScalarUDF {}
288
289impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF {
290 type Error = DataFusionError;
291
292 fn try_from(udf: &FFI_ScalarUDF) -> Result<Self, Self::Error> {
293 let name = udf.name.to_owned().into();
294 let signature = Signature::user_defined((&udf.volatility).into());
295
296 let aliases = udf.aliases.iter().map(|s| s.to_string()).collect();
297
298 Ok(Self {
299 name,
300 udf: udf.clone(),
301 aliases,
302 signature,
303 })
304 }
305}
306
307impl ScalarUDFImpl for ForeignScalarUDF {
308 fn as_any(&self) -> &dyn std::any::Any {
309 self
310 }
311
312 fn name(&self) -> &str {
313 &self.name
314 }
315
316 fn signature(&self) -> &Signature {
317 &self.signature
318 }
319
320 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
321 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
322
323 let result = unsafe { (self.udf.return_type)(&self.udf, arg_types) };
324
325 let result = df_result!(result);
326
327 result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from))
328 }
329
330 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
331 let args: FFI_ReturnFieldArgs = args.try_into()?;
332
333 let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) };
334
335 let result = df_result!(result);
336
337 result.and_then(|r| {
338 Field::try_from(&r.0)
339 .map(Arc::new)
340 .map_err(DataFusionError::from)
341 })
342 }
343
344 fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result<ColumnarValue> {
345 let ScalarFunctionArgs {
346 args,
347 arg_fields,
348 number_rows,
349 return_field,
350 } = invoke_args;
351
352 let args = args
353 .into_iter()
354 .map(|v| v.to_array(number_rows))
355 .collect::<Result<Vec<_>>>()?
356 .into_iter()
357 .map(|v| {
358 to_ffi(&v.to_data()).map(|(ffi_array, ffi_schema)| WrappedArray {
359 array: ffi_array,
360 schema: WrappedSchema(ffi_schema),
361 })
362 })
363 .collect::<std::result::Result<Vec<_>, ArrowError>>()?
364 .into();
365
366 let arg_fields_wrapped = arg_fields
367 .iter()
368 .map(FFI_ArrowSchema::try_from)
369 .collect::<std::result::Result<Vec<_>, ArrowError>>()?;
370
371 let arg_fields = arg_fields_wrapped
372 .into_iter()
373 .map(WrappedSchema)
374 .collect::<RVec<_>>();
375
376 let return_field = return_field.as_ref().clone();
377 let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?);
378
379 let result = unsafe {
380 (self.udf.invoke_with_args)(
381 &self.udf,
382 args,
383 arg_fields,
384 number_rows,
385 return_field,
386 )
387 };
388
389 let result = df_result!(result)?;
390 let result_array: ArrayRef = result.try_into()?;
391
392 Ok(ColumnarValue::Array(result_array))
393 }
394
395 fn aliases(&self) -> &[String] {
396 &self.aliases
397 }
398
399 fn short_circuits(&self) -> bool {
400 self.udf.short_circuits
401 }
402
403 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
404 unsafe {
405 let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?;
406 let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?;
407 Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
408 }
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_round_trip_scalar_udf() -> Result<()> {
418 let original_udf = datafusion::functions::math::abs::AbsFunc::new();
419 let original_udf = Arc::new(ScalarUDF::from(original_udf));
420
421 let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into();
422
423 let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?;
424
425 assert_eq!(original_udf.name(), foreign_udf.name());
426
427 Ok(())
428 }
429}