diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 1ae6ef5c4a8b5..625aef4251685 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -4801,7 +4801,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { @r" Projection: shapes.shape_id [shape_id:UInt32] Unnest: lists[shape_id2|depth=1] structs[] [shape_id:UInt32, shape_id2:UInt32;N] - Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { data_type: UInt32, nullable: true });N] + Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { data_type: UInt32 });N] TableScan: shapes projection=[shape_id] [shape_id:UInt32] " ); diff --git a/datafusion/core/tests/sql/aggregates/basic.rs b/datafusion/core/tests/sql/aggregates/basic.rs index d1b376b735ab9..cf186f5b9dfa1 100644 --- a/datafusion/core/tests/sql/aggregates/basic.rs +++ b/datafusion/core/tests/sql/aggregates/basic.rs @@ -35,11 +35,12 @@ async fn csv_query_array_agg_distinct() -> Result<()> { // | [4, 2, 3, 5, 1] | // +------------------------------------------+ // Since ARRAY_AGG(DISTINCT) ordering is nondeterministic, check the schema and contents. + // The inner field nullability matches the input column c2 which is NOT NULL assert_eq!( *actual[0].schema(), Schema::new(vec![Field::new_list( "array_agg(DISTINCT aggregate_test_100.c2)", - Field::new_list_field(DataType::UInt32, true), + Field::new_list_field(DataType::UInt32, false), true ),]) ); diff --git a/datafusion/functions-aggregate/benches/array_agg.rs b/datafusion/functions-aggregate/benches/array_agg.rs index d7f687386333f..78e788a6a1db8 100644 --- a/datafusion/functions-aggregate/benches/array_agg.rs +++ b/datafusion/functions-aggregate/benches/array_agg.rs @@ -45,7 +45,7 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) { b.iter(|| { #[allow(clippy::unit_arg)] black_box( - ArrayAggAccumulator::try_new(&list_item_data_type, false) + ArrayAggAccumulator::try_new(&list_item_data_type, false, true) .unwrap() .merge_batch(std::slice::from_ref(&values)) .unwrap(), diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index c07958a858ed4..b30be3fb4652a 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -32,7 +32,9 @@ use datafusion_common::cast::as_list_array; use datafusion_common::utils::{ SingleRowListArrayBuilder, compare_rows, get_row_at_idx, take_function_args, }; -use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err, exec_err}; +use datafusion_common::{ + Result, ScalarValue, assert_eq_or_internal_err, exec_err, not_impl_err, +}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -104,11 +106,23 @@ impl AggregateUDFImpl for ArrayAgg { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(DataType::List(Arc::new(Field::new_list_field( - arg_types[0].clone(), + fn return_type(&self, _arg_types: &[DataType]) -> Result { + not_impl_err!("Not called because return_field is implemented") + } + + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + // Outer field is always nullable in case of empty groups + // Inner list field nullability depends on input field + let input_field = &arg_fields[0]; + let list_field = Field::new_list_field( + input_field.data_type().clone(), + input_field.is_nullable(), + ); + Ok(Arc::new(Field::new( + self.name(), + DataType::List(Arc::new(list_field)), true, - )))) + ))) } fn state_fields(&self, args: StateFieldsArgs) -> Result> { @@ -169,6 +183,7 @@ impl AggregateUDFImpl for ArrayAgg { let field = &acc_args.expr_fields[0]; let data_type = field.data_type(); let ignore_nulls = acc_args.ignore_nulls && field.is_nullable(); + let input_nullable = field.is_nullable(); if acc_args.is_distinct { // Limitation similar to Postgres. The aggregation function can only mix @@ -198,6 +213,7 @@ impl AggregateUDFImpl for ArrayAgg { data_type, sort_option, ignore_nulls, + input_nullable, )?)); } @@ -205,6 +221,7 @@ impl AggregateUDFImpl for ArrayAgg { return Ok(Box::new(ArrayAggAccumulator::try_new( data_type, ignore_nulls, + input_nullable, )?)); }; @@ -220,6 +237,7 @@ impl AggregateUDFImpl for ArrayAgg { self.is_input_pre_ordered, acc_args.is_reversed, ignore_nulls, + input_nullable, ) .map(|acc| Box::new(acc) as _) } @@ -242,15 +260,22 @@ pub struct ArrayAggAccumulator { values: Vec, datatype: DataType, ignore_nulls: bool, + /// Whether the input field is nullable (preserved in result list elements) + input_nullable: bool, } impl ArrayAggAccumulator { /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result { + pub fn try_new( + datatype: &DataType, + ignore_nulls: bool, + input_nullable: bool, + ) -> Result { Ok(Self { values: vec![], datatype: datatype.clone(), ignore_nulls, + input_nullable, }) } @@ -373,21 +398,46 @@ impl Accumulator for ArrayAggAccumulator { } fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) + // State uses nullable inner elements to match state_fields() schema + // This is required for proper merging across partitions + let element_arrays: Vec<&dyn Array> = + self.values.iter().map(|a| a.as_ref()).collect(); + + if element_arrays.is_empty() { + return Ok(vec![ScalarValue::new_null_list( + self.datatype.clone(), + true, // state always uses nullable inner + 1, + )]); + } + + let concated_array = arrow::compute::concat(&element_arrays)?; + + Ok(vec![ + SingleRowListArrayBuilder::new(concated_array) + .with_nullable(true) // state always uses nullable inner + .build_list_scalar(), + ]) } fn evaluate(&mut self) -> Result { - // Transform Vec to ListArr + // Final output uses input_nullable to preserve nullability from input let element_arrays: Vec<&dyn Array> = self.values.iter().map(|a| a.as_ref()).collect(); if element_arrays.is_empty() { - return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); + return Ok(ScalarValue::new_null_list( + self.datatype.clone(), + self.input_nullable, + 1, + )); } let concated_array = arrow::compute::concat(&element_arrays)?; - Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar()) + Ok(SingleRowListArrayBuilder::new(concated_array) + .with_nullable(self.input_nullable) + .build_list_scalar()) } fn size(&self) -> usize { @@ -420,6 +470,8 @@ pub struct DistinctArrayAggAccumulator { datatype: DataType, sort_options: Option, ignore_nulls: bool, + /// Whether the input field is nullable (preserved in result list elements) + input_nullable: bool, } impl DistinctArrayAggAccumulator { @@ -427,19 +479,60 @@ impl DistinctArrayAggAccumulator { datatype: &DataType, sort_options: Option, ignore_nulls: bool, + input_nullable: bool, ) -> Result { Ok(Self { values: HashSet::new(), datatype: datatype.clone(), sort_options, ignore_nulls, + input_nullable, }) } } impl Accumulator for DistinctArrayAggAccumulator { fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) + // State uses nullable inner elements to match state_fields() schema + let mut values: Vec = self.values.iter().cloned().collect(); + if values.is_empty() { + return Ok(vec![ScalarValue::new_null_list( + self.datatype.clone(), + true, // state always uses nullable inner + 1, + )]); + } + + // Sort if needed (same logic as evaluate) + if let Some(opts) = self.sort_options { + let mut delayed_cmp_err = Ok(()); + values.sort_by(|a, b| { + if a.is_null() { + return match opts.nulls_first { + true => Ordering::Less, + false => Ordering::Greater, + }; + } + if b.is_null() { + return match opts.nulls_first { + true => Ordering::Greater, + false => Ordering::Less, + }; + } + match opts.descending { + true => b.try_cmp(a), + false => a.try_cmp(b), + } + .unwrap_or_else(|err| { + delayed_cmp_err = Err(err); + Ordering::Equal + }) + }); + delayed_cmp_err?; + } + + let arr = ScalarValue::new_list(&values, &self.datatype, true); // state always uses nullable inner + Ok(vec![ScalarValue::List(arr)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -484,7 +577,11 @@ impl Accumulator for DistinctArrayAggAccumulator { fn evaluate(&mut self) -> Result { let mut values: Vec = self.values.iter().cloned().collect(); if values.is_empty() { - return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); + return Ok(ScalarValue::new_null_list( + self.datatype.clone(), + self.input_nullable, + 1, + )); } if let Some(opts) = self.sort_options { @@ -514,7 +611,7 @@ impl Accumulator for DistinctArrayAggAccumulator { delayed_cmp_err?; }; - let arr = ScalarValue::new_list(&values, &self.datatype, true); + let arr = ScalarValue::new_list(&values, &self.datatype, self.input_nullable); Ok(ScalarValue::List(arr)) } @@ -551,6 +648,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator { reverse: bool, /// Whether the aggregation should ignore null values. ignore_nulls: bool, + /// Whether the input field is nullable (preserved in result list elements) + input_nullable: bool, } impl OrderSensitiveArrayAggAccumulator { @@ -563,6 +662,7 @@ impl OrderSensitiveArrayAggAccumulator { is_input_pre_ordered: bool, reverse: bool, ignore_nulls: bool, + input_nullable: bool, ) -> Result { let mut datatypes = vec![datatype.clone()]; datatypes.extend(ordering_dtypes.iter().cloned()); @@ -574,6 +674,7 @@ impl OrderSensitiveArrayAggAccumulator { is_input_pre_ordered, reverse, ignore_nulls, + input_nullable, }) } @@ -741,13 +842,39 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { self.sort(); } - let mut result = vec![self.evaluate()?]; + // State uses nullable inner elements to match state_fields() schema + let state_value = if self.values.is_empty() { + ScalarValue::new_null_list( + self.datatypes[0].clone(), + true, // state always uses nullable inner + 1, + ) + } else { + let values = self.values.clone(); + let array = if self.reverse { + ScalarValue::new_list_from_iter( + values.into_iter().rev(), + &self.datatypes[0], + true, // state always uses nullable inner + ) + } else { + ScalarValue::new_list_from_iter( + values.into_iter(), + &self.datatypes[0], + true, // state always uses nullable inner + ) + }; + ScalarValue::List(array) + }; + + let mut result = vec![state_value]; result.push(self.evaluate_orderings()?); Ok(result) } fn evaluate(&mut self) -> Result { + // Final output uses input_nullable to preserve nullability from input if !self.is_input_pre_ordered { self.sort(); } @@ -755,7 +882,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { if self.values.is_empty() { return Ok(ScalarValue::new_null_list( self.datatypes[0].clone(), - true, + self.input_nullable, 1, )); } @@ -765,10 +892,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { ScalarValue::new_list_from_iter( values.into_iter().rev(), &self.datatypes[0], - true, + self.input_nullable, ) } else { - ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true) + ScalarValue::new_list_from_iter( + values.into_iter(), + &self.datatypes[0], + self.input_nullable, + ) }; Ok(ScalarValue::List(array)) } @@ -799,7 +930,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { #[cfg(test)] mod tests { use super::*; - use arrow::array::{ListBuilder, StringBuilder}; + use arrow::array::{Int64Array, ListBuilder, StringBuilder}; use arrow::datatypes::{FieldRef, Schema}; use datafusion_common::cast::as_generic_string_array; use datafusion_common::internal_err; @@ -1116,6 +1247,135 @@ mod tests { Ok(()) } + #[test] + fn return_field_preserves_input_nullability() -> Result<()> { + let array_agg = ArrayAgg::default(); + + // Test with nullable input field + let nullable_field: FieldRef = + Arc::new(Field::new("input", DataType::Int64, true)); + let result_field = + array_agg.return_field(std::slice::from_ref(&nullable_field))?; + // List itself should always be nullable (NULL for empty groups) + assert!( + result_field.is_nullable(), + "List result should always be nullable" + ); + // Check inner field nullability is preserved + match result_field.data_type() { + DataType::List(inner) => { + assert!( + inner.is_nullable(), + "Inner field should be nullable when input is nullable" + ); + } + _ => panic!("Expected List type"), + } + + // Test with non-nullable input field + let non_nullable_field: FieldRef = + Arc::new(Field::new("input", DataType::Int64, false)); + let result_field = + array_agg.return_field(std::slice::from_ref(&non_nullable_field))?; + // List itself should still be nullable + assert!( + result_field.is_nullable(), + "List result should always be nullable" + ); + // Check inner field nullability is preserved + match result_field.data_type() { + DataType::List(inner) => { + assert!( + !inner.is_nullable(), + "Inner field should be non-nullable when input is non-nullable" + ); + } + _ => panic!("Expected List type"), + } + + Ok(()) + } + + #[test] + fn accumulator_output_matches_return_field() -> Result<()> { + // Test that the ListArray returned by evaluate() has a data type + // that matches the return_field() specification + + // Create a schema with a non-nullable input column + let input_schema = Schema::new(vec![Field::new( + "input", + DataType::Int64, + false, // non-nullable + )]); + + // Get the expected return field + let array_agg = ArrayAgg::default(); + let input_field: FieldRef = Arc::new(Field::new("input", DataType::Int64, false)); + let expected_return_field = + array_agg.return_field(std::slice::from_ref(&input_field))?; + + // Verify the expected field has non-nullable inner + if let DataType::List(inner) = expected_return_field.data_type() { + assert!( + !inner.is_nullable(), + "Expected non-nullable inner field from return_field" + ); + } else { + panic!("Expected List type"); + } + + // Create an accumulator for non-nullable input + let expr: Arc = Arc::new(Column::new("input", 0)); + let expr_field = expr.return_field(&input_schema)?; + let mut acc = array_agg.accumulator(AccumulatorArgs { + return_field: Arc::clone(&expected_return_field), + schema: &input_schema, + expr_fields: &[expr_field], + ignore_nulls: false, + order_bys: &[], + is_reversed: false, + name: "test", + is_distinct: false, + exprs: &[expr], + })?; + + // Add some values + let values: ArrayRef = Arc::new(Int64Array::from(vec![1i64, 2, 3])); + acc.update_batch(&[values])?; + + // Evaluate and check the result + let result = acc.evaluate()?; + let result_array = result.to_array()?; + + // Check the result array's data type matches the expected + let result_list = result_array + .as_any() + .downcast_ref::() + .expect("Expected ListArray"); + + // For ListArray, get the inner field from the data type + let DataType::List(result_inner_field) = result_list.data_type() else { + panic!("Expected List data type"); + }; + + assert!( + !result_inner_field.is_nullable(), + "Result list's inner field should be non-nullable, but got nullable. \ + Expected data_type: {:?}, got: {:?}", + expected_return_field.data_type(), + result_array.data_type() + ); + + // Verify data types match exactly + assert_eq!( + expected_return_field.data_type(), + result_array.data_type(), + "Result data type should match return_field data type" + ); + + Ok(()) + } + struct ArrayAggAccumulatorBuilder { return_field: FieldRef, distinct: bool, diff --git a/datafusion/spark/src/function/aggregate/collect.rs b/datafusion/spark/src/function/aggregate/collect.rs index 50497e2826383..95c088bab430d 100644 --- a/datafusion/spark/src/function/aggregate/collect.rs +++ b/datafusion/spark/src/function/aggregate/collect.rs @@ -89,7 +89,7 @@ impl AggregateUDFImpl for SparkCollectList { let data_type = field.data_type().clone(); let ignore_nulls = true; Ok(Box::new(NullToEmptyListAccumulator::new( - ArrayAggAccumulator::try_new(&data_type, ignore_nulls)?, + ArrayAggAccumulator::try_new(&data_type, ignore_nulls, true)?, data_type, ))) } @@ -151,7 +151,7 @@ impl AggregateUDFImpl for SparkCollectSet { let data_type = field.data_type().clone(); let ignore_nulls = true; Ok(Box::new(NullToEmptyListAccumulator::new( - DistinctArrayAggAccumulator::try_new(&data_type, None, ignore_nulls)?, + DistinctArrayAggAccumulator::try_new(&data_type, None, ignore_nulls, true)?, data_type, ))) }