diff --git a/datafusion/functions/src/math/ceil.rs b/datafusion/functions/src/math/ceil.rs index 5961b3cb27fed..443a5060707c5 100644 --- a/datafusion/functions/src/math/ceil.rs +++ b/datafusion/functions/src/math/ceil.rs @@ -89,6 +89,18 @@ impl ScalarUDFImpl for CeilFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { + DataType::Decimal32(precision, _scale) => { + Ok(DataType::Decimal32(*precision, 0)) + } + DataType::Decimal64(precision, _scale) => { + Ok(DataType::Decimal64(*precision, 0)) + } + DataType::Decimal128(precision, _scale) => { + Ok(DataType::Decimal128(*precision, 0)) + } + DataType::Decimal256(precision, _scale) => { + Ok(DataType::Decimal256(*precision, 0)) + } DataType::Null => Ok(DataType::Float64), other => Ok(other.clone()), } diff --git a/datafusion/functions/src/math/decimal.rs b/datafusion/functions/src/math/decimal.rs index abaded4568a93..d68612c411706 100644 --- a/datafusion/functions/src/math/decimal.rs +++ b/datafusion/functions/src/math/decimal.rs @@ -41,14 +41,15 @@ where let factor = decimal_scale_factor::(scale, fn_name)?; let decimal = array.as_primitive::(); - let data_type = array.data_type().clone(); + let data_type = T::TYPE_CONSTRUCTOR(precision, 0); let result: PrimitiveArray = decimal.try_unary(|value| { let new_value = op(value, factor); - T::validate_decimal_precision(new_value, precision, scale).map_err(|_| { + let rescaled = new_value.div_wrapping(factor); + T::validate_decimal_precision(rescaled, precision, 0).map_err(|_| { ArrowError::ComputeError(format!("Decimal overflow while applying {fn_name}")) })?; - Ok::<_, ArrowError>(new_value) + Ok::<_, ArrowError>(rescaled) })?; let result = result.with_data_type(data_type); diff --git a/datafusion/functions/src/math/floor.rs b/datafusion/functions/src/math/floor.rs index d4f25716ff7ee..39bc52367829d 100644 --- a/datafusion/functions/src/math/floor.rs +++ b/datafusion/functions/src/math/floor.rs @@ -129,6 +129,18 @@ impl ScalarUDFImpl for FloorFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { + DataType::Decimal32(precision, _scale) => { + Ok(DataType::Decimal32(*precision, 0)) + } + DataType::Decimal64(precision, _scale) => { + Ok(DataType::Decimal64(*precision, 0)) + } + DataType::Decimal128(precision, _scale) => { + Ok(DataType::Decimal128(*precision, 0)) + } + DataType::Decimal256(precision, _scale) => { + Ok(DataType::Decimal256(*precision, 0)) + } DataType::Null => Ok(DataType::Float64), other => Ok(other.clone()), }