diff --git a/native/core/src/execution/operators/parquet_writer.rs b/native/core/src/execution/operators/parquet_writer.rs index 8ba79098d4..7954e720cb 100644 --- a/native/core/src/execution/operators/parquet_writer.rs +++ b/native/core/src/execution/operators/parquet_writer.rs @@ -202,6 +202,9 @@ pub struct ParquetWriterExec { job_id: Option, /// Task attempt ID for this specific task task_attempt_id: Option, + /// Complete staging file path from FileCommitProtocol.newTaskTempFile() + /// When set, writes directly to this path for proper 2PC support + staging_file_path: Option, /// Compression codec compression: CompressionCodec, /// Partition ID (from Spark TaskContext) @@ -225,6 +228,7 @@ impl ParquetWriterExec { work_dir: String, job_id: Option, task_attempt_id: Option, + staging_file_path: Option, compression: CompressionCodec, partition_id: i32, column_names: Vec, @@ -246,6 +250,7 @@ impl ParquetWriterExec { work_dir, job_id, task_attempt_id, + staging_file_path, compression, partition_id, column_names, @@ -439,6 +444,7 @@ impl ExecutionPlan for ParquetWriterExec { self.work_dir.clone(), self.job_id.clone(), self.task_attempt_id, + self.staging_file_path.clone(), self.compression.clone(), self.partition_id, self.column_names.clone(), @@ -465,7 +471,9 @@ impl ExecutionPlan for ParquetWriterExec { let runtime_env = context.runtime_env(); let input = self.input.execute(partition, context)?; let input_schema = self.input.schema(); + let output_path = self.output_path.clone(); let work_dir = self.work_dir.clone(); + let staging_file_path = self.staging_file_path.clone(); let task_attempt_id = self.task_attempt_id; let compression = self.compression_to_parquet()?; let column_names = self.column_names.clone(); @@ -481,15 +489,25 @@ impl ExecutionPlan for ParquetWriterExec { .collect(); let output_schema = Arc::new(arrow::datatypes::Schema::new(fields)); - // Generate part file name for this partition - // If using FileCommitProtocol (work_dir is set), include task_attempt_id in the filename - let part_file = if let Some(attempt_id) = task_attempt_id { + // Determine output file path: + // 1. If staging_file_path is set (proper 2PC), use it directly + // 2. If work_dir is set, use work_dir-based path construction + // 3. Otherwise use output_path directly + let base_dir = if !work_dir.is_empty() { + work_dir + } else { + output_path + }; + + let part_file = if let Some(ref staging_path) = staging_file_path { + staging_path.clone() + } else if let Some(attempt_id) = task_attempt_id { format!( "{}/part-{:05}-{:05}.parquet", - work_dir, self.partition_id, attempt_id + base_dir, self.partition_id, attempt_id ) } else { - format!("{}/part-{:05}.parquet", work_dir, self.partition_id) + format!("{}/part-{:05}.parquet", base_dir, self.partition_id) }; // Configure writer properties @@ -824,6 +842,7 @@ mod tests { work_dir, None, // job_id Some(123), // task_attempt_id + None, // staging_file_path CompressionCodec::None, 0, // partition_id column_names, diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index debe47ba04..b12eba649f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1562,13 +1562,10 @@ impl PhysicalPlanner { let parquet_writer = Arc::new(ParquetWriterExec::try_new( Arc::clone(&child.native_plan), writer.output_path.clone(), - writer - .work_dir - .as_ref() - .expect("work_dir is provided") - .clone(), + writer.work_dir.clone().unwrap_or_default(), writer.job_id.clone(), writer.task_attempt_id, + writer.staging_file_path.clone(), codec, self.partition, writer.column_names.clone(), diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 7a33d46282..2ab58259fb 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -328,7 +328,7 @@ message ParquetWriter { CompressionCodec compression = 2; repeated string column_names = 4; // Working directory for temporary files (used by FileCommitProtocol) - // If not set, files are written directly to output_path + // DEPRECATED: Use staging_file_path instead for proper 2PC support optional string work_dir = 5; // Job ID for tracking this write operation optional string job_id = 6; @@ -341,6 +341,9 @@ message ParquetWriter { // configuration value "spark.hadoop.fs.s3a.access.key" will be stored as "fs.s3a.access.key" in // the map. map object_store_options = 8; + // Complete staging file path from FileCommitProtocol.newTaskTempFile() + // When set, native writer writes directly to this path for proper 2PC + optional string staging_file_path = 9; } enum AggregateMode { diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala index 69b9bd5f85..7927c8b567 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala @@ -24,7 +24,6 @@ import java.util.Locale import scala.jdk.CollectionConverters._ -import org.apache.spark.SparkException import org.apache.spark.sql.comet.{CometNativeExec, CometNativeWriteExec} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, WriteFilesExec} @@ -179,29 +178,13 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec other } - // Create FileCommitProtocol for atomic writes - val jobId = java.util.UUID.randomUUID().toString - val committer = - try { - // Use Spark's SQLHadoopMapReduceCommitProtocol - val committerClass = - classOf[org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol] - val constructor = - committerClass.getConstructor(classOf[String], classOf[String], classOf[Boolean]) - Some( - constructor - .newInstance( - jobId, - outputPath, - java.lang.Boolean.FALSE // dynamicPartitionOverwrite = false for now - ) - .asInstanceOf[org.apache.spark.internal.io.FileCommitProtocol]) - } catch { - case e: Exception => - throw new SparkException(s"Could not instantiate FileCommitProtocol: ${e.getMessage}") - } - - CometNativeWriteExec(nativeOp, childPlan, outputPath, committer, jobId) + // Note: We don't create our own FileCommitProtocol here because: + // 1. InsertIntoHadoopFsRelationCommand creates and manages its own committer + // 2. That committer is passed to FileFormatWriter which handles the commit flow + // 3. Our CometNativeWriteExec child is only used for data, not commit protocol + // The native writer writes directly to the output path, relying on Spark's + // existing commit protocol for atomicity. + CometNativeWriteExec(nativeOp, childPlan, outputPath) } private def parseCompressionCodec(cmd: InsertIntoHadoopFsRelationCommand) = { diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriter2PCSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriter2PCSuite.scala new file mode 100644 index 0000000000..c87d1cdc0a --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriter2PCSuite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import java.io.File + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.command.DataWritingCommandExec + +import org.apache.comet.CometConf + +/** + * Test suite for Comet Native Parquet Writer. + * + * Tests basic write functionality and verifies data integrity. + */ +class CometParquetWriter2PCSuite extends CometTestBase { + + private val nativeWriteConf = Seq( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") + + /** Helper to check if output directory contains any data files */ + private def hasDataFiles(dir: File): Boolean = { + if (!dir.exists()) return false + dir.listFiles().exists(f => f.getName.startsWith("part-") && f.getName.endsWith(".parquet")) + } + + /** Helper to count data files in directory */ + private def countDataFiles(dir: File): Int = { + if (!dir.exists()) return 0 + dir.listFiles().count(f => f.getName.startsWith("part-") && f.getName.endsWith(".parquet")) + } + + // ========================================================================== + // Test 1: Basic successful write should work + // ========================================================================== + test("basic successful write should create files in output directory") { + withTempPath { dir => + val outputPath = new File(dir, "output").getAbsolutePath + + val df = spark + .range(0, 1000, 1, 4) + .selectExpr("id", "id * 2 as value") + + withSQLConf(nativeWriteConf: _*) { + df.write.parquet(outputPath) + + val outputDir = new File(outputPath) + assert(hasDataFiles(outputDir), "Data files should exist in output directory") + + // Verify data can be read back correctly + val readDf = spark.read.parquet(outputPath) + assert(readDf.count() == 1000, "Should have 1000 rows") + } + } + } + + // ========================================================================== + // Test 2: Multiple partitions write correctly + // ========================================================================== + test("multiple concurrent tasks should write without file conflicts") { + withTempPath { dir => + val outputPath = new File(dir, "output").getAbsolutePath + + // Create larger dataset with more partitions + val df = spark + .range(0, 10000, 1, 20) + .selectExpr("id", "id * 2 as value") + + withSQLConf(nativeWriteConf: _*) { + df.write.parquet(outputPath) + + val outputDir = new File(outputPath) + val fileCount = countDataFiles(outputDir) + assert(fileCount >= 20, s"Expected at least 20 files for 20 partitions, got $fileCount") + + // Verify data integrity + val readDf = spark.read.parquet(outputPath) + assert(readDf.count() == 10000, "Should have 10000 rows") + + // Verify no data corruption + val sum = readDf.selectExpr("sum(id)").collect()(0).getLong(0) + val expectedSum = (0L until 10000L).sum + assert(sum == expectedSum, s"Data corruption detected: sum=$sum, expected=$expectedSum") + } + } + } + + // ========================================================================== + // Test 3: Write with different data types + // ========================================================================== + test("write various data types correctly") { + withTempPath { dir => + val outputPath = new File(dir, "output").getAbsolutePath + + val df = spark + .range(0, 100) + .selectExpr( + "id", + "cast(id as int) as int_col", + "cast(id as double) as double_col", + "cast(id as string) as string_col", + "id % 2 = 0 as bool_col") + + withSQLConf(nativeWriteConf: _*) { + df.write.parquet(outputPath) + + val readDf = spark.read.parquet(outputPath) + assert(readDf.count() == 100) + assert( + readDf.schema.fieldNames.toSet == Set( + "id", + "int_col", + "double_col", + "string_col", + "bool_col")) + } + } + } + + // ========================================================================== + // Test 4: Append mode - currently a known limitation + // Native writes use partition-based filenames without unique job IDs, + // so append overwrites files with same names. This test verifies the + // current behavior rather than ideal append semantics. + // ========================================================================== + test("append mode overwrites files with same partition IDs (known limitation)") { + withTempPath { dir => + val outputPath = new File(dir, "output").getAbsolutePath + + // Use different partition counts to avoid complete overlap + val df1 = spark.range(0, 500, 1, 2).toDF("id") // 2 partitions + val df2 = spark.range(500, 1000, 1, 3).toDF("id") // 3 partitions + + withSQLConf(nativeWriteConf: _*) { + df1.write.parquet(outputPath) + val countAfterFirst = spark.read.parquet(outputPath).count() + assert(countAfterFirst == 500, "Should have 500 rows after first write") + + df2.write.mode("append").parquet(outputPath) + + // Due to filename conflicts, only partition files that don't overlap survive + // Partitions 0, 1 get overwritten, partition 2 is new + val readDf = spark.read.parquet(outputPath) + val finalCount = readDf.count() + // We expect some rows from df2 (at least partition 2) plus potentially + // overwritten partitions. The exact count depends on partition distribution. + assert(finalCount > 0, "Should have some rows after append") + } + } + } + + // ========================================================================== + // Test 5: Overwrite mode works correctly + // ========================================================================== + test("overwrite mode should replace existing files") { + withTempPath { dir => + val outputPath = new File(dir, "output").getAbsolutePath + + val df1 = spark.range(0, 1000).toDF("id") + val df2 = spark.range(0, 500).toDF("id") + + withSQLConf(nativeWriteConf: _*) { + df1.write.parquet(outputPath) + df2.write.mode("overwrite").parquet(outputPath) + + val readDf = spark.read.parquet(outputPath) + assert(readDf.count() == 500, "Should have 500 rows after overwrite") + } + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterCommitSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterCommitSuite.scala new file mode 100644 index 0000000000..9365907502 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterCommitSuite.scala @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import java.io.File + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.functions._ + +import org.apache.comet.CometConf + +class CometParquetWriterCommitSuite extends CometTestBase { + + private val nativeWriteConf = Seq( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") + + private def hasDataFiles(dir: File): Boolean = { + if (!dir.exists()) return false + dir.listFiles().exists(f => f.getName.startsWith("part-") && f.getName.endsWith(".parquet")) + } + + test("_temporary folder is created during write and cleaned up after commit") { + withTempPath { dir => + val outputPath = new File(dir, "output").getAbsolutePath + + val df = spark + .range(0, 100000, 1, 4) + .selectExpr("id", "id * 2 as value") + + withSQLConf(nativeWriteConf: _*) { + @volatile var writeStarted = false + @volatile var writeException: Option[Throwable] = None + val writeThread = new Thread(() => { + try { + writeStarted = true + df.write.parquet(outputPath) + } catch { + case e: Throwable => writeException = Some(e) + } + }) + writeThread.start() + + CometWriteTestHelpers.waitForCondition(writeStarted, timeoutMs = 5000) + + val tempExists = CometWriteTestHelpers.waitForCondition( + CometWriteTestHelpers.hasTemporaryFolder(outputPath), + timeoutMs = 10000) + + if (tempExists) { + assert( + CometWriteTestHelpers.hasTemporaryFolder(outputPath), + "_temporary folder should be created during write") + + val tempFileCount = CometWriteTestHelpers.countTemporaryFiles(outputPath) + assert(tempFileCount > 0, s"Expected temp files during write, found $tempFileCount") + } + + writeThread.join(30000) + assert(!writeThread.isAlive, "Write should complete within 30 seconds") + + writeException.foreach(throw _) + + assert( + !CometWriteTestHelpers.hasTemporaryFolder(outputPath), + "_temporary folder should be cleaned up after successful commit") + + val outputDir = new File(outputPath) + assert(hasDataFiles(outputDir), "Final data files should exist") + + val readDf = spark.read.parquet(outputPath) + assert(readDf.count() == 100000, "All rows should be committed") + } + } + } + + test("_temporary folder is cleaned up on task failure") { + withTempPath { dir => + val outputPath = new File(dir, "output").getAbsolutePath + + val divideByZero = udf((x: Long) => { x / (x - 100) }) + val df = spark + .range(0, 1000, 1, 1) // single partition to avoid race conditions + .select(divideByZero(col("id")).as("value")) + + withSQLConf(nativeWriteConf: _*) { + intercept[Exception] { + df.write.parquet(outputPath) + } + + // small delay for cleanup to complete + Thread.sleep(1000) + assert( + !CometWriteTestHelpers.hasTemporaryFolder(outputPath), + "_temporary folder should be cleaned up after task failure") + + val outputDir = new File(outputPath) + if (outputDir.exists()) { + assert(!hasDataFiles(outputDir), "No data files should exist after failure") + } + } + } + } + + test("_temporary folder handles concurrent tasks correctly") { + withTempPath { dir => + val outputPath = new File(dir, "output").getAbsolutePath + + val df = spark + .range(0, 50000, 1, 10) + .selectExpr("id", "id * 2 as value") + + withSQLConf(nativeWriteConf: _*) { + @volatile var writeStarted = false + @volatile var writeException: Option[Throwable] = None + val writeThread = new Thread(() => { + try { + writeStarted = true + df.write.parquet(outputPath) + } catch { + case e: Throwable => writeException = Some(e) + } + }) + writeThread.start() + + CometWriteTestHelpers.waitForCondition(writeStarted, timeoutMs = 5000) + + val tempAppeared = CometWriteTestHelpers.waitForCondition( + CometWriteTestHelpers.hasTemporaryFolder(outputPath), + timeoutMs = 10000) + + if (tempAppeared) { + val subfolders = CometWriteTestHelpers.getTemporarySubfolders(outputPath) + assert(subfolders.nonEmpty, "Should have job tracking folders") + } + + writeThread.join(30000) + + writeException.foreach(throw _) + + assert( + !CometWriteTestHelpers.hasTemporaryFolder(outputPath), + "_temporary should be cleaned up after commit") + + val outputDir = new File(outputPath) + assert(hasDataFiles(outputDir), "Data files should exist") + + val readDf = spark.read.parquet(outputPath) + assert(readDf.count() == 50000, "All rows should be committed") + } + } + } + + test("_temporary folder is cleaned up on overwrite") { + withTempPath { dir => + val outputPath = new File(dir, "output").getAbsolutePath + + withSQLConf(nativeWriteConf: _*) { + spark.range(1000).write.parquet(outputPath) + assert( + !CometWriteTestHelpers.hasTemporaryFolder(outputPath), + "_temporary should be cleaned up after first write") + + val count1 = spark.read.parquet(outputPath).count() + assert(count1 == 1000, "First write should have 1000 rows") + + spark.range(500).write.mode("overwrite").parquet(outputPath) + assert( + !CometWriteTestHelpers.hasTemporaryFolder(outputPath), + "_temporary should be cleaned up after overwrite") + + val count2 = spark.read.parquet(outputPath).count() + assert(count2 == 500, "Overwrite should result in 500 rows") + } + } + } + + test("small writes may not create visible _temporary folder") { + withTempPath { dir => + val outputPath = new File(dir, "output").getAbsolutePath + + withSQLConf(nativeWriteConf: _*) { + spark.range(10).write.parquet(outputPath) + + assert( + !CometWriteTestHelpers.hasTemporaryFolder(outputPath), + "_temporary should not exist after completion") + + val outputDir = new File(outputPath) + assert(hasDataFiles(outputDir), "Data files should exist") + + val readDf = spark.read.parquet(outputPath) + assert(readDf.count() == 10, "Should have 10 rows") + } + } + } + + test("multiple concurrent writes to different paths are isolated") { + withTempPath { dir1 => + withTempPath { dir2 => + val outputPath1 = new File(dir1, "output1").getAbsolutePath + val outputPath2 = new File(dir2, "output2").getAbsolutePath + + withSQLConf(nativeWriteConf: _*) { + val df1 = spark.range(0, 10000, 1, 4) + val df2 = spark.range(10000, 20000, 1, 4) + + val thread1 = new Thread(() => df1.write.parquet(outputPath1)) + val thread2 = new Thread(() => df2.write.parquet(outputPath2)) + + thread1.start() + thread2.start() + + thread1.join(30000) + thread2.join(30000) + + assert(!CometWriteTestHelpers.hasTemporaryFolder(outputPath1)) + assert(!CometWriteTestHelpers.hasTemporaryFolder(outputPath2)) + + assert(spark.read.parquet(outputPath1).count() == 10000) + assert(spark.read.parquet(outputPath2).count() == 10000) + } + } + } + } + + test("no stale _temporary folders from previous operations") { + withTempPath { dir => + val outputPath = new File(dir, "output").getAbsolutePath + + withSQLConf(nativeWriteConf: _*) { + spark.range(100).write.parquet(outputPath) + assert(!CometWriteTestHelpers.hasTemporaryFolder(outputPath)) + + spark.range(200).write.mode("overwrite").parquet(outputPath) + assert(!CometWriteTestHelpers.hasTemporaryFolder(outputPath)) + + spark.range(300).write.mode("overwrite").parquet(outputPath) + assert(!CometWriteTestHelpers.hasTemporaryFolder(outputPath)) + + assert(spark.read.parquet(outputPath).count() == 300) + + val dirs = CometWriteTestHelpers.listDirectories(outputPath) + assert(!dirs.exists(_.startsWith("_temporary")), "No _temporary folders should exist") + } + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometWriteTestHelpers.scala b/spark/src/test/scala/org/apache/comet/parquet/CometWriteTestHelpers.scala new file mode 100644 index 0000000000..0689dca9f5 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/parquet/CometWriteTestHelpers.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.SparkSession + +object CometWriteTestHelpers { + + def hasTemporaryFolder(basePath: String)(implicit spark: SparkSession): Boolean = { + try { + val fs = new Path(basePath).getFileSystem(spark.sparkContext.hadoopConfiguration) + fs.exists(new Path(basePath, "_temporary")) + } catch { + case _: Exception => false + } + } + + def getTemporarySubfolders(basePath: String)(implicit spark: SparkSession): Seq[String] = { + try { + val fs = new Path(basePath).getFileSystem(spark.sparkContext.hadoopConfiguration) + val tempPath = new Path(basePath, "_temporary") + if (!fs.exists(tempPath)) return Seq.empty + + fs.listStatus(tempPath).map(_.getPath.getName).toSeq + } catch { + case _: Exception => Seq.empty + } + } + + def countTemporaryFiles(basePath: String)(implicit spark: SparkSession): Int = { + try { + val fs = new Path(basePath).getFileSystem(spark.sparkContext.hadoopConfiguration) + val tempPath = new Path(basePath, "_temporary") + if (!fs.exists(tempPath)) return 0 + + def countRecursive(path: Path): Int = { + val status = fs.listStatus(path) + status.map { fileStatus => + if (fileStatus.isDirectory) { + countRecursive(fileStatus.getPath) + } else { + 1 + } + }.sum + } + + countRecursive(tempPath) + } catch { + case _: Exception => 0 + } + } + + def waitForCondition( + condition: => Boolean, + timeoutMs: Long = 5000, + intervalMs: Long = 100): Boolean = { + val deadline = System.currentTimeMillis() + timeoutMs + while (System.currentTimeMillis() < deadline) { + if (condition) return true + Thread.sleep(intervalMs) + } + false + } + + def listFiles(basePath: String)(implicit spark: SparkSession): Seq[String] = { + try { + val fs = new Path(basePath).getFileSystem(spark.sparkContext.hadoopConfiguration) + val path = new Path(basePath) + if (!fs.exists(path)) return Seq.empty + + fs.listStatus(path) + .filter(_.isFile) + .map(_.getPath.getName) + .toSeq + } catch { + case _: Exception => Seq.empty + } + } + + def listDirectories(basePath: String)(implicit spark: SparkSession): Seq[String] = { + try { + val fs = new Path(basePath).getFileSystem(spark.sparkContext.hadoopConfiguration) + val path = new Path(basePath) + if (!fs.exists(path)) return Seq.empty + + fs.listStatus(path) + .filter(_.isDirectory) + .map(_.getPath.getName) + .toSeq + } catch { + case _: Exception => Seq.empty + } + } +}