diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index f29a4feb47..6f3a06a3d3 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -584,7 +584,7 @@ expression-level). The `outer` variants are wired but marked `Incompatible`; the | `soundex` | ✅ | | | `space` | ✅ | | | `split` | ✅ | | -| `split_part` | 🔜 | Lowers to `element_at(StringSplitSQL(...))`; `StringSplitSQL` falls back ([#4561](https://github.com/apache/datafusion-comet/issues/4561)) | +| `split_part` | ✅ | Spark 4.0+ | | `startswith` | ✅ | | | `substr` | ✅ | | | `substring` | ✅ | | diff --git a/native/spark-expr/src/array_funcs/list_extract.rs b/native/spark-expr/src/array_funcs/list_extract.rs index c44237efc9..61b3b19b59 100644 --- a/native/spark-expr/src/array_funcs/list_extract.rs +++ b/native/spark-expr/src/array_funcs/list_extract.rs @@ -142,6 +142,7 @@ impl PhysicalExpr for ListExtract { fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; let ordinal_value = self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?; + let element_type = self.data_type(&batch.schema())?; let default_value = self .default_value @@ -149,9 +150,9 @@ impl PhysicalExpr for ListExtract { .map(|d| { d.evaluate(batch).map(|value| match value { ColumnarValue::Scalar(scalar) - if !scalar.data_type().equals_datatype(child_value.data_type()) => + if !scalar.data_type().equals_datatype(&element_type) => { - scalar.cast_to(child_value.data_type()) + scalar.cast_to(&element_type) } ColumnarValue::Scalar(scalar) => Ok(scalar), v => Err(DataFusionError::Execution(format!( @@ -160,7 +161,7 @@ impl PhysicalExpr for ListExtract { }) }) .transpose()? - .unwrap_or(self.data_type(&batch.schema())?.try_into())?; + .unwrap_or(element_type.try_into())?; // Create error wrapper closure that has access to self let error_wrapper = |error: SparkError| self.wrap_error_with_context(error); diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 62eeaa2b1d..770a617a3e 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -198,6 +198,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(crate::string_funcs::spark_split); make_comet_scalar_udf!("split", func, without data_type) } + "split_sql" => { + let func = Arc::new(crate::string_funcs::spark_split_sql); + make_comet_scalar_udf!("split_sql", func, without data_type) + } "regexp_extract" => { let func = Arc::new(crate::string_funcs::spark_regexp_extract); make_comet_scalar_udf!("regexp_extract", func, without data_type) diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index ce1b8009a8..44c90ec6c2 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -27,5 +27,5 @@ pub use contains::SparkContains; pub use get_json_object::spark_get_json_object; pub use regexp_extract::spark_regexp_extract; pub use regexp_extract_all::spark_regexp_extract_all; -pub use split::spark_split; +pub use split::{spark_split, spark_split_sql}; pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/string_funcs/split.rs b/native/spark-expr/src/string_funcs/split.rs index 7e4a6af171..e88c748c16 100644 --- a/native/spark-expr/src/string_funcs/split.rs +++ b/native/spark-expr/src/string_funcs/split.rs @@ -17,7 +17,7 @@ use arrow::array::{ Array, ArrayBuilder, ArrayRef, GenericListArray, GenericStringArray, GenericStringBuilder, - ListArray, OffsetSizeTrait, + ListArray, NullBufferBuilder, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; @@ -84,13 +84,13 @@ pub fn spark_split(args: &[ColumnarValue]) -> DataFusionResult { ColumnarValue::Scalar(pattern_val), ) => { if string.is_none() { - return Ok(ColumnarValue::Scalar(ScalarValue::Null)); + return Ok(ColumnarValue::Scalar(new_null_list_scalar())); } let pattern_str = match pattern_val { ScalarValue::Utf8(Some(p)) | ScalarValue::LargeUtf8(Some(p)) => p, ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) => { - return Ok(ColumnarValue::Scalar(ScalarValue::Null)); + return Ok(ColumnarValue::Scalar(new_null_list_scalar())); } _ => { return exec_err!("split pattern must be a string"); @@ -109,6 +109,61 @@ pub fn spark_split(args: &[ColumnarValue]) -> DataFusionResult { } } +/// Spark-compatible StringSplitSQL function. +/// Splits a string around literal delimiter matches and keeps trailing empty strings. +pub fn spark_split_sql(args: &[ColumnarValue]) -> DataFusionResult { + if args.len() != 2 { + return exec_err!( + "split_sql expects 2 arguments (string, delimiter), got {}", + args.len() + ); + } + + match (&args[0], &args[1]) { + (ColumnarValue::Array(string_array), ColumnarValue::Scalar(delimiter)) => { + let delimiter = match delimiter { + ScalarValue::Utf8(Some(d)) | ScalarValue::LargeUtf8(Some(d)) => d, + ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) => { + return Ok(ColumnarValue::Array(new_null_list_array( + string_array.len(), + ))); + } + _ => return exec_err!("split_sql delimiter must be a string"), + }; + split_sql_array_scalar(string_array.as_ref(), delimiter) + } + (ColumnarValue::Array(string_array), ColumnarValue::Array(delimiter_array)) => { + split_sql_array_array(string_array.as_ref(), delimiter_array.as_ref()) + } + (ColumnarValue::Scalar(ScalarValue::Utf8(string)), ColumnarValue::Scalar(delimiter)) + | ( + ColumnarValue::Scalar(ScalarValue::LargeUtf8(string)), + ColumnarValue::Scalar(delimiter), + ) => { + if string.is_none() { + return Ok(ColumnarValue::Scalar(new_null_list_scalar())); + } + + let delimiter = match delimiter { + ScalarValue::Utf8(Some(d)) | ScalarValue::LargeUtf8(Some(d)) => d, + ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) => { + return Ok(ColumnarValue::Scalar(new_null_list_scalar())); + } + _ => return exec_err!("split_sql delimiter must be a string"), + }; + + let result = split_sql_string(string.as_ref().unwrap(), delimiter); + let string_array = GenericStringArray::::from(result); + let list_array = create_list_array(Arc::new(string_array)); + + Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new( + list_array, + )))) + } + _ => exec_err!("split_sql expects string arguments"), + } +} + fn split_array( string_array: &dyn arrow::array::Array, pattern: &str, @@ -133,6 +188,63 @@ fn split_array( } } +fn split_sql_array_scalar( + string_array: &dyn arrow::array::Array, + delimiter: &str, +) -> DataFusionResult { + match string_array.data_type() { + DataType::Utf8 => split_sql_generic_scalar::( + as_generic_string_array::(string_array)?, + delimiter, + ), + DataType::LargeUtf8 => split_sql_generic_scalar::( + as_generic_string_array::(string_array)?, + delimiter, + ), + _ => exec_err!( + "split_sql expects Utf8 or LargeUtf8 string array, got {:?}", + string_array.data_type() + ), + } +} + +fn split_sql_array_array( + string_array: &dyn arrow::array::Array, + delimiter_array: &dyn arrow::array::Array, +) -> DataFusionResult { + if string_array.len() != delimiter_array.len() { + return exec_err!( + "split_sql string and delimiter arrays must have the same length, got {} and {}", + string_array.len(), + delimiter_array.len() + ); + } + + match (string_array.data_type(), delimiter_array.data_type()) { + (DataType::Utf8, DataType::Utf8) => split_sql_generic_array::( + as_generic_string_array::(string_array)?, + as_generic_string_array::(delimiter_array)?, + ), + (DataType::Utf8, DataType::LargeUtf8) => split_sql_generic_array::( + as_generic_string_array::(string_array)?, + as_generic_string_array::(delimiter_array)?, + ), + (DataType::LargeUtf8, DataType::Utf8) => split_sql_generic_array::( + as_generic_string_array::(string_array)?, + as_generic_string_array::(delimiter_array)?, + ), + (DataType::LargeUtf8, DataType::LargeUtf8) => split_sql_generic_array::( + as_generic_string_array::(string_array)?, + as_generic_string_array::(delimiter_array)?, + ), + _ => exec_err!( + "split_sql expects Utf8 or LargeUtf8 string arrays, got {:?} and {:?}", + string_array.data_type(), + delimiter_array.data_type() + ), + } +} + fn split_generic( string_array: &GenericStringArray, regex: &Regex, @@ -171,6 +283,80 @@ fn split_generic( Ok(ColumnarValue::Array(Arc::new(list_array))) } +fn split_sql_generic_scalar( + string_array: &GenericStringArray, + delimiter: &str, +) -> DataFusionResult { + let len = string_array.len(); + let mut offsets: Vec = Vec::with_capacity(len + 1); + let mut values_builder = GenericStringBuilder::::new(); + offsets.push(O::usize_as(0)); + + for i in 0..len { + if !string_array.is_null(i) { + push_split_sql_parts(string_array.value(i), delimiter, &mut values_builder); + } + offsets.push(O::usize_as(values_builder.len())); + } + + let values_array = Arc::new(values_builder.finish()) as ArrayRef; + let item_type = if O::IS_LARGE { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; + let field = Arc::new(Field::new("item", item_type, false)); + let list_array = GenericListArray::::new( + field, + OffsetBuffer::new(offsets.into()), + values_array, + string_array.nulls().cloned(), + ); + + Ok(ColumnarValue::Array(Arc::new(list_array))) +} + +fn split_sql_generic_array( + string_array: &GenericStringArray, + delimiter_array: &GenericStringArray, +) -> DataFusionResult { + let len = string_array.len(); + let mut offsets: Vec = Vec::with_capacity(len + 1); + let mut values_builder = GenericStringBuilder::::new(); + let mut nulls = NullBufferBuilder::new(len); + offsets.push(O::usize_as(0)); + + for i in 0..len { + if string_array.is_null(i) || delimiter_array.is_null(i) { + nulls.append_null(); + } else { + push_split_sql_parts( + string_array.value(i), + delimiter_array.value(i), + &mut values_builder, + ); + nulls.append_non_null(); + } + offsets.push(O::usize_as(values_builder.len())); + } + + let values_array = Arc::new(values_builder.finish()) as ArrayRef; + let item_type = if O::IS_LARGE { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; + let field = Arc::new(Field::new("item", item_type, false)); + let list_array = GenericListArray::::new( + field, + OffsetBuffer::new(offsets.into()), + values_array, + nulls.finish(), + ); + + Ok(ColumnarValue::Array(Arc::new(list_array))) +} + /// Push the splits of `string` into `builder`. Avoids materializing an /// intermediate `Vec` — appends each `&str` slice from the regex /// iterator directly (the builder copies into its own buffer). @@ -214,6 +400,20 @@ fn push_split_parts( } } +fn push_split_sql_parts( + string: &str, + delimiter: &str, + builder: &mut GenericStringBuilder, +) { + if delimiter.is_empty() { + builder.append_value(string); + } else { + for p in string.split(delimiter) { + builder.append_value(p); + } + } +} + fn split_string(string: &str, pattern: &str, limit: i32) -> DataFusionResult> { let regex = Regex::new(pattern).map_err(|e| { DataFusionError::Execution(format!("Invalid regex pattern '{}': {}", pattern, e)) @@ -256,6 +456,14 @@ fn split_with_regex(string: &str, regex: &Regex, limit: i32) -> Vec { } } +fn split_sql_string(string: &str, delimiter: &str) -> Vec { + if delimiter.is_empty() { + vec![string.to_string()] + } else { + string.split(delimiter).map(|s| s.to_string()).collect() + } +} + fn create_list_array(values: ArrayRef) -> ListArray { let field = Arc::new(Field::new("item", DataType::Utf8, false)); let offsets = vec![0i32, values.len() as i32]; @@ -268,17 +476,25 @@ fn create_list_array(values: ArrayRef) -> ListArray { } fn new_null_list_array(len: usize) -> ArrayRef { + Arc::new(new_null_list_array_value(len)) +} + +fn new_null_list_scalar() -> ScalarValue { + ScalarValue::List(Arc::new(new_null_list_array_value(1))) +} + +fn new_null_list_array_value(len: usize) -> ListArray { let field = Arc::new(Field::new("item", DataType::Utf8, false)); let values = Arc::new(GenericStringArray::::from(Vec::::new())) as ArrayRef; let offsets = vec![0i32; len + 1]; let nulls = arrow::buffer::NullBuffer::new_null(len); - Arc::new(ListArray::new( + ListArray::new( field, arrow::buffer::OffsetBuffer::new(offsets.into()), values, Some(nulls), - )) + ) } #[cfg(test)] @@ -369,4 +585,50 @@ mod tests { let parts = split_string("", ",", -1).unwrap(); assert_eq!(parts, vec![""]); } + + #[test] + fn test_split_sql_literal_delimiter() { + let parts = split_sql_string("a.b.", "."); + assert_eq!(parts, vec!["a", "b", ""]); + } + + #[test] + fn test_split_sql_empty_delimiter() { + let parts = split_sql_string("abc", ""); + assert_eq!(parts, vec!["abc"]); + } + + #[test] + fn test_split_sql_keeps_regex_chars_literal() { + let parts = split_sql_string("a.b.c", "."); + assert_eq!(parts, vec!["a", "b", "c"]); + } + + #[test] + fn test_split_sql_scalar_nulls_return_typed_null_list() { + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let result = spark_split_sql(&[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + delimiter.clone(), + ]) + .unwrap(); + assert_null_list_scalar(result); + + let result = spark_split_sql(&[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("a,b".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ]) + .unwrap(); + assert_null_list_scalar(result); + } + + fn assert_null_list_scalar(result: ColumnarValue) { + match result { + ColumnarValue::Scalar(ScalarValue::List(array)) => { + assert_eq!(array.len(), 1); + assert!(array.is_null(0)); + } + _ => panic!("Expected typed null list scalar, got {result:?}"), + } + } } diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/CometExprShim4x.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/CometExprShim4x.scala index 349e34bb70..234be4dc54 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/CometExprShim4x.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/CometExprShim4x.scala @@ -19,15 +19,16 @@ package org.apache.comet.shims -import org.apache.spark.sql.catalyst.expressions.{Attribute, DayName, Expression, Literal, MonthName, StructsToXml, XmlToStructs} +import org.apache.spark.sql.catalyst.expressions.{Attribute, DayName, Expression, Literal, MonthName, StringSplitSQL, StructsToXml, XmlToStructs} import org.apache.spark.sql.catalyst.expressions.csv.SchemaOfCsvEvaluator import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionUtils, SchemaOfJsonEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.expressions.xml.{XmlExpressionEvalUtils, XPathEvaluator} +import org.apache.comet.CometSparkSessionExtensions.withFallbackReason import org.apache.comet.serde.CometScalaUDF import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, hasNonDefaultStringCollation, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType} /** * Expression conversions shared across all Spark 4.x minor versions, compiled from the @@ -38,6 +39,10 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFa */ trait CometExprShim4x { + private val stringSplitSQLCollationReason = + "StringSplitSQL does not support non-UTF8_BINARY collations " + + "(https://github.com/apache/datafusion-comet/issues/2190)" + /** * `dayname` / `monthname` (Spark 4.0+) map a `DateType` value to a fixed US-English abbreviated * name. Spark's `DateTimeUtils.getDayName` / `getMonthName` use `DayOfWeek` / `Month` @@ -62,6 +67,32 @@ trait CometExprShim4x { case _ => None } + /** + * `split_part` lowers to `element_at(StringSplitSQL(...), partNum)`. StringSplitSQL uses a + * literal delimiter instead of a regex pattern, unlike `split` / `StringSplit`. + */ + protected def convertStringSplitSQL( + expr: StringSplitSQL, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + if (hasNonDefaultStringCollation(expr.dataType) || + hasNonDefaultStringCollation(expr.str.dataType) || + hasNonDefaultStringCollation(expr.delimiter.dataType)) { + withFallbackReason(expr, stringSplitSQLCollationReason) + return None + } + + val strExpr = exprToProtoInternal(expr.str, inputs, binding) + val delimiterExpr = exprToProtoInternal(expr.delimiter, inputs, binding) + val splitExpr = scalarFunctionExprToProtoWithReturnType( + "split_sql", + expr.dataType, + false, + strExpr, + delimiterExpr) + optExprWithFallbackReason(splitExpr, expr, expr.str, expr.delimiter) + } + // Spark 4.x lowers the RuntimeReplaceable structured-text functions to an evaluator-backed // `Invoke` (`schema_of_csv`, `schema_of_json`, `xpath_*`) or `StaticInvoke` // (`json_object_keys`, `schema_of_xml`) before Comet sees the plan, so the original expression diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala index 645de7ba75..7efd17f68a 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/Spark4xCometExprShim.scala @@ -130,6 +130,9 @@ trait Spark4xCometExprShim extends CometExprShim4x { case _: DayName | _: MonthName => convertDayMonthName(expr, inputs, binding) + case s: StringSplitSQL => + convertStringSplitSQL(s, inputs, binding) + case _ => None } } diff --git a/spark/src/test/resources/sql-tests/expressions/string/split_part.sql b/spark/src/test/resources/sql-tests/expressions/string/split_part.sql new file mode 100644 index 0000000000..3c0ca0807d --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/split_part.sql @@ -0,0 +1,56 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- MinSparkVersion: 4.0 + +statement +CREATE TABLE test_split_part(s string, d string, p int) USING parquet + +statement +INSERT INTO test_split_part VALUES + ('one.two.three', '.', 2), + ('a||b||', '||', 3), + ('abc', '', 1), + (NULL, '.', 1), + ('abc', NULL, 1) + +query +SELECT split_part(s, d, p) FROM test_split_part + +-- literal delimiter: delimiter is literal, not regex +query +SELECT split_part('a.b.c', '.', 2), split_part('a|b|c', '|', 3) + +-- negative part numbers select from the end +query +SELECT split_part('one/two/three', '/', -1), split_part('one/two/three', '/', -2) + +-- out-of-range part numbers return the element_at default +query +SELECT split_part('a.b', '.', 4), split_part('a.b', '.', -4) + +-- literal NULL arguments still produce a typed native array child for element_at +query +SELECT split_part(CAST(NULL AS STRING), '.', 1), split_part('abc', CAST(NULL AS STRING), 1) + +-- part number zero follows element_at semantics +query expect_error(INVALID_INDEX_OF_ZERO) +SELECT split_part('a.b', '.', 0) + +-- StringSplitSQL is collation-aware in Spark. Comet does not support non-default collations yet. +query expect_fallback(StringSplitSQL does not support non-UTF8_BINARY collations) +SELECT split_part('Hello' COLLATE UTF8_LCASE, 'L' COLLATE UTF8_LCASE, 2)