diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 54bb84f03d3d5..4c766b2cc50c9 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -1585,6 +1585,7 @@ mod tests { vec![DataType::UInt16, DataType::UInt16], vec![DataType::UInt32, DataType::UInt32], vec![DataType::UInt64, DataType::UInt64], + vec![DataType::Float16, DataType::Float16], vec![DataType::Float32, DataType::Float32], vec![DataType::Float64, DataType::Float64] ] diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 01d093950d471..ab4d086e4ca5f 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -42,6 +42,7 @@ pub static NUMERICS: &[DataType] = &[ DataType::UInt16, DataType::UInt32, DataType::UInt64, + DataType::Float16, DataType::Float32, DataType::Float64, ]; diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index e6a1b53418e67..d1c0940aeb306 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -851,10 +851,13 @@ fn coerced_from<'a>( (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()), (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()), (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()), + (Float16, Null | Int8 | Int16 | UInt8 | UInt16 | Float16) => { + Some(type_into.clone()) + } ( Float32, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 - | Float32, + | Float16 | Float32, ) => Some(type_into.clone()), ( Float64, @@ -867,6 +870,7 @@ fn coerced_from<'a>( | UInt16 | UInt32 | UInt64 + | Float16 | Float32 | Float64 | Decimal32(_, _) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 5ce2c3b6af6a6..63fed1f82c5dc 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1929,7 +1929,7 @@ mod test { .err() .unwrap() .strip_backtrace(); - assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed")); + assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float16, Float32, Float64]) failed")); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 036bb93283cc6..8b641e02efc25 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -571,6 +571,16 @@ SELECT covar(c2, c12) FROM aggregate_test_100 ---- -0.079969012479 +query R +SELECT covar_pop(arrow_cast(c2, 'Float16'), arrow_cast(c12, 'Float16')) FROM aggregate_test_100 +---- +-0.079163311005 + +query R +SELECT covar(arrow_cast(c2, 'Float16'), arrow_cast(c12, 'Float16')) FROM aggregate_test_100 +---- +-0.079962940409 + # single_row_query_covar_1 query R select covar_samp(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq @@ -1313,6 +1323,24 @@ select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median ---- 2.75 Float16 +# This shouldn't be NaN, see: +# https://github.com/apache/datafusion/issues/18945 +query RT +select + percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16')), + arrow_typeof(percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16'))) +from median_table; +---- +NaN Float16 + +query RT +select + approx_percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16')), + arrow_typeof(approx_percentile_cont(0.5) within group (order by arrow_cast(col_f32, 'Float16'))) +from median_table; +---- +2.75 Float16 + query ?T select approx_median(NULL), arrow_typeof(approx_median(NULL)) from median_table; ---- @@ -6718,7 +6746,12 @@ from aggregate_test_100; ---- 0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695 - +query R +select + regr_slope(arrow_cast(c12, 'Float16'), arrow_cast(c11, 'Float16')) +from aggregate_test_100; +---- +0.051477733249 # regr_*() functions ignore NULLs query RRIRRRRRR