diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index c8ac8adac0..4f697451c5 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -340,6 +340,7 @@ jobs: org.apache.comet.exec.CometGenerateExecSuite org.apache.comet.exec.CometWindowExecSuite org.apache.comet.exec.CometJoinSuite + org.apache.spark.sql.comet.CometMapInBatchSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSetOpWithGroupBySuite org.apache.comet.CometSparkSessionExtensionsSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 3c6953aade..046dc6f47a 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -193,6 +193,7 @@ jobs: org.apache.comet.exec.CometGenerateExecSuite org.apache.comet.exec.CometWindowExecSuite org.apache.comet.exec.CometJoinSuite + org.apache.spark.sql.comet.CometMapInBatchSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSetOpWithGroupBySuite org.apache.comet.CometSparkSessionExtensionsSuite diff --git a/.github/workflows/pyarrow_udf_test.yml b/.github/workflows/pyarrow_udf_test.yml new file mode 100644 index 0000000000..e325ab8b6d --- /dev/null +++ b/.github/workflows/pyarrow_udf_test.yml @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: PyArrow UDF Tests + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +on: + push: + branches: + - main + paths: &feature-paths + - "pom.xml" + - "common/pom.xml" + - "common/src/main/scala/org/apache/comet/CometConf.scala" + - "spark/pom.xml" + - "spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala" + - "spark/src/main/scala/org/apache/spark/sql/comet/CometMapInBatchExec.scala" + - "spark/src/main/scala/org/apache/spark/sql/comet/shims/MapInBatchInfo.scala" + - "spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala" + - "spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala" + - "spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala" + - "spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala" + - "spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala" + - "spark/src/main/spark-4.x/org/apache/spark/sql/comet/shims/Spark4xMapInBatchSupport.scala" + - "spark/src/test/resources/pyspark/conftest.py" + - "spark/src/test/resources/pyspark/test_pyarrow_udf.py" + - "spark/src/test/spark-3.5/org/apache/spark/sql/comet/CometMapInBatchSuite.scala" + - "spark/src/test/spark-4.x/org/apache/spark/sql/comet/CometMapInBatchSuite.scala" + - ".github/workflows/pyarrow_udf_test.yml" + pull_request: + paths: *feature-paths + workflow_dispatch: + +permissions: + contents: read + +env: + RUST_VERSION: stable + RUST_BACKTRACE: 1 + RUSTFLAGS: "-Clink-arg=-fuse-ld=bfd" + +jobs: + pyarrow-udf: + name: PyArrow UDF (Spark 4.0, JDK 17, Python 3.11) + runs-on: ubuntu-latest + container: + # Pinned to the Debian 12 (bookworm) base so the system `python3` is 3.11. The default + # `amd64/rust` image is Debian 13 (trixie) which ships Python 3.13 and no python3.11 apt + # package, breaking `apt-get install python3.11`. + image: rust:bookworm + env: + JAVA_TOOL_OPTIONS: "--add-exports=java.base/sun.nio.ch=ALL-UNNAMED --add-exports=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.lang=ALL-UNNAMED" + steps: + - uses: actions/checkout@v6 + + - name: Setup Rust & Java toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ env.RUST_VERSION }} + jdk-version: 17 + + - name: Cache Maven dependencies + uses: actions/cache@v5 + with: + path: | + ~/.m2/repository + /root/.m2/repository + key: ${{ runner.os }}-java-maven-${{ hashFiles('**/pom.xml') }}-pyarrow-udf + restore-keys: | + ${{ runner.os }}-java-maven- + + - name: Build Comet (debug, Spark 4.0 / Scala 2.13) + run: | + cd native && cargo build + cd .. && ./mvnw -B install -DskipTests -Pspark-4.0 -Pscala-2.13 + + - name: Install Python 3.11 and pip + run: | + apt-get update + apt-get install -y --no-install-recommends python3 python3-venv python3-pip + python3 -m venv /tmp/venv + /tmp/venv/bin/pip install --upgrade pip + /tmp/venv/bin/pip install "pyspark==4.0.1" "pyarrow>=14" pandas pytest + + - name: Run PyArrow UDF pytest + env: + # Spark launches Python workers in a fresh subprocess and looks up `python3` + # on PATH unless PYSPARK_PYTHON is set. Without this, workers use the system + # python which has no pyarrow installed and UDF execution fails with + # ModuleNotFoundError. + PYSPARK_PYTHON: /tmp/venv/bin/python + PYSPARK_DRIVER_PYTHON: /tmp/venv/bin/python + run: | + /tmp/venv/bin/python -m pytest -v \ + spark/src/test/resources/pyspark/test_pyarrow_udf.py diff --git a/dev/ensure-jars-have-correct-contents.sh b/dev/ensure-jars-have-correct-contents.sh index 084936475d..3c8b7a3afa 100755 --- a/dev/ensure-jars-have-correct-contents.sh +++ b/dev/ensure-jars-have-correct-contents.sh @@ -91,6 +91,13 @@ allowed_expr+="|^org/apache/spark/shuffle/comet/.*$" allowed_expr+="|^org/apache/spark/sql/$" # allow ExplainPlanGenerator trait since it may not be available in older Spark versions allowed_expr+="|^org/apache/spark/sql/ExtendedExplainGenerator.*$" +# PyArrow UDF acceleration runner classes are under org/apache/spark/sql/execution/python +# because PythonArrowInput and BasicPythonArrowOutput are private[python]; Comet's classes +# must be in that package to mix them in. +allowed_expr+="|^org/apache/spark/sql/execution/$" +allowed_expr+="|^org/apache/spark/sql/execution/python/$" +allowed_expr+="|^org/apache/spark/sql/execution/python/CometColumnarPythonInput.*$" +allowed_expr+="|^org/apache/spark/sql/execution/python/CometArrowPythonRunner.*$" allowed_expr+="|^org/apache/spark/CometPlugin.class$" allowed_expr+="|^org/apache/spark/CometDriverPlugin.*$" allowed_expr+="|^org/apache/spark/CometSource.*$" diff --git a/docs/source/user-guide/latest/index.rst b/docs/source/user-guide/latest/index.rst index 314a0a51bd..70267c49e5 100644 --- a/docs/source/user-guide/latest/index.rst +++ b/docs/source/user-guide/latest/index.rst @@ -48,5 +48,6 @@ to read more. Understanding Comet Plans Tuning Guide Metrics Guide + PyArrow UDF Acceleration Iceberg Guide Kubernetes Guide diff --git a/docs/source/user-guide/latest/pyarrow-udfs.md b/docs/source/user-guide/latest/pyarrow-udfs.md new file mode 100644 index 0000000000..d0e2346998 --- /dev/null +++ b/docs/source/user-guide/latest/pyarrow-udfs.md @@ -0,0 +1,209 @@ + + +# PyArrow UDF Acceleration + +Comet can accelerate Python UDFs that use PyArrow-backed batch processing, such as `mapInArrow` and `mapInPandas`. +These APIs are commonly used for ML inference, feature engineering, and data transformation workloads. + +## Background + +Spark's `mapInArrow` and `mapInPandas` APIs allow users to apply Python functions that operate on Arrow +RecordBatches or Pandas DataFrames. Under the hood, Spark communicates with the Python worker process +using the Arrow IPC format. + +Without Comet, the execution path for these UDFs involves unnecessary data conversions: + +1. Comet reads data in Arrow columnar format (via CometScan) +2. Spark inserts a ColumnarToRow transition (converts Arrow to UnsafeRow) +3. The Python runner converts those rows back to Arrow to send to Python +4. Python executes the UDF on Arrow batches +5. Results are returned as Arrow and then converted back to rows + +Steps 2 and 3 are redundant since the data starts and ends in Arrow format. + +## How Comet Optimizes This + +When enabled, Comet detects `PythonMapInArrowExec` / `MapInArrowExec` and `MapInPandasExec` +operators in the physical plan and replaces them with `CometMapInBatchExec`, which: + +- Reads Arrow columnar batches directly from the upstream Comet operator +- Feeds them to the Python runner without the expensive UnsafeProjection copy +- Keeps the Python output in columnar format for downstream operators + +This eliminates the ColumnarToRow transition and the output row conversion, reducing CPU overhead +and memory allocations. The internal row-to-Arrow IPC re-encoding inside Spark's +`ArrowPythonRunner` is unchanged in this version; full round-trip elimination is tracked in +[#4240](https://github.com/apache/datafusion-comet/issues/4240). + +### Plan flow + +Without Comet's optimization: + +``` +PythonMapInArrow / MapInArrow / MapInPandas ++- ColumnarToRow <- Arrow -> Row copy + +- CometNativeExec <- Arrow batch + +- CometScan +``` + +With the optimization enabled: + +``` +CometMapInBatch <- Arrow batch in/out, Python runner attached ++- CometNativeExec + +- CometScan +``` + +## Configuration + +The optimization is experimental and disabled by default. Enable it with: + +``` +spark.comet.exec.pyarrowUdf.enabled=true +``` + +The default is `false` while the feature stabilizes. + +### Relationship to Spark's PySpark Arrow conversion conf + +`spark.comet.exec.pyarrowUdf.enabled` is **not** the same as PySpark's +[`spark.sql.execution.arrow.pyspark.enabled`](https://spark.apache.org/docs/latest/api/python/tutorial/sql/arrow_pandas.html#enabling-for-conversion-to-from-pandas). +That conf controls whether Spark uses Arrow when materializing a DataFrame to a Pandas DataFrame +(`toPandas()`) or constructing one from Pandas. The Comet conf controls a planner rewrite for +`mapInArrow` / `mapInPandas`, and only affects how Comet's columnar batches feed the Python +worker. Both confs can be set independently. + +## Supported APIs + +| PySpark API | Spark Plan Node | Supported | +| -------------------------------- | --------------------------- | --------- | +| `df.mapInArrow(func, schema)` | `PythonMapInArrowExec` | Yes | +| `df.mapInPandas(func, schema)` | `MapInPandasExec` | Yes | +| `@pandas_udf` (scalar) | `ArrowEvalPythonExec` | Not yet | +| `df.applyInPandas(func, schema)` | `FlatMapGroupsInPandasExec` | Not yet | + +## Example + +```python +import pyarrow as pa +from pyspark.sql import SparkSession, types as T + +spark = SparkSession.builder \ + .config("spark.plugins", "org.apache.spark.CometPlugin") \ + .config("spark.comet.enabled", "true") \ + .config("spark.comet.exec.enabled", "true") \ + .config("spark.comet.exec.pyarrowUdf.enabled", "true") \ + .config("spark.memory.offHeap.enabled", "true") \ + .config("spark.memory.offHeap.size", "2g") \ + .getOrCreate() + +df = spark.read.parquet("data.parquet") + +def transform(batch: pa.RecordBatch) -> pa.RecordBatch: + # Your transformation logic here + table = batch.to_pandas() + table["new_col"] = table["value"] * 2 + return pa.RecordBatch.from_pandas(table) + +output_schema = T.StructType([ + T.StructField("value", T.DoubleType()), + T.StructField("new_col", T.DoubleType()), +]) + +result = df.mapInArrow(transform, output_schema) +``` + +## Verifying the Optimization + +Use `explain()` to verify that `CometMapInBatch` appears in your plan: + +```python +result.explain(mode="extended") +``` + +You should see: + +``` +CometMapInBatch ... ++- CometNativeExec ... + +- CometScan ... +``` + +Instead of the unoptimized plan: + +``` +PythonMapInArrow ... ++- ColumnarToRow + +- CometNativeExec ... + +- CometScan ... +``` + +When AQE is enabled (the Spark default) and the query contains a shuffle, the +optimization is applied during stage materialization. Calling `explain()` before +running an action will show the unoptimized plan: + +``` +AdaptiveSparkPlan isFinalPlan=false ++- PythonMapInArrow ... + +- CometExchange ... +``` + +To see the optimized plan, run an action first (for example `result.collect()` or +`result.cache(); result.count()`) and then call `explain()`. The post-execution +plan shows the materialized stages and includes `CometMapInBatch` if the +optimization fired. + +## Barrier execution + +`mapInArrow(..., barrier=True)` and `mapInPandas(..., barrier=True)` are honored: the +optimized operator propagates `isBarrier` through `RDD.barrier()`, so all tasks are +gang-scheduled and `BarrierTaskContext.barrier()` works inside the UDF the same way it does +on the unoptimized path. + +## Limitations + +- The optimization currently applies only to `mapInArrow` and `mapInPandas`. Scalar pandas UDFs + (`@pandas_udf`) and grouped operations (`applyInPandas`) are not yet supported. +- The optimization requires Arrow data on the input side. If a shuffle sits between the upstream + Comet operator and the Python UDF, you need Comet's native shuffle for the optimization to + apply. Set `spark.shuffle.manager` to + `org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager` and enable + `spark.comet.exec.shuffle.enabled=true` at session startup. With a vanilla Spark `Exchange` + in the plan the data leaves the shuffle as rows and the optimization cannot fire. +- Spark 4.0 or newer is required. On Spark 3.4 and 3.5 the optimization is a no-op even when + enabled; vanilla `PythonMapInArrowExec` / `MapInPandasExec` handle the operation. The Spark 3.5 + `PythonArrowInput` trait has a different contract than 4.x and a separate implementation has + not been written. Track 3.5 support as a future follow-on if there is user demand. +- `spark.sql.execution.arrow.useLargeVarTypes=true` is not supported. With this conf enabled, + Spark widens `StringType` and `BinaryType` to Arrow's 8-byte-offset variants in the + destination IPC root, while Comet's source vectors always use 4-byte offsets. The buffer-copy + path cannot bridge that mismatch, so `EliminateRedundantTransitions` skips the rewrite and + vanilla Spark handles the operation. +- Each batch is copied twice on the JVM side: once from Comet's vectors into Spark's + destination IPC root (per-buffer `setBytes`), and a second time inside the IPC writer when + `VectorUnloader` / `MessageSerializer.serialize` walks the root and writes bytes to the + pipe to the Python worker. The pipe write is structural (Spark's transport to Python is + fork + pipe + Arrow IPC, so the buffer bytes must reach the pipe at least once); dropping + the first copy by serialising directly from Comet's vectors is tracked in + [#4294](https://github.com/apache/datafusion-comet/issues/4294). Even after that, + true zero-copy at the JVM boundary is blocked because Comet's source `FieldVector`s are + imported from native via Arrow C Data Interface (their buffers route `release` through FFI), + while Spark's destination IPC root is a child of `ArrowUtils.rootAllocator`. The two + reference managers cannot share buffers via `TransferPair`. diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 1a4eb3a8c8..63e7ee4247 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -293,6 +293,18 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_PYARROW_UDF_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.exec.pyarrowUdf.enabled") + .category(CATEGORY_EXEC) + .doc( + "Experimental: whether to enable optimized execution of PyArrow UDFs " + + "(mapInArrow/mapInPandas). When enabled, Comet passes Arrow columnar data " + + "directly to Python UDFs without the intermediate Arrow-to-Row-to-Arrow " + + "conversion that Spark normally performs. Disabled by default while the " + + "feature stabilizes.") + .booleanConf + .createWithDefault(false) + val COMET_TRACING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.tracing.enabled") .category(CATEGORY_TUNING) .doc(s"Enable fine-grained tracing of events and memory usage. $TRACING_GUIDE.") diff --git a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala index 7402a83248..ce3b78a9fa 100644 --- a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala +++ b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala @@ -22,13 +22,15 @@ package org.apache.comet.rules import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec, CometNativeColumnarToRowExec, CometNativeWriteExec, CometPlan, CometSparkToColumnarExec} +import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec, CometMapInBatchExec, CometNativeColumnarToRowExec, CometNativeWriteExec, CometPlan, CometSparkToColumnarExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.shims.{MapInBatchInfo, ShimCometMapInBatch} import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.QueryStageExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.comet.CometConf +import org.apache.comet.shims.ShimSQLConf // This rule is responsible for eliminating redundant transitions between row-based and // columnar-based operators for Comet. Currently, three potential redundant transitions are: @@ -51,7 +53,10 @@ import org.apache.comet.CometConf // various reasons) or Spark requests row-based output such as a `collect` call. Spark will adds // another `ColumnarToRowExec` on top of `CometSparkToColumnarExec`. In this case, the pair could // be removed. -case class EliminateRedundantTransitions(session: SparkSession) extends Rule[SparkPlan] { +case class EliminateRedundantTransitions(session: SparkSession) + extends Rule[SparkPlan] + with ShimCometMapInBatch + with ShimSQLConf { private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get() @@ -98,6 +103,25 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa case CometNativeColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) => sparkToColumnar.child case CometSparkToColumnarExec(child: CometSparkToColumnarExec) => child + // Replace MapInBatchExec (PythonMapInArrowExec / MapInArrowExec / MapInPandasExec) that has + // a ColumnarToRow child with CometMapInBatchExec, eliminating the input and output + // UnsafeProjection copies and keeping the stage columnar. The matchers are + // version-shimmed: Spark 3.4 / 3.5 return None (they lack the required APIs) and Spark + // 4.1+ matches the renamed `MapInArrowExec`. + // + // Falls back to vanilla Spark when `spark.sql.execution.arrow.useLargeVarTypes` is enabled: + // CometColumnarPythonInput.copyVector does raw `setBytes` on each Arrow buffer, but Comet's + // source string/binary vectors always use 4-byte offsets while the destination root is + // allocated with 8-byte offsets when this conf is on. The buffer counts match but the + // offset width does not, so a direct memcpy would corrupt the offsets. + case EligibleMapInBatch(info, columnarChild) => + CometMapInBatchExec( + info.func, + info.output, + columnarChild, + info.isBarrier, + info.pythonEvalType) + // Spark adds `RowToColumnar` under Comet columnar shuffle. But it's redundant as the // shuffle takes row-based input. case s @ CometShuffleExchangeExec( @@ -130,6 +154,39 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa } } + /** + * If the given plan is a Comet ColumnarToRow transition, returns the columnar child the Python + * UDF operator can consume directly. By the time this rule runs the earlier + * `hasCometNativeChild` arm has already rewritten any `ColumnarToRowExec` over a Comet columnar + * source to one of the Comet variants, so vanilla `ColumnarToRowExec` cannot reach here on a + * Comet-driven plan and is intentionally not handled. + */ + private def extractColumnarChild(plan: SparkPlan): Option[SparkPlan] = plan match { + case CometColumnarToRowExec(child) => Some(child) + case CometNativeColumnarToRowExec(child) => Some(child) + case _ => None + } + + /** + * Matches the plans this rule should rewrite to `CometMapInBatchExec`. Single extractor used in + * the `transformUp` arm above so the matchers and conf reads run once per visited plan. Returns + * `(info, columnarChild)` where `columnarChild` is the Comet columnar producer that + * `CometMapInBatchExec` will consume directly. Returns `None` (and the arm misses) when the + * conf is off, when `useLargeVarTypes` forces the fallback, when the plan is not one of the + * version-shimmed MapInArrow / MapInPandas operators, or when the child is not a Comet + * columnar-to-row transition we can strip. + */ + private object EligibleMapInBatch { + def unapply(plan: SparkPlan): Option[(MapInBatchInfo, SparkPlan)] = { + if (!CometConf.COMET_PYARROW_UDF_ENABLED.get()) None + else if (arrowUseLargeVarTypes(plan.conf)) None + else + matchMapInArrow(plan) + .orElse(matchMapInPandas(plan)) + .flatMap(info => extractColumnarChild(info.child).map(child => (info, child))) + } + } + /** * Creates an appropriate columnar to row transition operator. * diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMapInBatchExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMapInBatchExec.scala new file mode 100644 index 0000000000..4c40e68809 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMapInBatchExec.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.comet.shims.ShimCometMapInBatch +import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.python.PythonSQLMetrics +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} + +import org.apache.comet.vector.CometVector + +/** + * Comet replacement for Spark's `MapInBatchExec` family (`PythonMapInArrowExec` / + * `MapInArrowExec` in 4.1+ / `MapInPandasExec`). Feeds upstream Comet `ColumnarBatch` values + * directly to a `CometArrowPythonRunner`, eliminating the per-row `InternalRow.getXXX` loop that + * vanilla Spark's `ArrowPythonRunner` performs. + * + * Per-Spark-minor wiring lives in `ShimCometMapInBatch.computeArrowPython`. + */ +case class CometMapInBatchExec( + func: Expression, + output: Seq[Attribute], + child: SparkPlan, + isBarrier: Boolean, + pythonEvalType: Int) + extends UnaryExecNode + with CometPlan + with PythonSQLMetrics + with ShimCometMapInBatch { + + override def supportsColumnar: Boolean = true + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows")) ++ + pythonMetrics + + // Fallback for row-consuming parents (e.g. a top-level `collect()` that produces rows). + // Wraps this columnar exec in `ColumnarToRowExec`, reintroducing the row transition this + // operator otherwise eliminates. Only fires when nothing downstream consumes columnar. + override def doExecute(): RDD[InternalRow] = { + ColumnarToRowExec(this).doExecute() + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val numOutputBatches = longMetric("numOutputBatches") + val numInputRows = longMetric("numInputRows") + + val outputAttrs = output + val childSchema = child.schema + val evalType = pythonEvalType + val metricsCopy = pythonMetrics + + // Resolve every `SQLConf`-derived input on the driver. `SQLConf.get` reads from a thread-local + // `ConfigReader` that only exists on the driver, so dereferencing `conf` from inside the task + // closure NPEs. + val resolvedRunnerInputs = runnerInputs(func.asInstanceOf[PythonUDF], conf) + + val inputRDD = child.executeColumnar() + + def processPartition(batches: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { + val context = TaskContext.get() + val counting = batches.map { b => numInputRows += b.numRows(); b } + + val columnarBatchIter = computeArrowPython( + resolvedRunnerInputs, + evalType, + Array(Array(0)), + StructType(Array(StructField("struct", childSchema))), + metricsCopy, + Iterator(counting), + context.partitionId(), + context) + + columnarBatchIter.map { batch => + // Python returns a single struct column; flatten to the user's output columns and + // re-wrap each child as CometVector so consumers that expect Comet's vector hierarchy + // (e.g. another CometMapInBatchExec stacked on top, or NativeUtil.exportBatch for a + // downstream native Comet operator) see the right type. Sharing the underlying Arrow + // ValueVector with the original ArrowColumnVector is safe: close() on either ends up + // releasing the same buffers, and arrow-vector's release path is idempotent. + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors: Array[ColumnVector] = outputAttrs.indices.map { i => + val childArrow = structVector.getChild(i) + CometVector.getVector( + childArrow.getValueVector, + /* useDecimal128 */ true, + /* dictionaryProvider */ null) + }.toArray + val flattenedBatch = new ColumnarBatch(outputVectors) + flattenedBatch.setNumRows(batch.numRows()) + numOutputRows += flattenedBatch.numRows() + numOutputBatches += 1 + flattenedBatch + } + } + + // Preserve isBarrier semantics: when set, run inside a barrier stage so all tasks + // are gang-scheduled and BarrierTaskContext.barrier() works inside the UDF. + if (isBarrier) { + inputRDD.barrier().mapPartitions(processPartition) + } else { + inputRDD.mapPartitionsInternal(processPartition) + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): CometMapInBatchExec = + copy(child = newChild) +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/shims/MapInBatchInfo.scala b/spark/src/main/scala/org/apache/spark/sql/comet/shims/MapInBatchInfo.scala new file mode 100644 index 0000000000..f610c575b1 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/shims/MapInBatchInfo.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.execution.SparkPlan + +/** + * Spark-version-agnostic projection of a `MapInBatchExec` (`PythonMapInArrowExec`, + * `MapInArrowExec`, or `MapInPandasExec`) that the Comet rewrite needs. Lives outside the shims + * so the Comet planner can pattern-match on it without depending on which concrete Spark class + * was matched. + */ +case class MapInBatchInfo( + func: Expression, + output: Seq[Attribute], + child: SparkPlan, + isBarrier: Boolean, + pythonEvalType: Int) diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSQLConf.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSQLConf.scala index 0bff426c21..e809e33904 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSQLConf.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSQLConf.scala @@ -19,9 +19,18 @@ package org.apache.comet.shims +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy trait ShimSQLConf { protected val LEGACY = LegacyBehaviorPolicy.LEGACY protected val CORRECTED = LegacyBehaviorPolicy.CORRECTED + + /** + * Reads `spark.sql.execution.arrow.useLargeVarTypes`. Spark 3.4 has no typed accessor for this + * conf, so read by raw key. The conf only governs the destination Arrow IPC root width on + * Spark 4.x, so the value returned here matters only to callers that look it up explicitly. + */ + protected def arrowUseLargeVarTypes(conf: SQLConf): Boolean = + conf.getConfString("spark.sql.execution.arrow.useLargeVarTypes", "false").toBoolean } diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSQLConf.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSQLConf.scala index bdb2739460..219e0f2a2e 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSQLConf.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSQLConf.scala @@ -19,9 +19,15 @@ package org.apache.comet.shims -import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} trait ShimSQLConf { protected val LEGACY = LegacyBehaviorPolicy.LEGACY protected val CORRECTED = LegacyBehaviorPolicy.CORRECTED + + /** + * Reads `spark.sql.execution.arrow.useLargeVarTypes`. Spark 3.5 has the typed accessor; + * forward to it. + */ + protected def arrowUseLargeVarTypes(conf: SQLConf): Boolean = conf.arrowUseLargeVarTypes } diff --git a/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala new file mode 100644 index 0000000000..c0a31c6e52 --- /dev/null +++ b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Spark 3.x stub for the PyArrow UDF acceleration support. + * + * The columnar runner introduced in #4234 only targets Spark 4.0+. On Spark 3.4 / 3.5 the matchers + * return `None`, the rewrite does not fire, and vanilla Spark handles `mapInArrow` / + * `mapInPandas` unchanged. The runner factory throws; it is never called because the matchers + * always return `None`. 3.x support can be added later if there is user demand. + * + * Shared across spark-3.4 and spark-3.5 because both are identical: 3.4 lacks the modern + * `ArrowPythonRunner` constructor and `arrowUseLargeVarTypes`, and 3.5's `PythonArrowInput` + * trait has a different contract (`writeIteratorToArrowStream` one-shot vs 4.x's + * `writeNextBatchToArrowStream` batch-at-a-time), so neither version can host the columnar input + * implementation without a separate rewrite. + */ +trait ShimCometMapInBatch { + + protected def matchMapInArrow(plan: SparkPlan): Option[MapInBatchInfo] = None + + protected def matchMapInPandas(plan: SparkPlan): Option[MapInBatchInfo] = None + + /** Stub; never constructed on Spark 3.x because the matchers always return `None`. */ + protected case class RunnerInputs() + + protected def runnerInputs(pythonUDF: PythonUDF, conf: SQLConf): RunnerInputs = + throw new UnsupportedOperationException("CometMapInBatchExec is not supported on Spark 3.x") + + protected def computeArrowPython( + runnerInputs: RunnerInputs, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + pythonMetrics: Map[String, SQLMetric], + batchIter: Iterator[Iterator[ColumnarBatch]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = + throw new UnsupportedOperationException("CometMapInBatchExec is not supported on Spark 3.x") +} diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala new file mode 100644 index 0000000000..ddb73ac95c --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.TaskContext +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.CometArrowPythonRunner +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +trait ShimCometMapInBatch extends Spark4xMapInBatchSupport { + + protected def computeArrowPython( + runnerInputs: RunnerInputs, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + pythonMetrics: Map[String, SQLMetric], + batchIter: Iterator[Iterator[ColumnarBatch]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = + new CometArrowPythonRunner( + runnerInputs.chainedFunc, + evalType, + argOffsets, + schema, + runnerInputs.timeZoneId, + runnerInputs.largeVarTypes, + runnerInputs.pythonRunnerConf, + pythonMetrics, + runnerInputs.jobArtifactUUID).compute(batchIter, partitionId, context) +} diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/execution/python/CometArrowPythonRunner.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/execution/python/CometArrowPythonRunner.scala new file mode 100644 index 0000000000..63d282e8b9 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/spark/sql/execution/python/CometArrowPythonRunner.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.DataOutputStream + +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Comet's Arrow Python runner for Spark 4.0. Extends `BasePythonRunner` directly because Spark + * 4.0's `BaseArrowPythonRunner` is bound to `Iterator[InternalRow]` and mixes in + * `BasicPythonArrowInput`, so we cannot inherit from it. Wires the SQLConf-driven fields that + * `BaseArrowPythonRunner` provides. + */ +class CometArrowPythonRunner( + funcs: Seq[(ChainedPythonFunctions, Long)], + evalType: Int, + argOffsets: Array[Array[Int]], + protected override val schema: StructType, + protected override val timeZoneId: String, + protected override val largeVarTypes: Boolean, + override val workerConf: Map[String, String], + override val pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends BasePythonRunner[Iterator[ColumnarBatch], ColumnarBatch]( + funcs.map(_._1), + evalType, + argOffsets, + jobArtifactUUID, + pythonMetrics) + with CometColumnarPythonInput + with BasicPythonArrowOutput { + + override val pythonExec: String = + SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(funcs.head._1.funcs.head.pythonExec) + + override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + override val idleTimeoutSeconds: Long = SQLConf.get.pythonUDFWorkerIdleTimeoutSeconds + override val errorOnDuplicatedFieldNames: Boolean = true + override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback + override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + + override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize + require( + bufferSize >= 4, + "Pandas execution requires more than 4 bytes. Please set higher buffer. " + + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") + + override protected def writeUDF(dataOut: DataOutputStream): Unit = + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, jobArtifactUUID) +} diff --git a/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala new file mode 100644 index 0000000000..ad27b7de42 --- /dev/null +++ b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.TaskContext +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.CometArrowPythonRunner +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +trait ShimCometMapInBatch extends Spark4xMapInBatchSupport { + + protected def computeArrowPython( + runnerInputs: RunnerInputs, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + pythonMetrics: Map[String, SQLMetric], + batchIter: Iterator[Iterator[ColumnarBatch]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = + new CometArrowPythonRunner( + runnerInputs.chainedFunc, + evalType, + argOffsets, + schema, + runnerInputs.timeZoneId, + runnerInputs.largeVarTypes, + runnerInputs.pythonRunnerConf, + pythonMetrics, + runnerInputs.jobArtifactUUID, + None).compute(batchIter, partitionId, context) +} diff --git a/spark/src/main/spark-4.1/org/apache/spark/sql/execution/python/CometArrowPythonRunner.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/execution/python/CometArrowPythonRunner.scala new file mode 100644 index 0000000000..7b82b0aed8 --- /dev/null +++ b/spark/src/main/spark-4.1/org/apache/spark/sql/execution/python/CometArrowPythonRunner.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.DataOutputStream + +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Comet's Arrow Python runner for Spark 4.1. Extends `BaseArrowPythonRunner` parameterized over + * `Iterator[ColumnarBatch]` input, and supplies the columnar input via `CometColumnarPythonInput` + * instead of `BasicPythonArrowInput`. + * + * Spark 4.1's `PythonUDFRunner.writeUDFs` takes a `profiler: Option[String]` fourth argument; we + * pass `None` since Comet does not support Python profiling. + */ +class CometArrowPythonRunner( + funcs: Seq[(ChainedPythonFunctions, Long)], + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + timeZoneId: String, + largeVarTypes: Boolean, + workerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + sessionUUID: Option[String]) + extends BaseArrowPythonRunner[Iterator[ColumnarBatch], ColumnarBatch]( + funcs, + evalType, + argOffsets, + schema, + timeZoneId, + largeVarTypes, + workerConf, + pythonMetrics, + jobArtifactUUID, + sessionUUID) + with CometColumnarPythonInput + with BasicPythonArrowOutput { + + override protected def writeUDF(dataOut: DataOutputStream): Unit = + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, None) +} diff --git a/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala b/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala new file mode 100644 index 0000000000..ad27b7de42 --- /dev/null +++ b/spark/src/main/spark-4.2/org/apache/spark/sql/comet/shims/ShimCometMapInBatch.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.TaskContext +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.CometArrowPythonRunner +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +trait ShimCometMapInBatch extends Spark4xMapInBatchSupport { + + protected def computeArrowPython( + runnerInputs: RunnerInputs, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + pythonMetrics: Map[String, SQLMetric], + batchIter: Iterator[Iterator[ColumnarBatch]], + partitionId: Int, + context: TaskContext): Iterator[ColumnarBatch] = + new CometArrowPythonRunner( + runnerInputs.chainedFunc, + evalType, + argOffsets, + schema, + runnerInputs.timeZoneId, + runnerInputs.largeVarTypes, + runnerInputs.pythonRunnerConf, + pythonMetrics, + runnerInputs.jobArtifactUUID, + None).compute(batchIter, partitionId, context) +} diff --git a/spark/src/main/spark-4.2/org/apache/spark/sql/execution/python/CometArrowPythonRunner.scala b/spark/src/main/spark-4.2/org/apache/spark/sql/execution/python/CometArrowPythonRunner.scala new file mode 100644 index 0000000000..c9714ce068 --- /dev/null +++ b/spark/src/main/spark-4.2/org/apache/spark/sql/execution/python/CometArrowPythonRunner.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.DataOutputStream + +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Comet's Arrow Python runner for Spark 4.2. Spark 4.2's `BaseArrowPythonRunner` no longer + * accepts `workerConf` in its constructor; the subclass overrides `runnerConf` instead. + * `PythonUDFRunner.writeUDFs` drops the `profiler` argument compared to 4.1. + */ +class CometArrowPythonRunner( + funcs: Seq[(ChainedPythonFunctions, Long)], + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType, + timeZoneId: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + sessionUUID: Option[String]) + extends BaseArrowPythonRunner[Iterator[ColumnarBatch], ColumnarBatch]( + funcs, + evalType, + argOffsets, + schema, + timeZoneId, + largeVarTypes, + pythonMetrics, + jobArtifactUUID, + sessionUUID) + with CometColumnarPythonInput + with BasicPythonArrowOutput { + + override protected def runnerConf: Map[String, String] = + super.runnerConf ++ pythonRunnerConf + + override protected def writeUDF(dataOut: DataOutputStream): Unit = + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/ShimSQLConf.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/ShimSQLConf.scala index bdb2739460..3157889b43 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/ShimSQLConf.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/ShimSQLConf.scala @@ -19,9 +19,17 @@ package org.apache.comet.shims -import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} trait ShimSQLConf { protected val LEGACY = LegacyBehaviorPolicy.LEGACY protected val CORRECTED = LegacyBehaviorPolicy.CORRECTED + + /** + * Reads `spark.sql.execution.arrow.useLargeVarTypes`. Spark 4.x exposes a typed accessor; 3.4 + * lacks it (a 3.5 backport added it, but Comet's 3.x shim collapses both into a single string + * fallback). Forward to the accessor here so callers do not depend on which version they're + * compiled against. + */ + protected def arrowUseLargeVarTypes(conf: SQLConf): Boolean = conf.arrowUseLargeVarTypes } diff --git a/spark/src/main/spark-4.x/org/apache/spark/sql/comet/shims/Spark4xMapInBatchSupport.scala b/spark/src/main/spark-4.x/org/apache/spark/sql/comet/shims/Spark4xMapInBatchSupport.scala new file mode 100644 index 0000000000..bfb56427cf --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/spark/sql/comet/shims/Spark4xMapInBatchSupport.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.shims + +import org.apache.spark.JobArtifactSet +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInArrowExec, MapInPandasExec} +import org.apache.spark.sql.internal.SQLConf + +/** + * Shared 4.x bits for `ShimCometMapInBatch`. The matchers and `runnerInputs` helper are identical + * across 4.0/4.1/4.2; only the `ArrowPythonRunner` constructor parameter list differs per minor, + * so each minor's `ShimCometMapInBatch` provides only `computeArrowPython`. + */ +trait Spark4xMapInBatchSupport { + + protected def matchMapInArrow(plan: SparkPlan): Option[MapInBatchInfo] = + plan match { + case p: MapInArrowExec => + Some( + MapInBatchInfo( + p.func, + p.output, + p.child, + p.isBarrier, + PythonEvalType.SQL_MAP_ARROW_ITER_UDF)) + case _ => None + } + + protected def matchMapInPandas(plan: SparkPlan): Option[MapInBatchInfo] = + plan match { + case p: MapInPandasExec => + Some( + MapInBatchInfo( + p.func, + p.output, + p.child, + p.isBarrier, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)) + case _ => None + } + + /** Inputs every 4.x `ArrowPythonRunner` constructor needs in the same shape. */ + protected case class RunnerInputs( + chainedFunc: Seq[(ChainedPythonFunctions, Long)], + timeZoneId: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + jobArtifactUUID: Option[String]) + + /** + * Resolves the `SQLConf`-derived inputs the `ArrowPythonRunner` needs. Must be called on the + * driver: `SQLConf.get` reads from a thread-local `ConfigReader` that only exists on the + * driver, so dereferencing `conf.sessionLocalTimeZone` etc. from a task closure NPEs. + */ + protected def runnerInputs(pythonUDF: PythonUDF, conf: SQLConf): RunnerInputs = + RunnerInputs( + chainedFunc = Seq((ChainedPythonFunctions(Seq(pythonUDF.func)), pythonUDF.resultId.id)), + timeZoneId = conf.sessionLocalTimeZone, + largeVarTypes = conf.arrowUseLargeVarTypes, + pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf), + jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)) +} diff --git a/spark/src/main/spark-4.x/org/apache/spark/sql/execution/python/CometColumnarPythonInput.scala b/spark/src/main/spark-4.x/org/apache/spark/sql/execution/python/CometColumnarPythonInput.scala new file mode 100644 index 0000000000..cf4f324a23 --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/spark/sql/execution/python/CometColumnarPythonInput.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.DataOutputStream +import java.nio.channels.Channels + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.vector.{BaseFixedWidthVector, BaseLargeVariableWidthVector, BaseVariableWidthVector, FieldVector, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.complex.{LargeListVector, ListVector, StructVector} +import org.apache.arrow.vector.compression.{CompressionCodec, CompressionUtil, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel} +import org.apache.arrow.vector.ipc.message.MessageSerializer +import org.apache.spark.SparkException +import org.apache.spark.api.python.BasePythonRunner +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.unsafe.Platform + +import org.apache.comet.vector.CometDecodedVector + +/** + * `PythonArrowInput` implementation that streams Comet `ColumnarBatch` values to the Python + * worker as Arrow IPC. + * + * Per batch: walk the destination struct's children, allocate each child sized to match the + * corresponding Comet column, and copy each buffer with `ArrowBuf.setBytes`. The current path + * does two copies per batch: this one (Comet vector buffers → destination IPC root), and a + * second one inside `VectorUnloader` / `MessageSerializer.serialize` (root → pipe). The pipe + * write is structural — Spark's transport to Python is fork + pipe + Arrow IPC, so the buffer + * bytes must reach the pipe at least once. Dropping the first copy by serialising directly + * from Comet's vectors is tracked in #4294; once done, the path is at the single-copy floor. + * + * The cross-allocator constraint on `TransferPair` is independent of the copy count: even after + * #4294, true zero-copy at the JVM boundary is blocked because Comet's source `FieldVector`s + * are imported from native via Arrow C Data Interface (their buffers route `release` through + * FFI), while Spark's destination IPC root is a child of `ArrowUtils.rootAllocator`. The two + * reference managers cannot share buffers. + */ +private[python] trait CometColumnarPythonInput extends PythonArrowInput[Iterator[ColumnarBatch]] { + self: BasePythonRunner[Iterator[ColumnarBatch], _] => + + private var currentGroup: Iterator[ColumnarBatch] = _ + + // Constructed once per task: `root` (the trait's persistent destination IPC root) and + // `cometCodec` are both stable across the partition. `getRecordBatch` reads the current + // contents of `root.getFieldVectors` on every call, so re-using the unloader is safe. + private lazy val batchUnloader: VectorUnloader = + new VectorUnloader(root, /* includeNullCount */ true, cometCodec, /* alignBuffers */ true) + + // Read the codec name via raw config key. Spark 4.0.x has no `SQLConf.arrowCompressionCodec` + // accessor at all (it was added after the 4.0 line was cut), so a typed `ShimSQLConf` + // forwarder would still need a stringly-typed fallback for the 4.0 build. The codec instances + // are obtained through `CompressionCodec.Factory` (arrow-vector) rather than importing the + // concrete `Lz4CompressionCodec` / `ZstdCompressionCodec` from the separate + // arrow-compression artifact, which Comet does not depend on. + private lazy val cometCodec: CompressionCodec = { + val factory = CompressionCodec.Factory.INSTANCE + SQLConf.get.getConfString("spark.sql.execution.arrow.compression.codec", "none") match { + case "none" => NoCompressionCodec.INSTANCE + case "lz4" => + factory.createCodec(CompressionUtil.CodecType.LZ4_FRAME) + case "zstd" => + val level = + SQLConf.get.getConfString("spark.sql.execution.arrow.compression.zstd.level", "3").toInt + factory.createCodec(CompressionUtil.CodecType.ZSTD, level) + case other => + throw SparkException.internalError( + s"Unsupported Arrow compression codec: $other. Supported values: none, lz4, zstd") + } + } + + override protected def writeNextBatchToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[Iterator[ColumnarBatch]]): Boolean = { + + while (currentGroup == null || !currentGroup.hasNext) { + if (!inputIterator.hasNext) { + super[PythonArrowInput].close() + return false + } + currentGroup = inputIterator.next() + } + + val cometBatch = currentGroup.next() + val startData = dataOut.size() + val structVec = root.getVector(0).asInstanceOf[StructVector] + + var i = 0 + while (i < cometBatch.numCols()) { + val src = + cometBatch + .column(i) + .asInstanceOf[CometDecodedVector] + .getValueVector + .asInstanceOf[FieldVector] + val dst = structVec.getChildByOrdinal(i).asInstanceOf[FieldVector] + copyVector(src, dst) + i += 1 + } + val numRows = cometBatch.numRows() + structVec.setValueCount(numRows) + // Mark every row in the struct as non-null (all-1 validity bits). The struct validity + // buffer is freshly allocated (or cleared) and zero-initialised, so without this step + // Python would see an all-null struct column and return null for every output row. + val validityBytes = (numRows + 7) / 8 + Platform.setMemory(structVec.getValidityBuffer.memoryAddress(), 0xff.toByte, validityBytes) + root.setRowCount(numRows) + + val recordBatch = batchUnloader.getRecordBatch + try { + val writeChannel = new WriteChannel(Channels.newChannel(dataOut)) + MessageSerializer.serialize(writeChannel, recordBatch) + } finally { + recordBatch.close() + } + + pythonMetrics("pythonDataSent") += dataOut.size() - startData + true + } + + /** + * Copy a Comet column into the destination FieldVector. Walks both trees in lockstep: sizes + * each destination node from the source, copies every buffer with `ArrowBuf.setBytes`, then + * sets value counts bottom-up so `setValueCount` does not rewrite the offset bytes we just + * copied. + */ + private def copyVector(src: FieldVector, dst: FieldVector): Unit = { + val valueCount = src.getValueCount + + dst match { + case bfwv: BaseFixedWidthVector => + bfwv.allocateNew(valueCount) + case bvwv: BaseVariableWidthVector => + bvwv.allocateNew(src.getDataBuffer.readableBytes, valueCount) + case blvwv: BaseLargeVariableWidthVector => + blvwv.allocateNew(src.getDataBuffer.readableBytes, valueCount) + case _ => + dst.setInitialCapacity(valueCount) + dst.allocateNew() + } + + val srcBufs = src.getFieldBuffers + val dstBufs = dst.getFieldBuffers + require( + srcBufs.size == dstBufs.size, + s"buffer count mismatch for ${dst.getField}: src=${srcBufs.size}, dst=${dstBufs.size}") + var b = 0 + while (b < srcBufs.size) { + val s = srcBufs.get(b) + dstBufs.get(b).setBytes(0, s, 0, s.readableBytes) + b += 1 + } + + val srcChildren = src.getChildrenFromFields + val dstChildren = dst.getChildrenFromFields + require( + srcChildren.size == dstChildren.size, + s"child count mismatch for ${dst.getField}: src=${srcChildren.size}, dst=${dstChildren.size}") + srcChildren.asScala.zip(dstChildren.asScala).foreach { case (sc, dc) => + copyVector(sc.asInstanceOf[FieldVector], dc.asInstanceOf[FieldVector]) + } + + // For vectors that fill offset-buffer "holes" in setValueCount (variable-width and list + // types), set lastSet = vc - 1 first so fillHoles is a no-op and the already-copied + // offset bytes are preserved. + dst match { + case v: BaseVariableWidthVector => v.setLastSet(valueCount - 1) + case v: BaseLargeVariableWidthVector => v.setLastSet(valueCount - 1) + case v: ListVector => v.setLastSet(valueCount - 1) + case v: LargeListVector => v.setLastSet(valueCount - 1) + case _ => + } + dst.setValueCount(valueCount) + } +} diff --git a/spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py b/spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py new file mode 100644 index 0000000000..08f2c6540f --- /dev/null +++ b/spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +End-to-end wall-clock benchmark for Comet's PyArrow UDF acceleration. + +Requires PySpark 4.0.1+ (Comet's columnar runner targets Spark 4.0+ only; +3.5 and 3.4 are documented no-ops). + +Times `df.mapInArrow(passthrough, schema).count()` and the equivalent +`mapInPandas` query with `spark.comet.exec.pyarrowUdf.enabled` set +to false (vanilla Spark path) and true (Comet's optimized path). Both +modes run the same Python worker, so the measured delta covers what the +optimization actually changes for users: + + * vanilla: CometScan -> ColumnarToRow + UnsafeProjection -> ArrowPythonRunner + (per-row InternalRow.getXXX() loop inside ArrowWriter.write) + * optimized: CometScan -> CometMapInBatchExec -> CometArrowPythonRunner + (per-buffer Unsafe.copyMemory from Comet's vectors into the + runner's persistent VectorSchemaRoot; no row materialization) + +Results are wall-clock seconds, so they include Python interpreter, +Arrow IPC, and downstream count() costs. That's intentional: the +optimization's user-visible value is what fraction of end-to-end time +it shaves off, not the JVM-side delta in isolation. + +Caveat: the workload here is `passthrough_udf` + `count()` on `local[2]`, +so most of the wall time is Spark's Python fork/IPC overhead with very +little real Python work. Real UDFs (PyArrow compute, pandas ops, model +inference) increase the per-row Python cost, which dilutes the JVM-side +savings and shrinks the speedup ratio relative to what you see here. + +Usage: + # Build Comet (release for representative numbers): + make release + + pip install pyspark==3.5.8 pyarrow pandas + + python3 spark/src/test/resources/pyspark/benchmark_pyarrow_udf.py + +Override defaults via environment variables: + COMET_JAR=/path/to/comet.jar path to the Comet jar + BENCHMARK_ROWS=2000000 rows per run + BENCHMARK_WARMUP=2 warmup iterations per case + BENCHMARK_ITERS=5 measured iterations per case +""" + +import contextlib +import os +import statistics +import sys +import tempfile +import time + +from pyspark.sql import SparkSession + +sys.path.insert(0, os.path.dirname(__file__)) +from conftest import resolve_comet_jar + + +def _build_spark() -> SparkSession: + jar = resolve_comet_jar() + os.environ["PYSPARK_SUBMIT_ARGS"] = ( + f"--jars {jar} --driver-class-path {jar} pyspark-shell" + ) + return ( + SparkSession.builder.master("local[2]") + .appName("comet-pyarrow-udf-benchmark") + .config("spark.plugins", "org.apache.spark.CometPlugin") + .config("spark.comet.enabled", "true") + .config("spark.comet.exec.enabled", "true") + .config("spark.memory.offHeap.enabled", "true") + .config("spark.memory.offHeap.size", "4g") + .config("spark.driver.memory", "4g") + # Pin AQE off so the explain output and plan structure are stable + # across iterations. AQE doesn't change the optimization's behavior; + # it just makes plan inspection harder. + .config("spark.sql.adaptive.enabled", "false") + .getOrCreate() + ) + + +def _passthrough_arrow(iterator): + for batch in iterator: + yield batch + + +def _passthrough_pandas(iterator): + for pdf in iterator: + yield pdf + + +def _narrow_primitives(spark: SparkSession, n: int): + return spark.range(n).selectExpr( + "id as id_long", + "cast(id as int) as id_int", + "cast(id as double) as id_double", + ) + + +def _mixed_with_strings(spark: SparkSession, n: int): + return spark.range(n).selectExpr( + "id as id_long", + "cast(id as int) as id_int", + "cast(id as double) as id_double", + "concat('row_', cast(id as string)) as id_str", + "cast(id % 2 as boolean) as id_bool", + ) + + +def _wide_rows(spark: SparkSession, n: int): + types = ["int", "long", "double"] + cols = [ + f"cast(id + {i} as {types[i % len(types)]}) as col_{i}" for i in range(50) + ] + return spark.range(n).selectExpr(*cols) + + +WORKLOADS = [ + ("narrow primitives", _narrow_primitives), + ("mixed with strings", _mixed_with_strings), + ("wide rows (50 cols)", _wide_rows), +] + + +@contextlib.contextmanager +def _temp_parquet(spark: SparkSession, build_df, n: int): + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "src.parquet") + build_df(spark, n).write.parquet(path) + yield path + + +def _time_run(spark: SparkSession, parquet_path: str, accelerate: bool, api: str) -> float: + spark.conf.set( + "spark.comet.exec.pyarrowUdf.enabled", + "true" if accelerate else "false", + ) + df = spark.read.parquet(parquet_path) + schema = df.schema + if api == "mapInArrow": + df = df.mapInArrow(_passthrough_arrow, schema) + else: + df = df.mapInPandas(_passthrough_pandas, schema) + t0 = time.perf_counter() + df.count() + return time.perf_counter() - t0 + + +def main() -> None: + rows = int(os.environ.get("BENCHMARK_ROWS", 1024 * 1024)) + warmup = int(os.environ.get("BENCHMARK_WARMUP", 2)) + iters = int(os.environ.get("BENCHMARK_ITERS", 5)) + + spark = _build_spark() + spark.sparkContext.setLogLevel("WARN") + + print(f"\nrows per run: {rows:,}") + print(f"warmup iters: {warmup}, measured iters: {iters}") + print(f"jar: {resolve_comet_jar()}\n") + + header = " {:<14} {:<10} {:>10} {:>10} {:>10} {:>13} {:>9}".format( + "api", "mode", "min (s)", "median (s)", "max (s)", "rows/s", "speedup" + ) + print(header) + print(" " + "-" * (len(header) - 2)) + + for name, build_df in WORKLOADS: + print(f"\n=== {name} ===") + with _temp_parquet(spark, build_df, rows) as parquet_path: + for api in ("mapInArrow", "mapInPandas"): + samples_by_mode = {} + for mode, accelerate in (("vanilla", False), ("optimized", True)): + for _ in range(warmup): + _time_run(spark, parquet_path, accelerate, api) + samples = [ + _time_run(spark, parquet_path, accelerate, api) + for _ in range(iters) + ] + samples_by_mode[mode] = samples + median = statistics.median(samples) + speedup = "" + if mode == "optimized": + speedup = "{:.2f}x".format( + statistics.median(samples_by_mode["vanilla"]) / median + ) + print( + " {:<14} {:<10} {:>10} {:>10} {:>10} {:>13} {:>9}".format( + api, + mode, + "{:.3f}".format(min(samples)), + "{:.3f}".format(median), + "{:.3f}".format(max(samples)), + "{:,.0f}".format(rows / median), + speedup, + ) + ) + + spark.stop() + + +if __name__ == "__main__": + main() diff --git a/spark/src/test/resources/pyspark/conftest.py b/spark/src/test/resources/pyspark/conftest.py new file mode 100644 index 0000000000..35d6d85191 --- /dev/null +++ b/spark/src/test/resources/pyspark/conftest.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Shared helpers for the pytest modules under this directory and for the +benchmark scripts that import them. + +`resolve_comet_jar` returns the path to the Comet jar a Spark session needs. +Resolution order: the `COMET_JAR` env var (taken verbatim if it points at a +file, expanded as a glob otherwise), then `/spark/target` matched against +the installed pyspark major.minor version. +""" + +import glob +import os + + +REPO_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..") +) + + +def resolve_comet_jar() -> str: + explicit = os.environ.get("COMET_JAR") + if explicit: + if any(ch in explicit for ch in "*?["): + matches = sorted(glob.glob(explicit)) + if not matches: + raise FileNotFoundError( + f"COMET_JAR pattern matched nothing: {explicit}" + ) + return matches[-1] + return explicit + + # Pick the jar that matches the installed pyspark major.minor version. The + # Comet jars are published per Spark version (e.g. + # comet-spark-spark3.5_2.12-*.jar); using the wrong one yields + # ClassNotFoundException on Scala stdlib classes. + import pyspark + + major_minor = ".".join(pyspark.__version__.split(".")[:2]) + spark_tag = f"spark{major_minor}" + scala_tag = "_2.12" if major_minor.startswith("3.") else "_2.13" + pattern = os.path.join( + REPO_ROOT, + f"spark/target/comet-spark-{spark_tag}{scala_tag}-*-SNAPSHOT.jar", + ) + candidates = [ + m + for m in sorted(glob.glob(pattern)) + if "sources" not in os.path.basename(m) and "tests" not in os.path.basename(m) + ] + if not candidates: + raise FileNotFoundError( + "Comet jar not found. Set COMET_JAR or run `make release`. " + f"Looked under {pattern}." + ) + return candidates[-1] diff --git a/spark/src/test/resources/pyspark/test_pyarrow_udf.py b/spark/src/test/resources/pyspark/test_pyarrow_udf.py new file mode 100644 index 0000000000..6347411cb7 --- /dev/null +++ b/spark/src/test/resources/pyspark/test_pyarrow_udf.py @@ -0,0 +1,1155 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pytest-driven integration tests for Comet's PyArrow UDF acceleration. + +Each test runs against two execution paths: + - "accelerated": spark.comet.exec.pyarrowUdf.enabled=true + (plan should contain CometMapInBatch and no ColumnarToRow) + - "fallback": spark.comet.exec.pyarrowUdf.enabled=false + (plan should contain vanilla PythonMapInArrow / MapInArrow) + +Usage: + # Build Comet first: + make + + # Then either let the test discover the jar from spark/target, or pass it + # explicitly via COMET_JAR: + export COMET_JAR=$PWD/spark/target/comet-spark-spark3.5_2.12-0.16.0-SNAPSHOT.jar + + pip install pyspark==3.5.8 pyarrow pandas pytest + pytest -v spark/src/test/resources/pyspark/test_pyarrow_udf.py +""" + +import datetime as dt +import os +from decimal import Decimal + +import pyarrow as pa +import pytest +from pyspark.sql import SparkSession, types as T + +from conftest import resolve_comet_jar + + +@pytest.fixture(scope="session") +def spark(): + jar = resolve_comet_jar() + # PYSPARK_SUBMIT_ARGS is consumed when pyspark launches its JVM. Setting + # --jars puts the Comet jar on both driver and executor classpaths so the + # CometPlugin can be loaded. + os.environ["PYSPARK_SUBMIT_ARGS"] = ( + f"--jars {jar} --driver-class-path {jar} pyspark-shell" + ) + session = ( + SparkSession.builder.master("local[2]") + .appName("comet-pyarrow-udf-tests") + .config("spark.plugins", "org.apache.spark.CometPlugin") + .config("spark.comet.enabled", "true") + .config("spark.comet.exec.enabled", "true") + # spark.comet.exec.shuffle.enabled defaults to true, and + # CometSparkSessionExtensions.isCometLoaded refuses to register Comet's rules + # at all when shuffle is on but spark.shuffle.manager is not the Comet manager. + # These tests do not need Comet shuffle, so disable it explicitly to keep + # Comet's scan and exec rules active without configuring shuffle. + .config("spark.comet.exec.shuffle.enabled", "false") + .config("spark.memory.offHeap.enabled", "true") + .config("spark.memory.offHeap.size", "2g") + .getOrCreate() + ) + try: + yield session + finally: + session.stop() + + +@pytest.fixture(params=[True, False], ids=["accelerated", "fallback"]) +def accelerated(request, spark) -> bool: + spark.conf.set( + "spark.comet.exec.pyarrowUdf.enabled", + "true" if request.param else "false", + ) + return request.param + + +def _executed_plan(df) -> str: + return df._jdf.queryExecution().executedPlan().toString() + + +def _assert_plan_matches_mode( + plan: str, accelerated: bool, vanilla_node: str = "MapInArrow" +) -> None: + if accelerated: + assert "CometMapInBatch" in plan, ( + f"expected CometMapInBatch in accelerated plan, got:\n{plan}" + ) + assert "ColumnarToRow" not in plan, ( + f"unexpected ColumnarToRow in accelerated plan:\n{plan}" + ) + else: + assert "CometMapInBatch" not in plan, ( + f"unexpected CometMapInBatch in fallback plan:\n{plan}" + ) + assert vanilla_node in plan, ( + f"expected {vanilla_node} in fallback plan, got:\n{plan}" + ) + + +def test_map_in_arrow_doubles_value(spark, tmp_path, accelerated): + data = [(i, float(i * 1.5), f"name_{i}") for i in range(100)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value", "name"]).write.parquet(src) + + def double_value(iterator): + for batch in iterator: + pdf = batch.to_pandas() + pdf["value"] = pdf["value"] * 2 + yield pa.RecordBatch.from_pandas(pdf) + + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + T.StructField("name", T.StringType()), + ] + ) + result_df = spark.read.parquet(src).mapInArrow(double_value, schema) + + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + rows = result_df.orderBy("id").collect() + assert len(rows) == len(data) + for row, original in zip(rows, data): + assert row["id"] == original[0] + assert abs(row["value"] - original[1] * 2) < 1e-6 + assert row["name"] == original[2] + + +# All other tests use the default `vanilla_node="MapInArrow"`. The mapInPandas tests below +# pass `MapInPandas` explicitly. The substring is the same on Spark 3.5 (PythonMapInArrowExec) +# and Spark 4.x (MapInArrowExec) since the latter is a substring of the former. + + +def test_map_in_arrow_changes_schema(spark, tmp_path, accelerated): + data = [(i, float(i)) for i in range(50)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value"]).write.parquet(src) + + def add_computed_column(iterator): + for batch in iterator: + pdf = batch.to_pandas() + pdf["squared"] = pdf["value"] ** 2 + pdf["label"] = pdf["id"].apply(lambda x: f"item_{x}") + yield pa.RecordBatch.from_pandas(pdf) + + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + T.StructField("squared", T.DoubleType()), + T.StructField("label", T.StringType()), + ] + ) + result_df = spark.read.parquet(src).mapInArrow(add_computed_column, schema) + + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + rows = result_df.orderBy("id").collect() + assert len(rows) == 50 + for i, row in enumerate(rows): + assert abs(row["squared"] - float(i) ** 2) < 1e-6 + assert row["label"] == f"item_{i}" + + +def test_map_in_pandas_doubles_value(spark, tmp_path, accelerated): + data = [(i, float(i * 1.5)) for i in range(100)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value"]).write.parquet(src) + + def double_value(iterator): + for pdf in iterator: + pdf = pdf.copy() + pdf["value"] = pdf["value"] * 2 + yield pdf + + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + result_df = spark.read.parquet(src).mapInPandas(double_value, schema) + + _assert_plan_matches_mode( + _executed_plan(result_df), accelerated, vanilla_node="MapInPandas" + ) + + rows = result_df.orderBy("id").collect() + assert len(rows) == len(data) + for row, original in zip(rows, data): + assert row["id"] == original[0] + assert abs(row["value"] - original[1] * 2) < 1e-6 + + +def test_map_in_pandas_changes_schema(spark, tmp_path, accelerated): + data = [(i, float(i)) for i in range(50)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, ["id", "value"]).write.parquet(src) + + def add_squared(iterator): + for pdf in iterator: + pdf = pdf.copy() + pdf["squared"] = pdf["value"] ** 2 + yield pdf + + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + T.StructField("squared", T.DoubleType()), + ] + ) + result_df = spark.read.parquet(src).mapInPandas(add_squared, schema) + + _assert_plan_matches_mode( + _executed_plan(result_df), accelerated, vanilla_node="MapInPandas" + ) + + rows = result_df.orderBy("id").collect() + assert len(rows) == 50 + for i, row in enumerate(rows): + assert abs(row["squared"] - float(i) ** 2) < 1e-6 + + +def test_map_in_arrow_preserves_nulls(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("name", T.StringType()), + ] + ) + rows = [ + (1, "a"), + (2, None), + (None, "c"), + (None, None), + (5, "e"), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + # Pure Arrow passthrough so nulls survive without a pandas roundtrip + # (pandas would coerce null longs to NaN floats). + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["name"]) for r in result_df.collect()} + assert out == set(rows) + + +def test_map_in_arrow_empty_input(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + src = str(tmp_path / "src.parquet") + spark.createDataFrame([(1, 1.0), (2, 2.0)], schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + # Filter all rows out so the operator sees an empty stream from CometScan. + result_df = ( + spark.read.parquet(src).where("id < 0").mapInArrow(passthrough, schema_in) + ) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + assert result_df.count() == 0 + + +def test_map_in_arrow_python_exception_propagates(spark, tmp_path, accelerated): + schema_in = T.StructType([T.StructField("id", T.LongType())]) + data = [(i,) for i in range(10)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(data, schema_in).write.parquet(src) + + sentinel = "boom-from-pyarrow-udf" + + def boom(iterator): + for _batch in iterator: + raise ValueError(sentinel) + # Unreachable, but mapInArrow requires the callable to be a generator. + yield # pragma: no cover + + result_df = spark.read.parquet(src).mapInArrow(boom, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + with pytest.raises(Exception) as exc_info: + result_df.collect() + assert sentinel in str(exc_info.value), ( + f"expected sentinel {sentinel!r} in exception, got: {exc_info.value}" + ) + + +def test_map_in_arrow_decimal_type(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("amount", T.DecimalType(18, 6)), + ] + ) + rows = [ + (1, Decimal("123.456789")), + (2, Decimal("0.000001")), + (3, Decimal("-99999999.999999")), + (4, None), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["amount"]) for r in result_df.collect()} + assert out == set(rows) + + +@pytest.mark.parametrize( + "precision,scale", + [ + (1, 0), + (9, 0), + (9, 4), + (17, 8), + (18, 0), + (18, 18), + (19, 0), + (28, 14), + (38, 0), + (38, 18), + (38, 38), + ], +) +def test_map_in_arrow_decimal_precision_sweep( + spark, tmp_path, accelerated, precision, scale +): + """ + Spark's `BaseFixedWidthVector` handles short decimals (precision <= 18, long-backed) and long + decimals (precision >= 19, 16-byte `FixedSizeBinary`) on different code paths. The 18/19 + boundary is where buffer-width assumptions in `copyVector` can hide bugs. Sweep over + representative precisions and scale extremes (0, half, max). + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("amount", T.DecimalType(precision, scale)), + ] + ) + integer_digits = precision - scale + abs_int = (10**integer_digits - 1) if integer_digits > 0 else 0 + abs_frac = (10**scale - 1) if scale > 0 else 0 + largest = Decimal(f"{abs_int}.{abs_frac:0{scale}d}") if scale else Decimal(abs_int) + rows = [ + (1, Decimal(0)), + (2, largest), + (3, -largest), + (4, None), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["amount"]) for r in result_df.collect()} + assert out == set(rows) + + +@pytest.mark.parametrize("null_fraction", [0.0, 0.01, 0.5, 0.99, 1.0]) +def test_map_in_arrow_null_density_sweep( + spark, tmp_path, accelerated, null_fraction +): + """ + Validity-buffer memcpy is where Arrow Java vector copies historically break. Sweep null + density across the corner cases: all-non-null, sparse-null, half-null, sparse-non-null, + all-null. Catches off-by-one in validity packing and edge cases where source/destination + null counts diverge. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.LongType()), + ] + ) + n = 256 + rows = [ + (i, None if (i * 9973) % 100 < int(null_fraction * 100) else i * 2) + for i in range(n) + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = sorted((r["id"], r["value"]) for r in result_df.collect()) + assert out == sorted(rows) + + +def test_map_in_arrow_multi_batch_per_partition(spark, tmp_path, accelerated): + """ + Force many small batches in a single partition so the writer/unloader exercises the + persistent destination IPC root over multiple batches. Catches buffer-reuse bugs and + variable-width data-buffer growth across batches that single-batch tests miss. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("s", T.StringType()), + ] + ) + n = 4000 + rows = [(i, f"row_{i}" if i % 7 != 0 else None) for i in range(n)] + src = str(tmp_path / "src.parquet") + # Single partition; small arrow batch limit forces ~250 batches per partition. + spark.createDataFrame(rows, schema_in).coalesce(1).write.parquet(src) + + prev_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch") + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "16") + try: + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = sorted((r["id"], r["s"]) for r in result_df.collect()) + assert out == sorted(rows) + finally: + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", prev_records) + + +def test_map_in_arrow_wide_schema(spark, tmp_path, accelerated): + """ + 50-column mixed-type schema. The bulk-copy path walks a flattened addresses[] array indexed + across the whole vector tree; off-by-one in flattening logic surfaces at depth * width. + """ + fields = [T.StructField("id", T.LongType())] + for i in range(15): + fields.append(T.StructField(f"i{i}", T.IntegerType())) + for i in range(15): + fields.append(T.StructField(f"d{i}", T.DoubleType())) + for i in range(15): + fields.append(T.StructField(f"s{i}", T.StringType())) + for i in range(4): + fields.append(T.StructField(f"b{i}", T.BooleanType())) + assert len(fields) == 50 + schema_in = T.StructType(fields) + + rows = [] + for i in range(60): + row = [i] + row += [i + k if k % 3 != 0 else None for k in range(15)] + row += [float(i + k) * 0.5 if k % 4 != 0 else None for k in range(15)] + row += [f"s{i}_{k}" if k % 5 != 0 else None for k in range(15)] + row += [bool((i + k) % 2) for k in range(4)] + rows.append(tuple(row)) + + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = sorted(tuple(r[name] for name in schema_in.names) for r in result_df.collect()) + assert out == sorted(rows) + + +def test_map_in_arrow_zero_row_batch_in_stream(spark, tmp_path, accelerated): + """ + A non-empty stream that contains a 0-row batch mid-stream. The existing empty-input test + filters everything out so the operator sees zero batches; this one keeps later batches so + the writer must handle a 0-row batch and continue. setValueCount(0) + validity buffer + sizing are the candidates that can break here. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.LongType()), + ] + ) + rows = [(i, i * 3) for i in range(50)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).coalesce(1).write.parquet(src) + + def emit_with_empty(iterator): + for batch in iterator: + # Yield an empty record batch first, then the real one. + yield batch.slice(0, 0) + yield batch + + result_df = spark.read.parquet(src).mapInArrow(emit_with_empty, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = sorted((r["id"], r["value"]) for r in result_df.collect()) + assert out == sorted(rows) + + +def test_map_in_arrow_transforming_array(spark, tmp_path, accelerated): + """ + Mutating UDF over a complex type: reverse each array. Catches symmetric encode/decode + mistakes that a passthrough UDF would invert and hide. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("nums", T.ArrayType(T.IntegerType())), + ] + ) + rows = [ + (1, [1, 2, 3, 4]), + (2, [None, 5, None]), + (3, []), + (4, None), + (5, [42]), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def reverse_arrays(iterator): + for batch in iterator: + pdf = batch.to_pandas() + pdf["nums"] = pdf["nums"].apply( + lambda lst: list(reversed(lst)) if lst is not None else None + ) + yield pa.RecordBatch.from_pandas(pdf) + + result_df = spark.read.parquet(src).mapInArrow(reverse_arrays, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + def _norm(row): + nums = row["nums"] + return (row["id"], None if nums is None else tuple(nums)) + + out = {_norm(r) for r in result_df.collect()} + expected = set() + for id_, nums in rows: + rev = None if nums is None else tuple(reversed(nums)) + expected.add((id_, rev)) + assert out == expected + + +def test_map_in_arrow_date_and_timestamp(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("d", T.DateType()), + T.StructField("ts", T.TimestampType()), + ] + ) + rows = [ + (1, dt.date(2024, 1, 1), dt.datetime(2024, 1, 1, 12, 30, 45)), + (2, dt.date(1999, 12, 31), dt.datetime(2000, 6, 15, 0, 0, 0)), + (3, None, None), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["d"], r["ts"]) for r in result_df.collect()} + assert out == set(rows) + + +def test_map_in_arrow_array_and_struct(spark, tmp_path, accelerated): + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("nums", T.ArrayType(T.IntegerType())), + T.StructField( + "addr", + T.StructType( + [ + T.StructField("city", T.StringType()), + T.StructField("zip", T.IntegerType()), + ] + ), + ), + ] + ) + rows = [ + (1, [1, 2, 3], ("Berlin", 10115)), + (2, [], ("NYC", 10001)), + (3, None, None), + (4, [None, 5], ("Tokyo", None)), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + def _normalize(row): + nums = tuple(row["nums"]) if row["nums"] is not None else None + addr = row["addr"] + addr_tuple = (addr["city"], addr["zip"]) if addr is not None else None + return (row["id"], nums, addr_tuple) + + out = {_normalize(r) for r in result_df.collect()} + expected = { + (r[0], tuple(r[1]) if r[1] is not None else None, r[2]) for r in rows + } + assert out == expected + + +def test_map_in_arrow_numeric_scalars(spark, tmp_path, accelerated): + """ + Covers the BaseFixedWidthVector branch in CometColumnarPythonInput.copyVector for + every fixed-width primitive Comet's scan supports beyond the long/double/int already + exercised by other tests: boolean, byte, short, float. Each has a distinct buffer + size, and the validity bit handling is independent per column. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("b", T.BooleanType()), + T.StructField("tiny", T.ByteType()), + T.StructField("small", T.ShortType()), + T.StructField("flt", T.FloatType()), + ] + ) + rows = [ + (1, True, 1, 1000, 1.5), + (2, False, -128, -32768, -3.25), + (3, True, 127, 32767, float("inf")), + (4, None, None, None, None), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["b"], r["tiny"], r["small"], r["flt"]) for r in result_df.collect()} + assert out == set(rows) + + +def test_map_in_arrow_binary_type(spark, tmp_path, accelerated): + """ + BinaryType is the BaseVariableWidthVector path with non-string content. StringType + already exercises that path for utf-8 data; binary covers the case where the data + buffer can hold arbitrary bytes (including null bytes mid-string). + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("payload", T.BinaryType()), + ] + ) + rows = [ + (1, b"\x00\x01\x02\x03"), + (2, b""), + (3, b"\xff" * 64), + (4, None), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], bytes(r["payload"]) if r["payload"] is not None else None) + for r in result_df.collect()} + expected = set(rows) + assert out == expected + + +def test_map_in_arrow_timestamp_ntz(spark, tmp_path, accelerated): + """ + TimestampNTZType is a separate Arrow type from TimestampType (no timezone) and goes + through a different ArrowType.Timestamp(..., tz=None) on the wire. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("ts_ntz", T.TimestampNTZType()), + ] + ) + rows = [ + (1, dt.datetime(2024, 1, 1, 12, 30, 45)), + (2, dt.datetime(1970, 1, 1, 0, 0, 0)), + (3, None), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = {(r["id"], r["ts_ntz"]) for r in result_df.collect()} + assert out == set(rows) + + +def test_map_in_arrow_map_type(spark, tmp_path, accelerated): + """ + MapType is encoded in Arrow as a List> with extra metadata. The + buffer layout (offsets + struct child + key/value children) is distinct from a plain + list, and CometMapVector is a separate vector class from CometListVector. Without + this test the recursive copy path through map-typed columns is unexercised. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField( + "attrs", T.MapType(T.StringType(), T.IntegerType(), valueContainsNull=True) + ), + ] + ) + rows = [ + (1, {"a": 1, "b": 2}), + (2, {}), + (3, None), + (4, {"only": None}), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + def _normalize(row): + attrs = row["attrs"] + attrs_norm = ( + tuple(sorted(attrs.items(), key=lambda kv: kv[0])) + if attrs is not None + else None + ) + return (row["id"], attrs_norm) + + out = {_normalize(r) for r in result_df.collect()} + expected = { + ( + r[0], + tuple(sorted(r[1].items(), key=lambda kv: kv[0])) if r[1] is not None else None, + ) + for r in rows + } + assert out == expected + + +def test_map_in_arrow_deeply_nested(spark, tmp_path, accelerated): + """ + Exercises the recursive descent in CometColumnarPythonInput.copyVector at depth > 1, + in every nesting combination: array-of-array, array-of-struct, struct-of-array, + struct-of-struct. Single-level nesting is covered by test_map_in_arrow_array_and_struct; + the bug surface here is that setLastSet / setValueCount must be applied bottom-up + correctly at every level. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("matrix", T.ArrayType(T.ArrayType(T.IntegerType()))), + T.StructField( + "people", + T.ArrayType( + T.StructType( + [ + T.StructField("name", T.StringType()), + T.StructField("age", T.IntegerType()), + ] + ) + ), + ), + T.StructField( + "config", + T.StructType( + [ + T.StructField("flags", T.ArrayType(T.StringType())), + T.StructField( + "limits", + T.StructType( + [ + T.StructField("min", T.IntegerType()), + T.StructField("max", T.IntegerType()), + ] + ), + ), + ] + ), + ), + ] + ) + rows = [ + ( + 1, + [[1, 2], [3, 4, 5]], + [("alice", 30), ("bob", 25)], + (["x", "y"], (0, 100)), + ), + ( + 2, + [[], [None, 7]], + [("solo", None)], + ([], (None, None)), + ), + (3, None, None, None), + (4, [None, [9]], [None, ("ghost", 0)], (None, None)), + ] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + def _norm_array(a): + return tuple(a) if a is not None else None + + def _norm_matrix(m): + return tuple(_norm_array(inner) for inner in m) if m is not None else None + + def _norm_people(p): + if p is None: + return None + return tuple( + (item["name"], item["age"]) if item is not None else None for item in p + ) + + def _norm_config(c): + if c is None: + return None + flags = _norm_array(c["flags"]) + limits = c["limits"] + limits_norm = (limits["min"], limits["max"]) if limits is not None else None + return (flags, limits_norm) + + def _norm_row(r): + return ( + r["id"], + _norm_matrix(r["matrix"]), + _norm_people(r["people"]), + _norm_config(r["config"]), + ) + + def _norm_input_people(p): + if p is None: + return None + return tuple(item if item is not None else None for item in p) + + def _norm_input_config(c): + if c is None: + return None + flags, limits = c + return (_norm_array(flags), limits) + + out = {_norm_row(r) for r in result_df.collect()} + expected = { + ( + r[0], + _norm_matrix(r[1]), + _norm_input_people(r[2]), + _norm_input_config(r[3]), + ) + for r in rows + } + assert out == expected + + +def test_map_in_arrow_falls_back_when_use_large_var_types(spark, tmp_path): + """ + `spark.sql.execution.arrow.useLargeVarTypes=true` widens StringType / BinaryType to + LargeUtf8 / LargeBinary in the destination IPC root (8-byte offsets). Comet's source + vectors always use 4-byte offsets; CometColumnarPythonInput.copyVector does a raw + setBytes per buffer and would corrupt the offset buffer in this configuration. + EliminateRedundantTransitions must skip the rewrite in that case so vanilla Spark + handles the operation. This test does not use the `accelerated` fixture: it sets + pyarrowUdf.enabled=true AND useLargeVarTypes=true and asserts the plan still falls + back to vanilla MapInArrow. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("name", T.StringType()), + ] + ) + rows = [(i, f"name_{i}") for i in range(20)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + prev_pyarrow = spark.conf.get("spark.comet.exec.pyarrowUdf.enabled", "false") + prev_large = spark.conf.get("spark.sql.execution.arrow.useLargeVarTypes", "false") + spark.conf.set("spark.comet.exec.pyarrowUdf.enabled", "true") + spark.conf.set("spark.sql.execution.arrow.useLargeVarTypes", "true") + try: + result_df = spark.read.parquet(src).mapInArrow(passthrough, schema_in) + plan = _executed_plan(result_df) + assert "CometMapInBatch" not in plan, ( + f"useLargeVarTypes=true should force fallback, but plan has " + f"CometMapInBatch:\n{plan}" + ) + assert "MapInArrow" in plan, ( + f"expected vanilla MapInArrow in fallback plan, got:\n{plan}" + ) + out = sorted((r["id"], r["name"]) for r in result_df.collect()) + assert out == sorted(rows) + finally: + spark.conf.set("spark.comet.exec.pyarrowUdf.enabled", prev_pyarrow) + spark.conf.set("spark.sql.execution.arrow.useLargeVarTypes", prev_large) + + +def test_map_in_arrow_after_shuffle(spark, tmp_path, accelerated): + """ + Verifies correctness when a shuffle sits between the Comet scan and the + Python UDF. Without `spark.shuffle.manager` configured at session startup + the shuffle stays a vanilla `Exchange`, which is not columnar, so the + optimization does not fire across it today. This test does not assert on + the plan; it only ensures the path produces correct results in both modes + so a future change that wires Comet shuffle into the optimization does + not silently break correctness. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + rows = [(i, float(i)) for i in range(50)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = ( + spark.read.parquet(src) + .repartition(4, "id") + .mapInArrow(passthrough, schema_in) + ) + + out = sorted((r["id"], r["value"]) for r in result_df.collect()) + assert out == sorted(rows) + + +def test_chained_map_in_arrow(spark, tmp_path, accelerated): + """ + `df.mapInArrow(udf1).mapInArrow(udf2)` stacks two operators. With the rewrite + enabled both become `CometMapInBatchExec`, so the inner one's output feeds + the outer one's input. The outer operator's input path expects vectors of + `CometDecodedVector` type: if the inner's output is plain `ArrowColumnVector` + the outer throws `ClassCastException` on the first batch. + """ + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + rows = [(i, float(i)) for i in range(50)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema).write.parquet(src) + + def add_one(iterator): + for batch in iterator: + pdf = batch.to_pandas() + pdf["value"] = pdf["value"] + 1.0 + yield pa.RecordBatch.from_pandas(pdf) + + def double_value(iterator): + for batch in iterator: + pdf = batch.to_pandas() + pdf["value"] = pdf["value"] * 2.0 + yield pa.RecordBatch.from_pandas(pdf) + + result_df = ( + spark.read.parquet(src) + .mapInArrow(add_one, schema) + .mapInArrow(double_value, schema) + ) + + if accelerated: + plan = _executed_plan(result_df) + assert plan.count("CometMapInBatch") >= 2, ( + f"expected two CometMapInBatch operators in accelerated plan, got:\n{plan}" + ) + + out = sorted((r["id"], r["value"]) for r in result_df.collect()) + expected = sorted((i, (float(i) + 1.0) * 2.0) for i in range(50)) + assert out == expected + + +def test_filter_on_map_in_arrow_output(spark, tmp_path, accelerated): + """ + A filter on the UDF output column is a downstream Comet operator (when Comet's + native filter applies) reading from `CometMapInBatchExec`'s output. If the + output were plain `ArrowColumnVector`, NativeUtil.exportBatch's case match + would fall to the `case c =>` arm and throw SparkException. + """ + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.LongType()), + ] + ) + rows = [(i, i * 2) for i in range(100)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = ( + spark.read.parquet(src).mapInArrow(passthrough, schema).filter("value > 50") + ) + + out = sorted((r["id"], r["value"]) for r in result_df.collect()) + expected = sorted((i, i * 2) for i in range(100) if i * 2 > 50) + assert out == expected + + +def test_aggregate_on_map_in_arrow_output(spark, tmp_path, accelerated): + """ + `mapInArrow(...).groupBy(...).agg(...)` puts an aggregate over the UDF output. + The aggregate is a Comet operator and reads from `CometMapInBatchExec`'s + output via NativeUtil.exportBatch when promoted to the native pipeline. If + the output were ArrowColumnVector, exportBatch would throw on every batch. + """ + schema = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("grp", T.LongType()), + T.StructField("value", T.LongType()), + ] + ) + rows = [(i, i % 5, i) for i in range(100)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema).write.parquet(src) + + def passthrough(iterator): + for batch in iterator: + yield batch + + result_df = ( + spark.read.parquet(src) + .mapInArrow(passthrough, schema) + .groupBy("grp") + .agg({"value": "sum"}) + ) + + out = {r["grp"]: r["sum(value)"] for r in result_df.collect()} + expected = {} + for i in range(100): + expected[i % 5] = expected.get(i % 5, 0) + i + assert out == expected + + +def test_map_in_arrow_barrier_mode(spark, tmp_path, accelerated): + """ + `mapInArrow(..., barrier=True)` runs the stage in barrier execution mode + (gang scheduling, all-or-nothing failure semantics, BarrierTaskContext + available inside the UDF). The optimization captures isBarrier in the + operator constructor and must propagate it through to RDD.barrier(); + otherwise the runtime context the UDF sees changes when the optimization + fires and any code calling BarrierTaskContext APIs breaks. + """ + schema_in = T.StructType( + [ + T.StructField("id", T.LongType()), + T.StructField("value", T.DoubleType()), + ] + ) + rows = [(i, float(i)) for i in range(20)] + src = str(tmp_path / "src.parquet") + spark.createDataFrame(rows, schema_in).write.parquet(src) + + def assert_barrier_context(iterator): + from pyspark import BarrierTaskContext + + # Will raise if the task is not running inside a barrier stage. + BarrierTaskContext.get() + for batch in iterator: + yield batch + + result_df = ( + spark.read.parquet(src).mapInArrow( + assert_barrier_context, schema_in, barrier=True + ) + ) + _assert_plan_matches_mode(_executed_plan(result_df), accelerated) + + out = sorted((r["id"], r["value"]) for r in result_df.collect()) + assert out == sorted(rows) diff --git a/spark/src/test/spark-4.x/org/apache/spark/sql/comet/CometMapInBatchSuite.scala b/spark/src/test/spark-4.x/org/apache/spark/sql/comet/CometMapInBatchSuite.scala new file mode 100644 index 0000000000..b4f3d64d76 --- /dev/null +++ b/spark/src/test/spark-4.x/org/apache/spark/sql/comet/CometMapInBatchSuite.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.api.python.{PythonAccumulatorV2, PythonBroadcast, PythonEvalType, PythonFunction} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, ExprId, PythonUDF} +import org.apache.spark.sql.execution.{ColumnarToRowExec, LeafExecNode} +import org.apache.spark.sql.execution.python.MapInArrowExec +import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometConf +import org.apache.comet.rules.EliminateRedundantTransitions + +/** Minimal CometPlan leaf used to anchor the rule's transform without triggering execution. */ +private case class StubCometLeaf(override val output: Seq[Attribute]) + extends LeafExecNode + with CometPlan { + override def supportsColumnar: Boolean = true + override protected def doExecute(): RDD[InternalRow] = + throw new UnsupportedOperationException + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = + throw new UnsupportedOperationException +} + +/** + * Plan-rule test for the `EliminateRedundantTransitions` rewrite that produces + * `CometMapInBatchExec`. Pure Python execution paths are covered by the pytest module + * `test_pyarrow_udf.py`; this suite verifies the JVM-side rule without spinning up Python. + * + * Lives under `org.apache.spark.sql.comet` so it can reference Spark's `private[spark]` + * `PythonFunction` / `PythonAccumulatorV2` / `PythonBroadcast` classes when fabricating a stub + * `PythonUDF` for `MapInArrowExec` to wrap. + */ +class CometMapInBatchSuite extends CometTestBase { + + private def stubPythonUDF: PythonUDF = { + val pyFunc = new PythonFunction { + override val command: Seq[Byte] = Seq.empty[Byte] + override val envVars: java.util.Map[String, String] = + new java.util.HashMap[String, String]() + override val pythonIncludes: java.util.List[String] = + java.util.Collections.emptyList[String]() + override val pythonExec: String = "python3" + override val pythonVer: String = "3" + override val broadcastVars: java.util.List[Broadcast[PythonBroadcast]] = + java.util.Collections.emptyList[Broadcast[PythonBroadcast]]() + override val accumulator: PythonAccumulatorV2 = null + } + PythonUDF( + name = "test_udf", + func = pyFunc, + dataType = StructType(Seq(StructField("id", LongType))), + children = Seq(AttributeReference("id", LongType)(ExprId(0L))), + evalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF, + udfDeterministic = true) + } + + private def buildPlan(): MapInArrowExec = { + val cometChild = StubCometLeaf(Seq(AttributeReference("id", LongType)(ExprId(0L)))) + MapInArrowExec( + stubPythonUDF, + cometChild.output, + ColumnarToRowExec(cometChild), + isBarrier = false, + profile = None) + } + + test("rule rewrites MapInArrowExec over Comet to CometMapInBatchExec") { + withSQLConf(CometConf.COMET_PYARROW_UDF_ENABLED.key -> "true") { + val rewritten = EliminateRedundantTransitions(spark).apply(buildPlan()) + assert( + rewritten.exists(_.isInstanceOf[CometMapInBatchExec]), + s"expected CometMapInBatchExec in rewritten plan:\n$rewritten") + } + } + + test("rule does not rewrite when feature is disabled") { + withSQLConf(CometConf.COMET_PYARROW_UDF_ENABLED.key -> "false") { + val rewritten = EliminateRedundantTransitions(spark).apply(buildPlan()) + assert( + !rewritten.exists(_.isInstanceOf[CometMapInBatchExec]), + s"unexpected CometMapInBatchExec when disabled:\n$rewritten") + } + } + + test("rule handles chained MapInArrowExec without crashing") { + // df.mapInArrow(...).mapInArrow(...) produces two MapInArrowExec operators. The outer + // consumes rows from the inner directly (MapInArrowExec is a row producer), so there is + // no ColumnarToRow between them. After the rule's bottom-up rewrite the inner becomes + // CometMapInBatchExec; the outer keeps its row contract and is satisfied by + // CometMapInBatchExec.doExecute() reintroducing a ColumnarToRow internally. The + // assertion exists mainly to pin the structure: regress this if a future change makes + // both rewrite (the bulk-copy input path would then need to accept a CometVector input + // that did not come from a CometDecodedVector chain). + withSQLConf(CometConf.COMET_PYARROW_UDF_ENABLED.key -> "true") { + val cometLeaf = StubCometLeaf(Seq(AttributeReference("id", LongType)(ExprId(0L)))) + val inner = MapInArrowExec( + stubPythonUDF, + cometLeaf.output, + ColumnarToRowExec(cometLeaf), + isBarrier = false, + profile = None) + val outer = MapInArrowExec( + stubPythonUDF, + cometLeaf.output, + inner, + isBarrier = false, + profile = None) + + val rewritten = EliminateRedundantTransitions(spark).apply(outer) + val cometOps = rewritten.collect { case op: CometMapInBatchExec => op } + assert( + cometOps.size == 1, + s"expected the inner MapInArrowExec to be rewritten, but the chain produced " + + s"${cometOps.size} CometMapInBatchExec(s):\n$rewritten") + } + } + + test("end-to-end: rewrite-on output matches rewrite-off output for primitives + varchar") { + // This test needs PySpark workers; only run if PYSPARK_PYTHON is set in the env. + assume( + sys.env.contains("PYSPARK_PYTHON"), + "set PYSPARK_PYTHON to enable end-to-end pyarrow UDF tests") + + withTempPath { path => + val pathStr = path.getCanonicalPath + spark + .range(0, 1000, 1, 4) + .selectExpr( + "id AS id", + "CAST(id AS DOUBLE) * 1.5 AS dbl", + "CASE WHEN id % 10 = 0 THEN NULL ELSE CONCAT('row_', CAST(id AS STRING)) END AS s") + .write + .mode("overwrite") + .parquet(pathStr) + + // Baseline: rewrite disabled, vanilla MapInArrowExec runs. + val baseline = withSQLConf(CometConf.COMET_PYARROW_UDF_ENABLED.key -> "false") { + spark.read.parquet(pathStr).collect().map(_.toSeq).toSet + } + + // Optimized: rewrite enabled, CometMapInBatchExec + CometArrowPythonRunner runs. + withSQLConf(CometConf.COMET_PYARROW_UDF_ENABLED.key -> "true") { + val df = spark.read.parquet(pathStr) + val result = df.collect().map(_.toSeq).toSet + assert( + result == baseline, + s"optimized output differs from baseline:\noptimized=$result\nbaseline=$baseline") + } + } + } +}