Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1643,6 +1643,7 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
(Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
string_coercion(lhs_value_type, rhs_value_type).or(None)
}
(Binary, Binary) => Some(Utf8),
_ => None,
})
}
Expand Down
78 changes: 65 additions & 13 deletions datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ use crate::string::concat;
use crate::strings::{
ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder,
};
use datafusion_common::cast::{as_string_array, as_string_view_array};
use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
use datafusion_common::cast::{as_binary_array, as_string_array, as_string_view_array};
use datafusion_common::{
Result, ScalarValue, exec_datafusion_err, internal_err, plan_err,
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
Expand Down Expand Up @@ -68,7 +70,7 @@ impl ConcatFunc {
use DataType::*;
Self {
signature: Signature::variadic(
vec![Utf8View, Utf8, LargeUtf8],
vec![Utf8View, Utf8, LargeUtf8, Binary],
Volatility::Immutable,
),
}
Expand Down Expand Up @@ -130,19 +132,25 @@ impl ScalarUDFImpl for ConcatFunc {

// Scalar
if array_len.is_none() {
let mut values = Vec::with_capacity(args.len());
let mut values: Vec<&str> = Vec::with_capacity(args.len());
for arg in &args {
let ColumnarValue::Scalar(scalar) = arg else {
return internal_err!("concat expected scalar value, got {arg:?}");
};

match scalar.try_as_str() {
Some(Some(v)) => values.push(v),
Some(None) => {} // null literal
None => plan_err!(
"Concat function does not support scalar type {}",
scalar
)?,
if let ScalarValue::Binary(Some(value)) = scalar {
let s: &str = std::str::from_utf8(value).map_err(|_| {
exec_datafusion_err!("invalid UTF-8 in binary literal: {value:?}")
})?;
values.push(s);
} else {
match scalar.try_as_str() {
Some(Some(v)) => values.push(v),
Some(None) => {} // null literal
None => plan_err!(
"Concat function does not support scalar type {}",
scalar
)?,
}
}
}
let result = values.concat();
Expand Down Expand Up @@ -178,6 +186,13 @@ impl ScalarUDFImpl for ConcatFunc {
columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
}
}
ColumnarValue::Scalar(ScalarValue::Binary(maybe_value)) => {
if let Some(b) = maybe_value {
// data_size is a capacity hint, so doesn't matter if it is chars or bytes
data_size += b.len() * len;
columns.push(ColumnarValueRef::Scalar(b.as_slice()));
}
}
ColumnarValue::Array(array) => {
match array.data_type() {
DataType::Utf8 => {
Expand Down Expand Up @@ -215,6 +230,17 @@ impl ScalarUDFImpl for ConcatFunc {
};
columns.push(column);
}
DataType::Binary => {
let string_array = as_binary_array(array)?;

data_size += string_array.values().len();
let column = if array.is_nullable() {
ColumnarValueRef::NullableBinaryArray(string_array)
} else {
ColumnarValueRef::NonNullableBinaryArray(string_array)
};
columns.push(column);
}
other => {
return plan_err!(
"Input was {other} which is not a supported datatype for concat function"
Expand Down Expand Up @@ -456,7 +482,33 @@ mod tests {
Utf8View,
StringViewArray
);

test_function!(
ConcatFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Binary(Some(
"Café".as_bytes().into()
))),
ColumnarValue::Scalar(ScalarValue::Utf8(None)),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
],
Ok(Some("Cafécc")),
&str,
Utf8,
StringArray
);
test_function!(
ConcatFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Binary(Some(Vec::from(
"Café".as_bytes()
)))),
ColumnarValue::Scalar(ScalarValue::Binary(Some("cc".as_bytes().into()))),
],
Ok(Some("Cafécc")),
&str,
Utf8,
StringArray
);
Ok(())
}

Expand Down
38 changes: 35 additions & 3 deletions datafusion/functions/src/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use std::mem::size_of;

use arrow::array::{
Array, ArrayAccessor, ArrayDataBuilder, ByteView, LargeStringArray,
Array, ArrayAccessor, ArrayDataBuilder, BinaryArray, ByteView, LargeStringArray,
NullBufferBuilder, StringArray, StringViewArray, StringViewBuilder, make_view,
};
use arrow::buffer::{MutableBuffer, NullBuffer};
Expand Down Expand Up @@ -75,6 +75,11 @@ impl StringArrayBuilder {
.extend_from_slice(array.value(i).as_bytes());
}
}
ColumnarValueRef::NullableBinaryArray(array) => {
if !CHECK_VALID || array.is_valid(i) {
self.value_buffer.extend_from_slice(array.value(i));
}
}
ColumnarValueRef::NonNullableArray(array) => {
self.value_buffer
.extend_from_slice(array.value(i).as_bytes());
Expand All @@ -87,6 +92,9 @@ impl StringArrayBuilder {
self.value_buffer
.extend_from_slice(array.value(i).as_bytes());
}
ColumnarValueRef::NonNullableBinaryArray(array) => {
self.value_buffer.extend_from_slice(array.value(i));
}
}
}

Expand Down Expand Up @@ -171,6 +179,12 @@ impl StringViewArrayBuilder {
);
}
}
ColumnarValueRef::NullableBinaryArray(array) => {
if !CHECK_VALID || array.is_valid(i) {
self.block
.push_str(std::str::from_utf8(array.value(i)).unwrap());
}
}
ColumnarValueRef::NonNullableArray(array) => {
self.block
.push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap());
Expand All @@ -183,6 +197,10 @@ impl StringViewArrayBuilder {
self.block
.push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap());
}
ColumnarValueRef::NonNullableBinaryArray(array) => {
self.block
.push_str(std::str::from_utf8(array.value(i)).unwrap());
}
}
}

