diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 72c2bea9e4..15fd147740 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -670,7 +670,7 @@ case class CometExecRule(session: SparkSession) val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]] if (isOperatorEnabled(serde, op)) { // For operators that require native children (like writes), check if all data-producing - // children are CometNativeExec. This prevents runtime failures when the native operator + // children produce Arrow data. This prevents runtime failures when the native operator // expects Arrow arrays but receives non-Arrow data (e.g., OnHeapColumnVector). if (serde.requiresNativeChildren && op.children.nonEmpty) { // Get the actual data-producing children (unwrap WriteFilesExec if present) @@ -678,7 +678,7 @@ case class CometExecRule(session: SparkSession) case writeFiles: WriteFilesExec => Seq(writeFiles.child) case other => Seq(other) } - if (!dataProducingChildren.forall(_.isInstanceOf[CometNativeExec])) { + if (!dataProducingChildren.forall(producesArrowData)) { withInfo(op, "Cannot perform native operation because input is not in Arrow format") return None } @@ -787,4 +787,19 @@ case class CometExecRule(session: SparkSession) } } + /** + * Checks if a plan produces Arrow-formatted data by unwrapping wrapper operators. This handles + * ReusedExchangeExec (used in multi-insert), QueryStageExec (AQE), and checks for CometExec + * (includes CometNativeExec and sink operators like CometUnionExec, CometCoalesceExec, etc.). + */ + private def producesArrowData(plan: SparkPlan): Boolean = { + plan match { + case _: CometExec => true + case r: ReusedExchangeExec => producesArrowData(r.child) + case s: ShuffleQueryStageExec => producesArrowData(s.plan) + case b: BroadcastQueryStageExec => producesArrowData(b.plan) + case _ => false + } + } + } diff --git a/spark/src/test/resources/sql-tests/write/multi_insert.sql b/spark/src/test/resources/sql-tests/write/multi_insert.sql new file mode 100644 index 0000000000..7e27bb2ba6 --- /dev/null +++ b/spark/src/test/resources/sql-tests/write/multi_insert.sql @@ -0,0 +1,75 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test multi-insert with Comet native writer (issue #3430, SPARK-48817) +-- Validates that data written via multi-insert SQL is correct when +-- the native writer handles ReusedExchangeExec in the plan. + +-- Config: spark.comet.parquet.write.enabled=true +-- Config: spark.comet.operator.DataWritingCommandExec.allowIncompatible=true +-- Config: spark.comet.exec.enabled=true +-- Config: spark.sql.adaptive.enabled=false + +statement +CREATE TABLE multi_src(c1 INT) USING PARQUET + +statement +INSERT INTO multi_src VALUES (1), (2), (3), (4), (5) + +statement +CREATE TABLE multi_dst1(c1 INT) USING PARQUET + +statement +CREATE TABLE multi_dst2(c1 INT) USING PARQUET + +-- Multi-insert: single plan with ReusedExchangeExec +statement +FROM (SELECT /*+ REPARTITION(3) */ c1 FROM multi_src) +INSERT OVERWRITE TABLE multi_dst1 SELECT c1 +INSERT OVERWRITE TABLE multi_dst2 SELECT c1 + +-- Validate data in both destination tables +query +SELECT c1 FROM multi_dst1 ORDER BY c1 + +query +SELECT c1 FROM multi_dst2 ORDER BY c1 + +-- Verify both tables have equal content +query +SELECT count(*) FROM multi_dst1 + +query +SELECT count(*) FROM multi_dst2 + +-- Multi-insert with filtered inserts into different targets +statement +CREATE TABLE multi_dst3(c1 INT) USING PARQUET + +statement +CREATE TABLE multi_dst4(c1 INT) USING PARQUET + +statement +FROM (SELECT /*+ REPARTITION(2) */ c1 FROM multi_src) +INSERT OVERWRITE TABLE multi_dst3 SELECT c1 WHERE c1 <= 3 +INSERT OVERWRITE TABLE multi_dst4 SELECT c1 WHERE c1 > 3 + +query +SELECT c1 FROM multi_dst3 ORDER BY c1 + +query +SELECT c1 FROM multi_dst4 ORDER BY c1 diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala index 815f03f213..2d741c08a5 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.comet.{CometBatchScanExec, CometNativeScanExec, CometNativeWriteExec, CometScanExec} import org.apache.spark.sql.execution.{FileSourceScanExec, QueryExecution, SparkPlan} import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -140,6 +141,52 @@ class CometParquetWriterSuite extends CometTestBase { } } + // Test for issue #3430: SPARK-48817 multi-insert with native writer in Spark 4.x + // Uses SQL multi-insert syntax to produce a plan with ReusedExchangeExec, + // which exercises the producesArrowData() path in CometExecRule. + test("parquet write with multi-insert pattern") { + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + + withTable("src", "dst1", "dst2") { + sql("CREATE TABLE src(c1 INT) USING PARQUET") + sql("INSERT INTO src VALUES (1), (2), (3)") + sql("CREATE TABLE dst1(c1 INT) USING PARQUET") + sql("CREATE TABLE dst2(c1 INT) USING PARQUET") + + // Multi-insert: single plan inserts from one source into two tables. + // The REPARTITION hint forces a shuffle exchange that Spark reuses + // via ReusedExchangeExec for the second insert. + val multiInsertDf = sql(""" + |FROM (SELECT /*+ REPARTITION(3) */ c1 FROM src) + |INSERT OVERWRITE TABLE dst1 SELECT c1 + |INSERT OVERWRITE TABLE dst2 SELECT c1 + """.stripMargin) + val plan = multiInsertDf.queryExecution.executedPlan + + // Assert that the plan contains ReusedExchangeExec, proving + // we are exercising the multi-insert reuse path + val reusedExchanges = plan.collect { case r: ReusedExchangeExec => r } + assert( + reusedExchanges.nonEmpty, + s"Expected ReusedExchangeExec in the multi-insert plan, but found none:\n${plan.treeString}") + + // Assert native write was used + val nativeWrites = plan.collect { case n: CometNativeWriteExec => n } + assert( + nativeWrites.nonEmpty, + s"Expected CometNativeWriteExec in the plan, but found none:\n${plan.treeString}") + + // Verify data correctness + checkAnswer(sql("SELECT c1 FROM dst1 ORDER BY c1"), Seq(Row(1), Row(2), Row(3))) + checkAnswer(sql("SELECT c1 FROM dst2 ORDER BY c1"), Seq(Row(1), Row(2), Row(3))) + } + } + } + test("parquet write with map type") { withTempPath { dir => val outputPath = new File(dir, "output.parquet").getAbsolutePath