Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ use crate::aggregates::{

/// Marker for raw rows -> partial state aggregation.
pub(in crate::aggregates) struct PartialMarker;
/// Marker for partial state -> partial state aggregation.
pub(in crate::aggregates) struct PartialReduceMarker;
/// Marker for raw rows -> partial state conversion without aggregation.
pub(in crate::aggregates) struct PartialSkipMarker;
/// Marker for partial state -> final value aggregation.
Expand Down Expand Up @@ -182,7 +184,7 @@ impl<AggrMode> AggregateHashTable<AggrMode> {
acc + state.group_values.size()
+ state.batch_group_indices.allocated_size()
}
AggregateHashTableState::OutputtingMaterializedFinal(output) => {
AggregateHashTableState::OutputtingMaterialized(output) => {
output.memory_size()
}
AggregateHashTableState::Done => 0,
Expand Down Expand Up @@ -304,24 +306,25 @@ pub(super) enum AggregateHashTableState {
Building(AggregateHashTableBuffer),
/// Emitting results directly from group keys and aggregate state.
Outputting(AggregateHashTableBuffer),
/// Materialize all the output results, and then incrementally output in the `OutputtingMaterializedFinal` state.
/// Materialize all output rows, then incrementally output slices from the
/// `OutputtingMaterialized` state.
///
/// Note this is a temporary solution until the `GroupValues` issue is solved:
/// Issue: <https://github.com/apache/datafusion/issues/23178>
OutputtingMaterializedFinal(MaterializedFinalOutput),
OutputtingMaterialized(MaterializedOutput),
Done,
}

/// Fully evaluated final aggregate output and the next row offset to emit.
/// Fully materialized aggregate output and the next row offset to emit.
///
/// Final aggregate evaluation consumes accumulator state, so final output is
/// materialized once and then sliced to honor `batch_size` across output polls.
pub(super) struct MaterializedFinalOutput {
/// Some output paths consume accumulator state when emitting. Materialize those
/// rows once, then slice them to honor `batch_size` across output polls.
pub(super) struct MaterializedOutput {
batch: RecordBatch,
offset: usize,
}

impl MaterializedFinalOutput {
impl MaterializedOutput {
pub(super) fn new(batch: RecordBatch) -> Self {
Self { batch, offset: 0 }
}
Expand Down Expand Up @@ -496,7 +499,7 @@ mod tests {
use super::*;

#[test]
fn materialized_final_output_slices_batches_until_exhausted() -> Result<()> {
fn materialized_output_slices_batches_until_exhausted() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new(
"group_col",
DataType::Int32,
Expand All @@ -506,7 +509,7 @@ mod tests {
schema,
vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
)?;
let mut output = MaterializedFinalOutput::new(batch);
let mut output = MaterializedOutput::new(batch);

assert_eq!(int32_values(&output.next_batch(2).unwrap(), 0), vec![1, 2]);
assert_eq!(int32_values(&output.next_batch(2).unwrap(), 0), vec![3, 4]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::aggregates::AggregateExec;

use super::common::{
AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState, FinalMarker,
MaterializedFinalOutput,
MaterializedOutput,
};

/// Methods specific to the aggregate hash table used in the final aggregation stage.
Expand Down Expand Up @@ -68,7 +68,7 @@ impl AggregateHashTable<FinalMarker> {
let output = self.materialize_final_output(state, output_schema)?;
Ok(self.emit_next_materialized_batch(output, batch_size))
}
AggregateHashTableState::OutputtingMaterializedFinal(output) => {
AggregateHashTableState::OutputtingMaterialized(output) => {
Ok(self.emit_next_materialized_batch(output, batch_size))
}
AggregateHashTableState::Done => Ok(None),
Expand All @@ -82,7 +82,7 @@ impl AggregateHashTable<FinalMarker> {
&self,
mut state: AggregateHashTableBuffer,
output_schema: SchemaRef,
) -> Result<MaterializedFinalOutput> {
) -> Result<MaterializedOutput> {
// Final aggregate evaluation consumes accumulator state. Evaluate all
// groups once, then slice the materialized batch on subsequent polls.
let emit_to = EmitTo::All;
Expand All @@ -96,19 +96,19 @@ impl AggregateHashTable<FinalMarker> {

let batch = RecordBatch::try_new(output_schema, output)?;
debug_assert!(batch.num_rows() > 0);
Ok(MaterializedFinalOutput::new(batch))
Ok(MaterializedOutput::new(batch))
}

fn emit_next_materialized_batch(
&mut self,
mut output: MaterializedFinalOutput,
mut output: MaterializedOutput,
batch_size: usize,
) -> Option<RecordBatch> {
let batch = output.next_batch(batch_size);
if output.is_exhausted() {
self.state = AggregateHashTableState::Done;
} else {
self.state = AggregateHashTableState::OutputtingMaterializedFinal(output);
self.state = AggregateHashTableState::OutputtingMaterialized(output);
}
batch
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

mod common;
mod final_table;
mod partial_reduce_table;
mod partial_table;

pub(super) use common::{
AggregateHashTable, FinalMarker, PartialMarker, PartialSkipMarker,
AggregateHashTable, FinalMarker, PartialMarker, PartialReduceMarker,
PartialSkipMarker,
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::{Result, internal_err};
use datafusion_expr::EmitTo;

use crate::aggregates::AggregateExec;

use super::common::{
AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState,
MaterializedOutput, PartialReduceMarker,
};

/// Methods specific to the aggregate hash table used in the partial-reduce stage.
impl AggregateHashTable<PartialReduceMarker> {
pub(in crate::aggregates) fn new(
agg: &AggregateExec,
partition: usize,
output_schema: SchemaRef,
batch_size: usize,
) -> Result<Self> {
Self::new_with_filters(
agg,
partition,
output_schema,
batch_size,
vec![None; agg.aggr_expr.len()],
)
}

/// Emits the next batch of aggregated group keys and aggregate states.
///
/// The output batch size is determined by `self.batch_size`.
///
/// Returns `Some(batch)` for each emitted batch, `None` when output is
/// exhausted, and an internal error if polled in the `Building` state.
pub(in crate::aggregates) fn next_output_batch(
&mut self,
) -> Result<Option<RecordBatch>> {
let output_schema = Arc::clone(&self.output_schema);
let batch_size = self.batch_size;
// Take ownership of the output state. Note `emit_next_materialized_batch`
// updates state after it emits a materialized slice.
match std::mem::replace(&mut self.state, AggregateHashTableState::Done) {
AggregateHashTableState::Outputting(state) => {
if state.group_values.is_empty() {
return Ok(None);
}

let output =
self.materialize_partial_reduce_output(state, output_schema)?;
Ok(self.emit_next_materialized_batch(output, batch_size))
}
AggregateHashTableState::OutputtingMaterialized(output) => {
Ok(self.emit_next_materialized_batch(output, batch_size))
}
AggregateHashTableState::Done => Ok(None),
AggregateHashTableState::Building(_) => {
internal_err!("next_output_batch must be called in the outputting state")
}
}
}

fn materialize_partial_reduce_output(
&self,
mut state: AggregateHashTableBuffer,
output_schema: SchemaRef,
) -> Result<MaterializedOutput> {
// `state(EmitTo::All)` consumes accumulator state. Emit all groups once,
// then slice the materialized batch on subsequent polls.
let emit_to_all = EmitTo::All;
let timer = self.group_by_metrics.emitting_time.timer();
let mut output = state.group_values.emit(emit_to_all)?;

for acc in state.accumulators.iter_mut() {
output.extend(acc.state(emit_to_all)?);
}
drop(timer);

let batch = RecordBatch::try_new(output_schema, output)?;
debug_assert!(batch.num_rows() > 0);
Ok(MaterializedOutput::new(batch))
}

fn emit_next_materialized_batch(
&mut self,
mut output: MaterializedOutput,
batch_size: usize,
) -> Option<RecordBatch> {
let batch = output.next_batch(batch_size);
if output.is_exhausted() {
self.state = AggregateHashTableState::Done;
} else {
self.state = AggregateHashTableState::OutputtingMaterialized(output);
}
batch
}

pub(in crate::aggregates) fn aggregate_batch(
&mut self,
batch: &RecordBatch,
) -> Result<()> {
let evaluated_batch = self.evaluate_batch(batch)?;
let state = self.state.building_mut();

let timer = self.group_by_metrics.aggregation_time.timer();
for group_values in &evaluated_batch.grouping_set_args {
state
.group_values
.intern(group_values, &mut state.batch_group_indices)?;
let group_indices = &state.batch_group_indices;
let total_num_groups = state.group_values.len();

for (acc, values) in state
.accumulators
.iter_mut()
.zip(evaluated_batch.accumulator_args.iter())
{
acc.merge_batch(values, group_indices, total_num_groups)?;
}
}
drop(timer);

Ok(())
}

pub(in crate::aggregates) fn start_output(&mut self) -> Result<()> {
self.start_outputting();
Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,8 @@ impl AggregateHashTable<PartialMarker> {
AggregateHashTableState::Building(_) => {
internal_err!("next_output_batch must be called in the outputting state")
}
AggregateHashTableState::OutputtingMaterializedFinal(_) => {
internal_err!(
"partial aggregate output should not materialize final output"
)
AggregateHashTableState::OutputtingMaterialized(_) => {
internal_err!("partial aggregate output should not materialize output")
}
}
}
Expand Down
20 changes: 20 additions & 0 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
use crate::aggregates::{
hash_aggregate::{FinalHashAggregateStream, PartialHashAggregateStream},
no_grouping::AggregateStream,
partial_reduce_stream::PartialReduceHashAggregateStream,
row_hash::GroupedHashAggregateStream,
topk_stream::GroupedTopKAggregateStream,
};
Expand Down Expand Up @@ -77,6 +78,7 @@ pub mod group_values;
mod hash_aggregate;
mod no_grouping;
pub mod order;
mod partial_reduce_stream;
mod row_hash;
mod skip_partial;
mod topk;
Expand Down Expand Up @@ -527,6 +529,9 @@ enum StreamType {
/// Partial stage of the hash aggregation
/// Input output scheme: initial input -> partial state
PartialHash(PartialHashAggregateStream),
/// Partial-reduce stage of the hash aggregation
/// Input output scheme: partial state -> partial state
PartialReduceHash(PartialReduceHashAggregateStream),
/// Final stage of the hash aggregation
/// Input output scheme: partial state -> final result
FinalHash(FinalHashAggregateStream),
Expand All @@ -550,6 +555,7 @@ impl From<StreamType> for SendableRecordBatchStream {
match stream {
StreamType::AggregateStream(stream) => Box::pin(stream),
StreamType::PartialHash(stream) => Box::pin(stream),
StreamType::PartialReduceHash(stream) => Box::pin(stream),
StreamType::FinalHash(stream) => Box::pin(stream),
StreamType::GroupedHash(stream) => Box::pin(stream),
StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
Expand Down Expand Up @@ -1028,6 +1034,12 @@ impl AggregateExec {
)?));
}

if self.should_use_partial_reduce_hash_stream() {
return Ok(StreamType::PartialReduceHash(
PartialReduceHashAggregateStream::new(self, context, partition)?,
));
}

if self.should_use_final_hash_stream(context) {
return Ok(StreamType::FinalHash(FinalHashAggregateStream::new(
self, context, partition,
Expand Down Expand Up @@ -1069,6 +1081,14 @@ impl AggregateExec {
&& self.group_by.is_single()
}

fn should_use_partial_reduce_hash_stream(&self) -> bool {
self.mode == AggregateMode::PartialReduce
&& self.limit_options.is_none()
&& self.input_order_mode == InputOrderMode::Linear
&& !self.group_by.is_true_no_grouping()
&& self.group_by.is_single()
}

/// See comments in `PartialHashAggregateStream` limit optimization section
fn limit_options_supported_by_hash_stream(&self) -> bool {
self.limit_options.is_none() || self.is_unordered_unfiltered_group_by_distinct()
Expand Down
Loading
Loading