diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 6a511db9da00d..4746ac9114733 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -219,14 +219,15 @@ cargo run --example dataframe -- dataframe #### Category: Single Process -| Subcommand | File Path | Description | -| --------------- | ----------------------------------------------------------- | ----------------------------------------------- | -| adv_udaf | [`udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs) | Advanced User Defined Aggregate Function (UDAF) | -| adv_udf | [`udf/advanced_udf.rs`](examples/udf/advanced_udf.rs) | Advanced User Defined Scalar Function (UDF) | -| adv_udwf | [`udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs) | Advanced User Defined Window Function (UDWF) | -| async_udf | [`udf/async_udf.rs`](examples/udf/async_udf.rs) | Asynchronous User Defined Scalar Function | -| udaf | [`udf/simple_udaf.rs`](examples/udf/simple_udaf.rs) | Simple UDAF example | -| udf | [`udf/simple_udf.rs`](examples/udf/simple_udf.rs) | Simple UDF example | -| udtf | [`udf/simple_udtf.rs`](examples/udf/simple_udtf.rs) | Simple UDTF example | -| udwf | [`udf/simple_udwf.rs`](examples/udf/simple_udwf.rs) | Simple UDWF example | -| table_list_udtf | [`udf/table_list_udtf.rs`](examples/udf/table_list_udtf.rs) | Session-aware UDTF table list example | +| Subcommand | File Path | Description | +| --------------- | ----------------------------------------------------------------------- | ----------------------------------------------- | +| adv_udaf | [`udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs) | Advanced User Defined Aggregate Function (UDAF) | +| adv_udf | [`udf/advanced_udf.rs`](examples/udf/advanced_udf.rs) | Advanced User Defined Scalar Function (UDF) | +| adv_udwf | [`udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs) | Advanced User Defined Window Function (UDWF) | +| async_udf | [`udf/async_udf.rs`](examples/udf/async_udf.rs) | Asynchronous User Defined Scalar Function | +| struct_udaf | [`udf/struct_returning_udaf.rs`](examples/udf/struct_returning_udaf.rs) | Struct-returning UDAF with window metadata | +| udaf | [`udf/simple_udaf.rs`](examples/udf/simple_udaf.rs) | Simple UDAF example | +| udf | [`udf/simple_udf.rs`](examples/udf/simple_udf.rs) | Simple UDF example | +| udtf | [`udf/simple_udtf.rs`](examples/udf/simple_udtf.rs) | Simple UDTF example | +| udwf | [`udf/simple_udwf.rs`](examples/udf/simple_udwf.rs) | Simple UDWF example | +| table_list_udtf | [`udf/table_list_udtf.rs`](examples/udf/table_list_udtf.rs) | Session-aware UDTF table list example | diff --git a/datafusion-examples/examples/udf/main.rs b/datafusion-examples/examples/udf/main.rs index 89f3fd801deec..0eff5f7a30a2c 100644 --- a/datafusion-examples/examples/udf/main.rs +++ b/datafusion-examples/examples/udf/main.rs @@ -39,6 +39,9 @@ //! - `async_udf` //! (file: async_udf.rs, desc: Asynchronous User Defined Scalar Function) //! +//! - `struct_udaf` +//! (file: struct_returning_udaf.rs, desc: Struct-returning UDAF with window metadata) +//! //! - `udaf` //! (file: simple_udaf.rs, desc: Simple UDAF example) //! @@ -62,6 +65,7 @@ mod simple_udaf; mod simple_udf; mod simple_udtf; mod simple_udwf; +mod struct_returning_udaf; mod table_list_udtf; use datafusion::error::{DataFusionError, Result}; @@ -76,6 +80,7 @@ enum ExampleKind { AdvUdf, AdvUdwf, AsyncUdf, + StructUdaf, Udf, Udaf, Udwf, @@ -102,6 +107,9 @@ impl ExampleKind { ExampleKind::AdvUdf => advanced_udf::advanced_udf().await?, ExampleKind::AdvUdwf => advanced_udwf::advanced_udwf().await?, ExampleKind::AsyncUdf => async_udf::async_udf().await?, + ExampleKind::StructUdaf => { + struct_returning_udaf::struct_returning_udaf().await? + } ExampleKind::Udaf => simple_udaf::simple_udaf().await?, ExampleKind::Udf => simple_udf::simple_udf().await?, ExampleKind::Udtf => simple_udtf::simple_udtf().await?, diff --git a/datafusion-examples/examples/udf/struct_returning_udaf.rs b/datafusion-examples/examples/udf/struct_returning_udaf.rs new file mode 100644 index 0000000000000..5bb32b9ef28a3 --- /dev/null +++ b/datafusion-examples/examples/udf/struct_returning_udaf.rs @@ -0,0 +1,280 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! This example shows how an extension can return window metadata from an +//! aggregate by passing the relevant input columns directly to the aggregate. + +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, Float64Array, StructArray, TimestampNanosecondArray, UInt64Array, +}; +use arrow::datatypes::{DataType, Field, Fields, Schema, TimeUnit}; +use arrow::record_batch::RecordBatch; +use datafusion::assert_batches_eq; +use datafusion::common::{cast::as_primitive_array, exec_err}; +use datafusion::datasource::MemTable; +use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::{AccumulatorFactoryFunction, Volatility, create_udaf}; +use datafusion::physical_plan::Accumulator; +use datafusion::prelude::*; +use datafusion::scalar::ScalarValue; + +pub async fn struct_returning_udaf() -> Result<()> { + let ctx = create_context()?; + + register_augmented_avg(&ctx); + + // The `augmented_avg` aggregate returns both the average and metadata about + // the time window from which the average was computed. + let sql = " + SELECT + augmented_avg(time, value)['window_start'] AS window_start, + augmented_avg(time, value)['window_end'] AS window_end, + augmented_avg(time, value)['window_duration'] AS window_duration, + augmented_avg(time, value)['avg_value'] AS avg_value + FROM t + GROUP BY date_bin(INTERVAL '5 microseconds', time) + ORDER BY window_start + "; + + let results = ctx.sql(sql).await?.collect().await?; + let expected = [ + "+----------------------------+----------------------------+-----------------+-----------+", + "| window_start | window_end | window_duration | avg_value |", + "+----------------------------+----------------------------+-----------------+-----------+", + "| 1970-01-01T00:00:00.000001 | 1970-01-01T00:00:00.000002 | 1000 | 15.0 |", + "| 1970-01-01T00:00:00.000005 | 1970-01-01T00:00:00.000009 | 4000 | 3.0 |", + "+----------------------------+----------------------------+-----------------+-----------+", + ]; + assert_batches_eq!(expected, &results); + + println!("Struct-returning aggregate produced window metadata:"); + ctx.sql(sql).await?.show().await?; + + Ok(()) +} + +fn create_context() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "time", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("value", DataType::Float64, false), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(TimestampNanosecondArray::from(vec![ + 1000, 2000, 5000, 7000, 9000, + ])) as ArrayRef, + Arc::new(Float64Array::from(vec![10.0, 20.0, 1.0, 3.0, 5.0])), + ], + )?; + + let ctx = SessionContext::new(); + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +fn register_augmented_avg(ctx: &SessionContext) { + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_| Ok(Box::new(AugmentedAvg::new()))); + + let augmented_avg = create_udaf( + "augmented_avg", + vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Float64, + ], + Arc::new(AugmentedAvg::output_datatype()), + Volatility::Immutable, + accumulator, + Arc::new(AugmentedAvg::state_datatypes()), + ); + + ctx.register_udaf(augmented_avg); +} + +#[derive(Debug, Clone)] +struct AugmentedAvg { + window_start: Option, + window_end: Option, + sum: f64, + count: u64, +} + +impl AugmentedAvg { + fn new() -> Self { + Self { + window_start: None, + window_end: None, + sum: 0.0, + count: 0, + } + } + + fn fields() -> Fields { + vec![ + Field::new( + "window_start", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new( + "window_end", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new("window_duration", DataType::Int64, true), + Field::new("avg_value", DataType::Float64, true), + ] + .into() + } + + fn output_datatype() -> DataType { + DataType::Struct(Self::fields()) + } + + fn state_datatypes() -> Vec { + vec![ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Float64, + DataType::UInt64, + ] + } + + fn update_one(&mut self, time: i64, value: f64) { + self.window_start = Some(self.window_start.map_or(time, |start| start.min(time))); + self.window_end = Some(self.window_end.map_or(time, |end| end.max(time))); + self.sum += value; + self.count += 1; + } +} + +impl Accumulator for AugmentedAvg { + fn state(&mut self) -> Result> { + // DataFusion can merge partial aggregate results across execution + // stages, so all values needed to reconstruct the final struct are + // included in the state. + Ok(vec![ + ScalarValue::TimestampNanosecond(self.window_start, None), + ScalarValue::TimestampNanosecond(self.window_end, None), + ScalarValue::Float64(Some(self.sum)), + ScalarValue::UInt64(Some(self.count)), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let [times, values] = values else { + return exec_err!("augmented_avg expects time and value arrays"); + }; + let times = + as_primitive_array::(times)?; + let values = as_primitive_array::(values)?; + + // Track the window bounds and aggregate values directly from the input + // rows assigned to each group by `date_bin`. + for (time, value) in times.iter().zip(values.iter()) { + if let (Some(time), Some(value)) = (time, value) { + self.update_one(time, value); + } + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let [starts, ends, sums, counts] = states else { + return exec_err!("augmented_avg expects four state arrays"); + }; + let starts = + as_primitive_array::(starts)?; + let ends = as_primitive_array::(ends)?; + let sums = as_primitive_array::(sums)?; + let counts = counts + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Execution("Expected UInt64Array".to_string()) + })?; + + // Combine partial states by preserving the earliest start, latest end, + // and additive average components. + for (((start, end), sum), count) in starts + .iter() + .zip(ends.iter()) + .zip(sums.iter()) + .zip(counts.iter()) + { + let Some(count) = count else { + continue; + }; + if count == 0 { + continue; + } + if let (Some(start), Some(end), Some(sum)) = (start, end, sum) { + self.window_start = Some( + self.window_start + .map_or(start, |current| current.min(start)), + ); + self.window_end = + Some(self.window_end.map_or(end, |current| current.max(end))); + self.sum += sum; + self.count += count; + } + } + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let duration = self + .window_start + .zip(self.window_end) + .map(|(start, end)| end - start); + let avg = (self.count > 0).then_some(self.sum / self.count as f64); + + // Return one Struct scalar whose fields can be projected from SQL with + // expressions like `augmented_avg(time, value)['window_start']`. + let struct_array = StructArray::try_new( + AugmentedAvg::fields(), + vec![ + Arc::new(TimestampNanosecondArray::from(vec![self.window_start])) + as ArrayRef, + Arc::new(TimestampNanosecondArray::from(vec![self.window_end])) + as ArrayRef, + Arc::new(arrow::array::Int64Array::from(vec![duration])) as ArrayRef, + Arc::new(Float64Array::from(vec![avg])) as ArrayRef, + ], + None, + )?; + + Ok(ScalarValue::Struct(Arc::new(struct_array))) + } + + fn size(&self) -> usize { + size_of_val(self) + } +} diff --git a/docs/source/library-user-guide/functions/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md index 0221e2e5adeb0..b6021c9dbb7b4 100644 --- a/docs/source/library-user-guide/functions/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -1229,6 +1229,36 @@ The `create_udaf` has six arguments to check: - The fifth argument is the function implementation. This is the function that we defined above. - The sixth argument is the description of the state, which will by passed between execution stages. +### Returning multiple values from an Aggregate UDF + +An aggregate UDF can return a `DataType::Struct` when one aggregate result needs +to carry multiple values. This is useful for time-windowing extensions that +need to return metadata such as the window start, window end, and the aggregate +value together. + +Pass the relevant input columns to the aggregate so the accumulator has enough +information to update and merge state normally in multi-stage aggregate plans. +For example, rows can be grouped into time buckets with the built-in `date_bin` +function, while a struct-returning aggregate computes the value and carries +metadata about each bucket: + +```sql +SELECT + augmented_avg(time, value)['window_start'] AS window_start, + augmented_avg(time, value)['window_end'] AS window_end, + augmented_avg(time, value)['window_duration'] AS window_duration, + augmented_avg(time, value)['avg_value'] AS avg_value +FROM t +GROUP BY date_bin(INTERVAL '30 seconds', time) +ORDER BY window_start; +``` + +In this pattern `date_bin(...)` assigns rows to a time bucket, while +`augmented_avg(time, value)` is a normal aggregate UDF whose accumulator stores +mergeable state such as `window_start`, `window_end`, `sum`, and `count`. +The aggregate's `evaluate` method returns a `ScalarValue::Struct`, and callers +can project individual fields from that struct. + ```rust # use datafusion::arrow::array::ArrayRef;