Expand Down Expand Up @@ -244,6 +262,11 @@ impl LargeStringArrayBuilder {
.extend_from_slice(array.value(i).as_bytes());
}
}
ColumnarValueRef::NullableBinaryArray(array) => {
if !CHECK_VALID || array.is_valid(i) {
self.value_buffer.extend_from_slice(array.value(i));
}
}
ColumnarValueRef::NonNullableArray(array) => {
self.value_buffer
.extend_from_slice(array.value(i).as_bytes());
Expand All @@ -256,6 +279,9 @@ impl LargeStringArrayBuilder {
self.value_buffer
.extend_from_slice(array.value(i).as_bytes());
}
ColumnarValueRef::NonNullableBinaryArray(array) => {
self.value_buffer.extend_from_slice(array.value(i));
}
}
}

Expand Down Expand Up @@ -341,6 +367,8 @@ pub enum ColumnarValueRef<'a> {
NonNullableLargeStringArray(&'a LargeStringArray),
NullableStringViewArray(&'a StringViewArray),
NonNullableStringViewArray(&'a StringViewArray),
NullableBinaryArray(&'a BinaryArray),
NonNullableBinaryArray(&'a BinaryArray),
}

impl ColumnarValueRef<'_> {
Expand All @@ -350,10 +378,12 @@ impl ColumnarValueRef<'_> {
Self::Scalar(_)
| Self::NonNullableArray(_)
| Self::NonNullableLargeStringArray(_)
| Self::NonNullableStringViewArray(_) => true,
| Self::NonNullableStringViewArray(_)
| Self::NonNullableBinaryArray(_) => true,
Self::NullableArray(array) => array.is_valid(i),
Self::NullableStringViewArray(array) => array.is_valid(i),
Self::NullableLargeStringArray(array) => array.is_valid(i),
Self::NullableBinaryArray(array) => array.is_valid(i),
}
}

Expand All @@ -363,10 +393,12 @@ impl ColumnarValueRef<'_> {
Self::Scalar(_)
| Self::NonNullableArray(_)
| Self::NonNullableStringViewArray(_)
| Self::NonNullableLargeStringArray(_) => None,
| Self::NonNullableLargeStringArray(_)
| Self::NonNullableBinaryArray(_) => None,
Self::NullableArray(array) => array.nulls().cloned(),
Self::NullableStringViewArray(array) => array.nulls().cloned(),
Self::NullableLargeStringArray(array) => array.nulls().cloned(),
Self::NullableBinaryArray(array) => array.nulls().cloned(),
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,18 @@ fn pre_selection_scatter(
}

fn concat_elements(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef> {
if *left.data_type() == DataType::Binary && *right.data_type() == DataType::Binary {
// Cast Binary to Utf8 to validate UTF-8 encoding before concatenation
// Follow widespread approach of PostgreSQL, sqlite, DuckDB, Snowflake
// Spark does it in a different way by making a binary-to-binary concatenation
let left = cast(left.as_ref(), &DataType::Utf8)?;
let right = cast(right.as_ref(), &DataType::Utf8)?;
return Ok(Arc::new(concat_elements_utf8(
left.as_string::<i32>(),
right.as_string::<i32>(),
)?));
}

Ok(match left.data_type() {
DataType::Utf8 => Arc::new(concat_elements_utf8(
left.as_string::<i32>(),
Expand Down
4 changes: 3 additions & 1 deletion datafusion/sqllogictest/test_files/information_schema.slt
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,10 @@ datafusion public string_agg 1 OUT NULL String NULL false 1
query TTTBI rowsort
select specific_name, data_type, parameter_mode, is_variadic, rid from information_schema.parameters where specific_name = 'concat';
----
concat String IN true 0
concat Binary IN true 0
concat String IN true 1
concat String OUT false 0
concat String OUT false 1

# test ceorcion signature
query TTITI rowsort
Expand Down
19 changes: 19 additions & 0 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,25 @@ SELECT 'a' || 42 || 23.3
----
a4223.3

# concat of binary and text provides a text output
query T
select arrow_cast('Café', 'Utf8') || arrow_cast('Foobar', 'Binary');
----
CaféFoobar

query T
select arrow_cast('Café', 'Binary') || arrow_cast('Foobar', 'Utf8');
----
CaféFoobar

# Concat of two binaries should cast arguments to text and produce a text output,
# following common behaviour of PostreSQL. However, Spark is providing binary
query T
select arrow_cast('Café', 'Binary') || arrow_cast('Foobar', 'Binary');
----
CaféFoobar


# test_not_expressions()

query BB
Expand Down
24 changes: 24 additions & 0 deletions datafusion/sqllogictest/test_files/spark/string/concat.slt
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,27 @@ query TT
SELECT concat('a', arrow_cast('b', 'LargeUtf8'), arrow_cast('c', 'Utf8View')), arrow_typeof(concat('a', arrow_cast('b', 'LargeUtf8'), arrow_cast('c', 'Utf8View')));
----
abc Utf8View

# Test mixed types: Utf8 + Binary
query TT
SELECT concat(arrow_cast('hello', 'Utf8'), arrow_cast(' world', 'Binary')), arrow_typeof(concat(arrow_cast('hello', 'Utf8'), arrow_cast(' world', 'Binary')));
----
hello world Utf8

# Test mixed types: Utf8View + Binary
query TT
SELECT concat(arrow_cast('hello', 'Utf8View'), arrow_cast(' world', 'Binary')), arrow_typeof(concat(arrow_cast('hello', 'Utf8View'), arrow_cast(' world', 'Binary')));
----
hello world Utf8View

# Test mixed types: Binary + Binary
query TT
SELECT concat(arrow_cast('hello', 'Binary'), arrow_cast(' world', 'Binary')), arrow_typeof(concat(arrow_cast('hello', 'Binary'), arrow_cast(' world', 'Binary')));
----
hello world Utf8

# Test mixed types with ws: Binary + Binary
query TT
SELECT concat_ws('|', arrow_cast('hello', 'Binary'), arrow_cast('world', 'Binary')), arrow_typeof(concat_ws('|', arrow_cast('hello', 'Binary'), arrow_cast('world', 'Binary')));
----
hello|world Utf8
Loading