Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ use datafusion_common::{
DFSchema, DataFusionError, Result, ScalarValue, assert_batches_eq,
assert_batches_sorted_eq, assert_contains, exec_datafusion_err, exec_err,
not_impl_err, plan_err,
types::{NativeType, logical_int16},
};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
use datafusion_expr::{
Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder,
OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
Signature, Volatility, lit_with_metadata,
Signature, TypeSignatureClass, Volatility, lit_with_metadata,
};
use datafusion_expr_common::signature::Coercion;
use datafusion_expr_common::signature::TypeSignature;
use datafusion_functions_nested::range::range_udf;
use parking_lot::Mutex;
Expand Down Expand Up @@ -2078,6 +2080,74 @@ AS t(string, extension)
Ok(())
}

/// https://github.com/apache/datafusion/issues/19982
#[tokio::test]
async fn test_return_field_args_scalar_argument_types_match_arg_fields() -> Result<()> {
#[derive(Debug, PartialEq, Eq, Hash)]
struct TestUdf {
signature: Signature,
}

impl Default for TestUdf {
fn default() -> Self {
Self {
signature: Signature::coercible(
vec![Coercion::new_implicit(
TypeSignatureClass::Native(logical_int16()),
vec![TypeSignatureClass::Numeric],
NativeType::Int16,
)],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for TestUdf {
fn name(&self) -> &str {
"test_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
unreachable!("return_field_from_args is implemented")
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
assert_eq!(args.arg_fields.len(), 1);
assert_eq!(args.scalar_arguments.len(), 1);
assert_eq!(
args.arg_fields[0].data_type(),
&args.scalar_arguments[0]
.expect("literal argument")
.data_type()
);
Ok(
Field::new(self.name(), args.arg_fields[0].data_type().clone(), true)
.into(),
)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
assert!(matches!(
args.args[0],
ColumnarValue::Scalar(ScalarValue::Int16(Some(_)))
));
Ok(args.args[0].clone())
}
}

let ctx = SessionContext::new();
ctx.register_udf(TestUdf::default().into());

ctx.sql("select test_udf(1)").await?.collect().await?;

Ok(())
}

/// https://github.com/apache/datafusion/issues/17422
#[tokio::test]
async fn test_extension_metadata_preserve_in_subquery() -> Result<()> {
Expand Down
36 changes: 28 additions & 8 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,30 @@ fn cast_output_field(
Arc::new(f)
}

fn scalar_arguments_for_fields(
args: &[Expr],
arg_fields: &[FieldRef],
) -> Result<Vec<Option<ScalarValue>>> {
args.iter()
.zip(arg_fields)
.map(|(expr, field)| {
literal_scalar_value(expr)
.map(|sv| sv.cast_to(field.data_type()))
.transpose()
})
.collect()
}

fn literal_scalar_value(expr: &Expr) -> Option<&ScalarValue> {
match expr {
Expr::Literal(sv, _) => Some(sv),
Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) => {
literal_scalar_value(expr)
}
_ => None,
}
}

impl ExprSchemable for Expr {
/// Returns the [arrow::datatypes::DataType] of the expression
/// based on [ExprSchema]
Expand Down Expand Up @@ -580,16 +604,12 @@ impl ExprSchemable for Expr {
.collect::<Result<Vec<_>>>()?;
let new_fields = verify_function_arguments(func.as_ref(), &fields)?;

let arguments = args
.iter()
.map(|e| match e {
Expr::Literal(sv, _) => Some(sv),
_ => None,
})
.collect::<Vec<_>>();
let arguments = scalar_arguments_for_fields(args, &new_fields)?;
let argument_refs =
arguments.iter().map(Option::as_ref).collect::<Vec<_>>();
let args = ReturnFieldArgs {
arg_fields: &new_fields,
scalar_arguments: &arguments,
scalar_arguments: &argument_refs,
};

func.return_field_from_args(args)
Expand Down