From 6a008e6f72b6c768cd0bfb90951570856831261b Mon Sep 17 00:00:00 2001 From: Kevin-Li-2025 <2242139@qq.com> Date: Mon, 29 Jun 2026 14:59:24 +0800 Subject: [PATCH] Align scalar UDF return-field literal args Signed-off-by: Kevin-Li-2025 <2242139@qq.com> --- .../user_defined_scalar_functions.rs | 72 ++++++++++++++++++- datafusion/expr/src/expr_schema.rs | 36 +++++++--- 2 files changed, 99 insertions(+), 9 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index b758aeb5209e8..6d4f7948bd4d9 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -40,13 +40,15 @@ use datafusion_common::{ DFSchema, DataFusionError, Result, ScalarValue, assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_datafusion_err, exec_err, not_impl_err, plan_err, + types::{NativeType, logical_int16}, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, lit_with_metadata, + Signature, TypeSignatureClass, Volatility, lit_with_metadata, }; +use datafusion_expr_common::signature::Coercion; use datafusion_expr_common::signature::TypeSignature; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -2078,6 +2080,74 @@ AS t(string, extension) Ok(()) } +/// https://github.com/apache/datafusion/issues/19982 +#[tokio::test] +async fn test_return_field_args_scalar_argument_types_match_arg_fields() -> Result<()> { + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestUdf { + signature: Signature, + } + + impl Default for TestUdf { + fn default() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_int16()), + vec![TypeSignatureClass::Numeric], + NativeType::Int16, + )], + Volatility::Immutable, + ), + } + } + } + + impl ScalarUDFImpl for TestUdf { + fn name(&self) -> &str { + "test_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unreachable!("return_field_from_args is implemented") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + assert_eq!(args.arg_fields.len(), 1); + assert_eq!(args.scalar_arguments.len(), 1); + assert_eq!( + args.arg_fields[0].data_type(), + &args.scalar_arguments[0] + .expect("literal argument") + .data_type() + ); + Ok( + Field::new(self.name(), args.arg_fields[0].data_type().clone(), true) + .into(), + ) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert!(matches!( + args.args[0], + ColumnarValue::Scalar(ScalarValue::Int16(Some(_))) + )); + Ok(args.args[0].clone()) + } + } + + let ctx = SessionContext::new(); + ctx.register_udf(TestUdf::default().into()); + + ctx.sql("select test_udf(1)").await?.collect().await?; + + Ok(()) +} + /// https://github.com/apache/datafusion/issues/17422 #[tokio::test] async fn test_extension_metadata_preserve_in_subquery() -> Result<()> { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 039bbad65a660..ac6b6d3728953 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -87,6 +87,30 @@ fn cast_output_field( Arc::new(f) } +fn scalar_arguments_for_fields( + args: &[Expr], + arg_fields: &[FieldRef], +) -> Result>> { + args.iter() + .zip(arg_fields) + .map(|(expr, field)| { + literal_scalar_value(expr) + .map(|sv| sv.cast_to(field.data_type())) + .transpose() + }) + .collect() +} + +fn literal_scalar_value(expr: &Expr) -> Option<&ScalarValue> { + match expr { + Expr::Literal(sv, _) => Some(sv), + Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) => { + literal_scalar_value(expr) + } + _ => None, + } +} + impl ExprSchemable for Expr { /// Returns the [arrow::datatypes::DataType] of the expression /// based on [ExprSchema] @@ -580,16 +604,12 @@ impl ExprSchemable for Expr { .collect::>>()?; let new_fields = verify_function_arguments(func.as_ref(), &fields)?; - let arguments = args - .iter() - .map(|e| match e { - Expr::Literal(sv, _) => Some(sv), - _ => None, - }) - .collect::>(); + let arguments = scalar_arguments_for_fields(args, &new_fields)?; + let argument_refs = + arguments.iter().map(Option::as_ref).collect::>(); let args = ReturnFieldArgs { arg_fields: &new_fields, - scalar_arguments: &arguments, + scalar_arguments: &argument_refs, }; func.return_field_from_args(args)