diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 0dc35f4a87776..e3c11dde7693e 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -66,3 +66,7 @@ name = "char" [[bench]] harness = false name = "space" + +[[bench]] +harness = false +name = "hex" diff --git a/datafusion/spark/benches/hex.rs b/datafusion/spark/benches/hex.rs new file mode 100644 index 0000000000000..756352b034c34 --- /dev/null +++ b/datafusion/spark/benches/hex.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::*; +use arrow::datatypes::*; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::math::hex::SparkHex; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +fn generate_int64_data(size: usize, null_density: f32) -> PrimitiveArray { + let mut rng = seedable_rng(); + (0..size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range::(-999_999_999_999..999_999_999_999)) + } + }) + .collect() +} + +fn generate_utf8_data(size: usize, null_density: f32) -> StringArray { + let mut rng = seedable_rng(); + let mut builder = StringBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let s: String = + std::iter::repeat_with(|| rng.random_range(b'a'..=b'z') as char) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_binary_data(size: usize, null_density: f32) -> BinaryArray { + let mut rng = seedable_rng(); + let mut builder = BinaryBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let bytes: Vec = (0..len).map(|_| rng.random()).collect(); + builder.append_value(&bytes); + } + } + builder.finish() +} + +fn generate_int64_dict_data( + size: usize, + null_density: f32, +) -> DictionaryArray { + let mut rng = seedable_rng(); + let mut builder = PrimitiveDictionaryBuilder::::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + builder.append_value( + rng.random_range::(-999_999_999_999..999_999_999_999), + ); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, array: Arc) { + let hex_func = SparkHex::new(); + let args = vec![ColumnarValue::Array(array)]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + hex_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Utf8, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + for &size in &sizes { + let data = generate_int64_data(size, null_density); + run_benchmark(c, "hex_int64", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_utf8_data(size, null_density); + run_benchmark(c, "hex_utf8", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_binary_data(size, null_density); + run_benchmark(c, "hex_binary", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_int64_dict_data(size, null_density); + run_benchmark(c, "hex_int64_dict", size, Arc::new(data)); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs index ef62b08fb03d2..134324f45f5bc 100644 --- a/datafusion/spark/src/function/math/hex.rs +++ b/datafusion/spark/src/function/math/hex.rs @@ -16,9 +16,10 @@ // under the License. use std::any::Any; +use std::str::from_utf8_unchecked; use std::sync::Arc; -use arrow::array::{Array, StringArray}; +use arrow::array::{Array, BinaryArray, Int64Array, StringArray, StringBuilder}; use arrow::datatypes::DataType; use arrow::{ array::{as_dictionary_array, as_largestring_array, as_string_array}, @@ -110,37 +111,85 @@ impl ScalarUDFImpl for SparkHex { } } -fn hex_int64(num: i64) -> String { - format!("{num:X}") -} - /// Hex encoding lookup tables for fast byte-to-hex conversion const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef"; const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF"; #[inline] -fn hex_encode>(data: T, lower_case: bool) -> String { - let bytes = data.as_ref(); - let mut s = String::with_capacity(bytes.len() * 2); - let hex_chars = if lower_case { +fn hex_int64(num: i64, buffer: &mut [u8; 16]) -> &[u8] { + if num == 0 { + return b"0"; + } + + let mut n = num as u64; + let mut i = 16; + while n != 0 { + i -= 1; + buffer[i] = HEX_CHARS_UPPER[(n & 0xF) as usize]; + n >>= 4; + } + &buffer[i..] +} + +/// Generic hex encoding for byte array types +fn hex_encode_bytes<'a, I, T>( + iter: I, + lowercase: bool, + len: usize, +) -> Result +where + I: Iterator>, + T: AsRef<[u8]> + 'a, +{ + let mut builder = StringBuilder::with_capacity(len, len * 64); + let mut buffer = Vec::with_capacity(64); + let hex_chars = if lowercase { HEX_CHARS_LOWER } else { HEX_CHARS_UPPER }; - for &b in bytes { - s.push(hex_chars[(b >> 4) as usize] as char); - s.push(hex_chars[(b & 0x0f) as usize] as char); + + for v in iter { + if let Some(b) = v { + buffer.clear(); + let bytes = b.as_ref(); + for &byte in bytes { + buffer.push(hex_chars[(byte >> 4) as usize]); + buffer.push(hex_chars[(byte & 0x0f) as usize]); + } + // SAFETY: buffer contains only ASCII hex digests, which are valid UTF-8 + unsafe { + builder.append_value(from_utf8_unchecked(&buffer)); + } + } else { + builder.append_null(); + } } - s + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } -#[inline(always)] -fn hex_bytes>( - bytes: T, - lowercase: bool, -) -> Result { - let hex_string = hex_encode(bytes, lowercase); - Ok(hex_string) +/// Generic hex encoding for int64 type +fn hex_encode_int64(iter: I, len: usize) -> Result +where + I: Iterator>, +{ + let mut builder = StringBuilder::with_capacity(len, len * 16); + + for v in iter { + if let Some(num) = v { + let mut temp = [0u8; 16]; + let slice = hex_int64(num, &mut temp); + // SAFETY: slice contains only ASCII hex digests, which are valid UTF-8 + unsafe { + builder.append_value(from_utf8_unchecked(slice)); + } + } else { + builder.append_null(); + } + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } /// Spark-compatible `hex` function @@ -166,103 +215,55 @@ pub fn compute_hex( ColumnarValue::Array(array) => match array.data_type() { DataType::Int64 => { let array = as_int64_array(array)?; - - let hexed_array: StringArray = - array.iter().map(|v| v.map(hex_int64)).collect(); - - Ok(ColumnarValue::Array(Arc::new(hexed_array))) + hex_encode_int64(array.iter(), array.len()) } DataType::Utf8 => { let array = as_string_array(array); - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + hex_encode_bytes(array.iter(), lowercase, array.len()) } DataType::Utf8View => { let array = as_string_view_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + hex_encode_bytes(array.iter(), lowercase, array.len()) } DataType::LargeUtf8 => { let array = as_largestring_array(array); - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + hex_encode_bytes(array.iter(), lowercase, array.len()) } DataType::Binary => { let array = as_binary_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + hex_encode_bytes(array.iter(), lowercase, array.len()) } DataType::LargeBinary => { let array = as_large_binary_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + hex_encode_bytes(array.iter(), lowercase, array.len()) } DataType::FixedSizeBinary(_) => { let array = as_fixed_size_binary_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + hex_encode_bytes(array.iter(), lowercase, array.len()) } DataType::Dictionary(_, value_type) => { let dict = as_dictionary_array::(&array); - let values = match **value_type { - DataType::Int64 => as_int64_array(dict.values())? - .iter() - .map(|v| v.map(hex_int64)) - .collect::>(), - DataType::Utf8 => as_string_array(dict.values()) - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?, - DataType::Binary => as_binary_array(dict.values())? - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?, - _ => exec_err!( - "hex got an unexpected argument type: {}", - array.data_type() - )?, - }; - - let new_values: Vec> = dict - .keys() - .iter() - .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None)) - .collect(); - - let string_array_values = StringArray::from(new_values); - - Ok(ColumnarValue::Array(Arc::new(string_array_values))) + match **value_type { + DataType::Int64 => { + let arr = dict.downcast_dict::().unwrap(); + hex_encode_int64(arr.into_iter(), dict.len()) + } + DataType::Utf8 => { + let arr = dict.downcast_dict::().unwrap(); + hex_encode_bytes(arr.into_iter(), lowercase, dict.len()) + } + DataType::Binary => { + let arr = dict.downcast_dict::().unwrap(); + hex_encode_bytes(arr.into_iter(), lowercase, dict.len()) + } + _ => { + exec_err!( + "hex got an unexpected argument type: {}", + array.data_type() + ) + } + } } _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()), }, @@ -272,9 +273,10 @@ pub fn compute_hex( #[cfg(test)] mod test { + use std::str::from_utf8_unchecked; use std::sync::Arc; - use arrow::array::{Int64Array, StringArray}; + use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray}; use arrow::{ array::{ BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder, @@ -373,13 +375,17 @@ mod test { #[test] fn test_hex_int64() { - let num = 1234; - let hexed = super::hex_int64(num); - assert_eq!(hexed, "4D2".to_string()); + let test_cases = vec![(1234, "4D2"), (-1, "FFFFFFFFFFFFFFFF")]; + + for (num, expected) in test_cases { + let mut cache = [0u8; 16]; + let slice = super::hex_int64(num, &mut cache); - let num = -1; - let hexed = super::hex_int64(num); - assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); + unsafe { + let result = from_utf8_unchecked(slice); + assert_eq!(expected, result); + } + } } #[test] @@ -403,4 +409,25 @@ mod test { assert_eq!(string_array, &expected_array); } + + #[test] + fn test_dict_values_null() { + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let vals = Int64Array::from(vec![Some(32), None]); + // [32, null, null] + let dict = DictionaryArray::new(keys, Arc::new(vals)); + + let columnar_value = ColumnarValue::Array(Arc::new(dict)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + let expected = StringArray::from(vec![Some("20"), None, None]); + + assert_eq!(&expected, result); + } }