diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java index 77044c6202fbe..12ba6dd33c23a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java @@ -20,8 +20,10 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.catalog.CatalogPlugin; import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin; +import org.apache.spark.sql.connector.read.Scan; import java.io.Closeable; +import java.util.List; /** * Represents a transaction. @@ -66,6 +68,29 @@ public interface Transaction extends Closeable { */ void abort(); + /** + * Attempts to register a list of materialized scans against this transaction's read set. + *

+ * Spark calls this when considering reuse of a cached subtree during the transaction. The + * list contains every materialized scan in the candidate cached subtree. That may include + * scans that belong to catalogs other than this transaction's catalog. The connector can + * decide which scans to register or ignore. + *

+ * The connector decides whether reusing the cached snapshots is compatible with the + * transaction's isolation contract. It must either accept the cache entry (returning + * {@code true} after adding any of its own scans to the read set) or refuse it (returning + * {@code false} without modifying the read set). + *

+ * This method may be called multiple times during a single query with overlapping scan + * lists. Registering a scan that is already in the read set must be a no-op. + * + * @param scans Every materialized scan in the candidate cached subtree. + * @return true if the connector accepts reuse of the cache entry; false otherwise. + */ + default boolean registerScans(List scans) { + return false; + } + /** * Releases any resources held by this transaction. *

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 19371dcb94dec..4efb789e51c05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -28,9 +28,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataC import org.apache.spark.sql.catalyst.streaming.{StreamingSourceIdentifyingName, Unassigned} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability, TableCatalog, V2TableUtil} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Column, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability, TableCatalog, V2TableUtil} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper -import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} +import org.apache.spark.sql.connector.catalog.constraints.Constraint +import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference, Transform} import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics, SupportsRuntimeV2Filtering} import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram => V2Histogram, HistogramBin => V2HistogramBin} import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream} @@ -122,6 +123,21 @@ case class DataSourceV2Relation( copy(output = output.map(_.newInstance())) } + // Replace the `table` field with an identity placeholder. This allows two relations + // that point at the same logical table compare equal even when one is wrapped (e.g. + // by a transaction-aware catalog). Note, when `Table.id()` is null we key only on catalog + // and identifier, thus, drop-and-recreate is not distinguished. + override def doCanonicalize(): LogicalPlan = { + val base = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] + val catalogName = base.catalog.map(_.name()) + (catalogName, base.identifier) match { + case (Some(cn), Some(ident)) => + base.copy(table = DataSourceV2Relation.CanonicalTableKey( + cn, ident, Option(base.table.id()))) + case _ => base + } + } + override lazy val metadataOutput: Seq[AttributeReference] = table match { case hasMeta: SupportsMetadataColumns => metadataOutputWithOutConflicts( @@ -320,6 +336,28 @@ object ExtractV2ScanInfo { } object DataSourceV2Relation { + + // Substituted into a canonicalized `DataSourceV2Relation.table` field so that two + // relations targeting the same metastore entity compare equal regardless of which concrete + // `Table` instance backs each side. Should only appear in canonicalized plans. + private[v2] case class CanonicalTableKey( + catalogName: String, + identifier: Identifier, + idOpt: Option[String]) extends Table { + override def id(): String = idOpt.orNull + override def name(): String = s"$catalogName.$identifier" + override def capabilities(): java.util.Set[TableCapability] = throwExecutionAccess() + override def columns(): Array[Column] = throwExecutionAccess() + override def partitioning(): Array[Transform] = throwExecutionAccess() + override def properties(): java.util.Map[String, String] = throwExecutionAccess() + override def constraints(): Array[Constraint] = throwExecutionAccess() + override def version(): String = throwExecutionAccess() + + private def throwExecutionAccess(): Nothing = + throw SparkException.internalError( + "CanonicalTableKey is canonicalization-only and must not appear in execution plans") + } + def create( table: Table, catalog: Option[CatalogPlugin], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index f582f3e408cb6..ec92691811a36 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -501,6 +501,7 @@ abstract class InMemoryBaseTable( if (evaluableFilters.nonEmpty) { scan.filter(evaluableFilters) } + scan.pushedFilters = _pushedFilters recordScanEvent(_pushedFilters) scan } @@ -665,6 +666,18 @@ abstract class InMemoryBaseTable( options: CaseInsensitiveStringMap) extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeFiltering { + // Snapshot of the table version when this scan was built. + val builtAtTableVersion: Int = InMemoryBaseTable.this.tableVersion + + // The current table version, read fresh on each access. + def currentTableVersion: Int = InMemoryBaseTable.this.tableVersion + + // Back-pointer to the table this scan was built against. + val table: InMemoryBaseTable = InMemoryBaseTable.this + + // The filters pushed to this scan at build time. + var pushedFilters: Array[Filter] = Array.empty + override def filterAttributes(): Array[NamedReference] = { val scanFields = readSchema.fields.map(_.name).toSet partitioning.flatMap(_.references) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index 81f976dce510f..f544963f57784 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -35,11 +35,18 @@ class InMemoryRowLevelOperationTableCatalog // All transactions in order (committed and aborted), allowing per-statement // validation in SQL scripting tests. val observedTransactions: ArrayBuffer[Txn] = new ArrayBuffer[Txn]() + // Test-only knob. When true, the next transaction created by `beginTransaction` will defer + // to the default `Transaction.registerScans` behavior (always returns false). The flag is + // reset to false after being consumed. + var nextTxnUsesDefaultRegisterScans: Boolean = false override def beginTransaction(info: TransactionInfo): Transaction = { assert(transaction == null || transaction.currentState != Active) - this.transaction = new Txn(new TxnTableCatalog(this)) - transaction + val txn = new Txn(new TxnTableCatalog(this)) + txn.useDefaultRegisterScans = nextTxnUsesDefaultRegisterScans + nextTxnUsesDefaultRegisterScans = false + this.transaction = txn + txn } override def createTable(ident: Identifier, tableInfo: TableInfo): Table = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 4b9dff5c3d780..a8b126684ced3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import org.apache.spark.sql.connector.catalog.transactions.Transaction +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.connector.write.{LogicalWriteInfo, RowLevelOperationBuilder, RowLevelOperationInfo, WriteBuilder} import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -39,10 +40,50 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction { private[this] var state: TransactionState = Active private[this] var closed: Boolean = false + // Records every batch of scans the connector accepted via registerScans. Tests assert on this + // to confirm cache substitution went through the txn path. + val registeredScans: ArrayBuffer[Seq[Scan]] = ArrayBuffer.empty + + // Test-only switch. When true, `registerScans` mimics the default Transaction interface + // behavior (always returns false) instead of running the version-equality check. Used to + // verify Spark's cache-bypass behavior when a connector hasn't implemented `registerScans`. + var useDefaultRegisterScans: Boolean = false + def currentState: TransactionState = state def isClosed: Boolean = closed + // Accept the batch if relevant scans were built against the table at exactly the + // version the table is at now. A real connector could be more permissive. For example, it + // could accept older snapshots and record them in the read set for commit-time conflict + // detection. + // + // The scan list may include foreign scans that originated in other catalogs. We + // identify our own scans by structural type AND by table identity an InMemoryBatchScan + // whose underlying table is one this catalog instance is tracking. Foreign scans are + // non-transactional from our perspective and are ignored. + override def registerScans(javaScans: java.util.List[Scan]): Boolean = { + if (useDefaultRegisterScans) return false + + val scans = javaScans.asScala.toSeq + val myScans = scans.collect { + case s: InMemoryBaseTable#InMemoryBatchScan + if catalog.txnTables.values.exists(_.delegate eq s.table) => s + } + val accepted = myScans.forall { s => + s.builtAtTableVersion == s.currentTableVersion + } + if (accepted) { + registeredScans += scans + myScans.foreach { s => + catalog.txnTables.values + .find(_.delegate eq s.table) + .foreach(_.scanEvents += s.pushedFilters) + } + } + accepted + } + override def commit(): Unit = { if (closed) throw new IllegalStateException("Can't commit, already closed") if (state == Aborted) throw new IllegalStateException("Can't commit, already aborted") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2RelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2RelationSuite.scala index 10ec9efca4aba..88190d3c3db5a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2RelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2RelationSuite.scala @@ -19,9 +19,12 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{Histogram, HistogramBin} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, Table, TableCapability} import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap class DataSourceV2RelationSuite extends SparkFunSuite { @@ -43,7 +46,7 @@ class DataSourceV2RelationSuite extends SparkFunSuite { rowCount = Some(10), colStats = Map( "id" -> idColStat, - // "extra" is not in schema — should be silently skipped + // "extra" is not in schema, should be silently skipped "extra" -> CatalogColumnStat(distinctCount = Some(5)))) val v2Stats = DataSourceV2Relation.v1StatsToV2Stats(catalogStats, schema) @@ -107,4 +110,101 @@ class DataSourceV2RelationSuite extends SparkFunSuite { val idV2NoHist = colStats.get(FieldReference.column("id")) assert(!idV2NoHist.histogram().isPresent) } + + test("canonicalize matches across Table wrappers when catalog and identifier agree") { + val schema = StructType(Seq(StructField("a", IntegerType))) + val output = Seq(AttributeReference("a", IntegerType)()) + val catalog = Some(new StubCatalog("cat")) + val ident = Some(Identifier.of(Array("ns"), "t")) + val opts = CaseInsensitiveStringMap.empty() + + // Same id on both sides, strict path. Different Table instances, same id => equal. + val a1 = DataSourceV2Relation(new StubTable("a", schema, id = "id-1"), + output, catalog, ident, opts) + val a2 = DataSourceV2Relation(new StubTable("a-wrapped", schema, id = "id-1"), + output, catalog, ident, opts) + assert(a1.sameResult(a2), + "two relations to the same logical table should canonicalize equal") + + // Null id on both sides, permissive path. Same catalog+identifier => still equal. + val b1 = DataSourceV2Relation(new StubTable("b", schema, id = null), + output, catalog, ident, opts) + val b2 = DataSourceV2Relation(new StubTable("b-wrapped", schema, id = null), + output, catalog, ident, opts) + assert(b1.sameResult(b2), + "relations with null id should canonicalize equal when catalog+identifier match") + + // Mixed: one side has an id, the other null. Treated as different (couldn't both be the same + // logical table for any connector that exposes id() consistently). + val c1 = DataSourceV2Relation(new StubTable("c", schema, id = "id-1"), + output, catalog, ident, opts) + val c2 = DataSourceV2Relation(new StubTable("c", schema, id = null), + output, catalog, ident, opts) + assert(!c1.sameResult(c2), + "id present vs null should not compare equal") + + // Different non-null ids, drop+recreate scenario. Must compare unequal. + val d1 = DataSourceV2Relation(new StubTable("d", schema, id = "id-1"), + output, catalog, ident, opts) + val d2 = DataSourceV2Relation(new StubTable("d", schema, id = "id-2"), + output, catalog, ident, opts) + assert(!d1.sameResult(d2), + "different ids (drop+recreate) should not compare equal") + + // Different identifier, must compare unequal even if ids would otherwise match. + val otherIdent = Some(Identifier.of(Array("ns"), "other")) + val e1 = DataSourceV2Relation(new StubTable("e", schema, id = "id-1"), + output, catalog, ident, opts) + val e2 = DataSourceV2Relation(new StubTable("e", schema, id = "id-1"), + output, catalog, otherIdent, opts) + assert(!e1.sameResult(e2), + "different identifiers should not compare equal") + + // Different catalog, must compare unequal even with same identifier and id. + val otherCatalog = Some(new StubCatalog("other")) + val f1 = DataSourceV2Relation(new StubTable("f", schema, id = "id-1"), + output, catalog, ident, opts) + val f2 = DataSourceV2Relation(new StubTable("f", schema, id = "id-1"), + output, otherCatalog, ident, opts) + assert(!f1.sameResult(f2), + "different catalog names should not compare equal") + } + + test("canonicalize falls back when catalog or identifier is missing") { + val schema = StructType(Seq(StructField("a", IntegerType))) + val output = Seq(AttributeReference("a", IntegerType)()) + val ident = Some(Identifier.of(Array("ns"), "t")) + val opts = CaseInsensitiveStringMap.empty() + + // No catalog: falls through to the default canonical form. Equality reduces to reference + // equality on the `table` field, same instance is equal, different instances are not. + val sharedTable = new StubTable("shared", schema, id = "id-1") + val sameInstance1 = DataSourceV2Relation(sharedTable, output, None, ident, opts) + val sameInstance2 = DataSourceV2Relation(sharedTable, output, None, ident, opts) + assert(sameInstance1.sameResult(sameInstance2), + "fallback path: same Table instance compares equal") + + val diff1 = DataSourceV2Relation(new StubTable("x", schema, id = "id-1"), + output, None, ident, opts) + val diff2 = DataSourceV2Relation(new StubTable("x", schema, id = "id-1"), + output, None, ident, opts) + assert(!diff1.sameResult(diff2), + "fallback path: different Table instances are not equal even with matching id") + } +} + +private class StubCatalog(catalogName: String) extends CatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = () + override def name(): String = catalogName +} + +private class StubTable( + tableName: String, + schema: StructType, + id: String) extends Table { + override def name(): String = tableName + override def id(): String = id + override def columns(): Array[org.apache.spark.sql.connector.catalog.Column] = + org.apache.spark.sql.connector.catalog.CatalogV2Util.structTypeToV2Columns(schema) + override def capabilities(): java.util.Set[TableCapability] = java.util.Set.of() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 3f92f24156d3c..d590ebc2286ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -35,11 +36,12 @@ import org.apache.spark.sql.classic.{Dataset, SparkSession} import org.apache.spark.sql.connector.catalog.CatalogPlugin import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper} import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation, LogicalRelationWithTable} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2CatalogAndIdentifier, ExtractV2Table, FileTable, V2TableRefreshUtil} +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, ExtractV2CatalogAndIdentifier, ExtractV2Table, FileTable, V2TableRefreshUtil} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK @@ -479,8 +481,13 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { lookupCachedDataInternal(normalized) } - private def lookupCachedDataInternal(plan: LogicalPlan): Option[CachedData] = { - val result = cachedData.find(cd => plan.sameResult(cd.plan)) + private def lookupCachedDataInternal( + plan: LogicalPlan, + transactionOpt: Option[Transaction] = None): Option[CachedData] = { + val result = cachedData.find { cd => + plan.sameResult(cd.plan) && + transactionOpt.forall(txn => validateCachedEntryForTransaction(cd, txn)) + } if (result.isDefined) { CacheManager.logCacheOperation(log"Dataframe cache hit for input plan:" + log"\n${MDC(QUERY_PLAN, plan)} matched with cache entry:" + @@ -489,16 +496,34 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { result } + // Decides whether the cached entry can be substituted into a plan being executed inside + // the given transaction. It identifies the scans and attempts to register them in the connector. + // The connector returns true if reusing the cached snapshot is consistent with its isolation + // contract. Note, the scans might belong to different catalogs. The connector can decide which + // one to register and which ones to ignore. + private def validateCachedEntryForTransaction(cd: CachedData, txn: Transaction): Boolean = { + val scans = collectWithSubqueries(cd.cachedRepresentation.cacheBuilder.cachedPlan) { + case b: BatchScanExec => b.scan + } + scans.isEmpty || txn.registerScans(scans.asJava) + } + /** * Replaces segments of the given logical plan with cached versions where possible. The input * plan must be normalized. + * + * @param plan the plan to rewrite. + * @param transactionOpt if defined, each candidate cache hit is validated against the + * transaction's isolation contract (via `Transaction.registerScans`). */ - private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = { + private[sql] def useCachedData( + plan: LogicalPlan, + transactionOpt: Option[Transaction] = None): LogicalPlan = { val newPlan = plan transformDown { case command: Command => command case currentFragment => - lookupCachedDataInternal(currentFragment).map { cached => + lookupCachedDataInternal(currentFragment, transactionOpt).map { cached => // After cache lookup, we should still keep the hints from the input plan. val hints = EliminateResolvedHint.extractHintsFromPlan(currentFragment)._2 val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) @@ -511,7 +536,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } val result = newPlan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { - case s: SubqueryExpression => s.withNewPlan(useCachedData(s.plan)) + case s: SubqueryExpression => s.withNewPlan(useCachedData(s.plan, transactionOpt)) } if (result.fastEquals(plan)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 65bc57de907b2..7f43d14cb7827 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -329,16 +329,11 @@ class QueryExecution( assertAnalyzed() assertSupported() - // During a transaction, skip cache substitution. This is to avoid replacing relations - // loaded by the transactional catalog with potentially stale relations cached before - // the transaction was active. - if (transactionOpt.isDefined) { - normalized - } else { - // Clone the plan to avoid sharing the plan instance between different stages like - // analyzing, optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) - } + // Clone the plan to avoid sharing the plan instance between different stages like + // analyzing, optimizing and planning. + sparkSession.sharedState.cacheManager.useCachedData( + normalized.clone(), + transactionOpt) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala index d2dfae13b511c..8cc81062d982e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.{sources, Column, Row} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.classic.MergeIntoWriter -import org.apache.spark.sql.connector.catalog.{Aborted, Committed} +import org.apache.spark.sql.connector.catalog.{Aborted, Committed, InMemoryBaseTable} import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableInfo import org.apache.spark.sql.functions._ @@ -1233,4 +1233,266 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { assert(e.message.contains("incompatible changes to table `cat`.`ns1`.`source_table`")) } } + + // The DataFrame cache is bypassed during a transaction unless the connector approves the cached + // scans via Transaction#registerScans. The next two tests exercise that path against the merge + // source, a common case where users cache an expensive source DataFrame before running MERGE. + // The test connector's registerScans accepts when the cached scan's table version matches the + // current version (no intervening writes) and refuses otherwise. + + test("cached merge source is reused when the table is unchanged since caching") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val tableVersionBeforeCache = table.version() + + // Cache a derived view of the target table. The InMemoryBatchScan stamps the current table + // version at build time; that version is what the connector compares against later. + val sourceDF = spark.table(tableNameAsString).where("salary < 250").as("source") + sourceDF.cache() + sourceDF.count() + + assert(table.version() == tableVersionBeforeCache, + "sanity: caching a read should not bump the table version") + + val (txn, txnTables) = executeTransaction { + sourceDF + .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk")) + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .merge() + } + + assert(txn.currentState == Committed) + + // Cached scan's version == current version -> registerScans accepted -> cache substituted. + // The non-empty registeredScans buffer is the witness that the connector ran the + // cache-acceptance path (not a fresh scan-builder path). + assert(txn.registeredScans.nonEmpty, + "registerScans should have been called and accepted (no writes since cache)") + + // Even though the cache was used, registerScans recorded the read on the target's TxnTable + // so that scanEvents uniformly captures cached and fresh reads. The MERGE plan references + // the source twice (matched and not-matched branches), so the source filter appears twice + // in scanEvents, whether contributed by registerScans (cache path) or by fresh scans + // (rescan path). The `registeredScans` non-empty assertion above is what distinguishes + // the two cases. + val targetTxnTable = txnTables(tableNameAsString) + val sourceFilterScans = targetTxnTable.scanEvents.flatten.count { + case sources.LessThan("salary", 250) => true + case _ => false + } + assert(sourceFilterScans == 2, + s"expected two salary<250 scan events, got " + + s"${targetTxnTable.scanEvents.map(_.toSeq).mkString("[", ", ", "]")}") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), + Row(2, 201, "software"), + Row(3, 300, "hr"))) + } + + test("cached merge source is dropped when the table version moves on between caching and MERGE") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // Cache the source at the table's current version. + val sourceDF = spark.table(tableNameAsString).where("salary < 250").as("source") + sourceDF.cache() + sourceDF.count() + val versionAtCache = table.version() + + // Simulate an external committer bumping the table version between caching and the MERGE. + // A write done through Spark (e.g. another `append`) would also bump the version, but it + // would additionally trigger CacheManager.refreshCache, which re-materializes our cached + // InMemoryRelation against the new snapshot, defeating the test. The `registerScans` + // path is meant to defend against writes Spark didn't observe (other clusters, other + // sessions, direct table mutations), so we bump the version directly. + // + // NOTE on fidelity: the in-memory test catalog does not separate "metastore version" from + // "loaded snapshot version",`InMemoryBaseTable.tableVersion` is shared mutable state, so + // a scan's `currentTableVersion` immediately reflects this bump. A real DSv2 Delta + // connector holds a Table loaded at V1 and would not observe an external bump to V2 unless + // it explicitly refreshed from the metastore (per go/spark-refresh design doc, Section 5 + // Scenario 1). The test therefore exercises a stricter connector policy than Delta's + // default, namely one that polices version drift on every registerScans call. A more + // faithful model would capture the txn's pinned snapshot in TxnTable at txn-begin and + // compare scan.builtAt against that pinned value; see follow-up notes in memory. + table.increaseVersion() + assert(table.version() != versionAtCache, "sanity: bump should change the version") + + val (txn, txnTables) = executeTransaction { + sourceDF + .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk")) + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .merge() + } + + assert(txn.currentState == Committed) + // Cached scan's builtAtTableVersion != currentTableVersion -> registerScans refused; nothing + // recorded. + assert(txn.registeredScans.isEmpty, + "registerScans should have refused the stale cached scan") + + // The cached source was not substituted, so the source's salary<250 filter was re-evaluated + // through the txn catalog and shows up as a scan event on the target's TxnTable. + val targetTxnTable = txnTables(tableNameAsString) + val sourceFilterScanned = targetTxnTable.scanEvents.flatten.exists { + case sources.LessThan("salary", 250) => true + case _ => false + } + assert(sourceFilterScanned, + s"expected the source's salary<250 filter to appear as a scan event after cache bypass, " + + s"got ${targetTxnTable.scanEvents.map(_.toSeq).mkString("[", ", ", "]")}") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), + Row(2, 201, "software"), + Row(3, 300, "hr"))) + } + + test("cached source from outside the txn catalog is reused without consulting registerScans") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |""".stripMargin) + + // Build the source from spark.range, no DataSourceV2Relation, no catalog reference at all. + // Inside the txn, validateCachedEntryForTransaction's txnTables set is empty for this cached + // entry, so the short-circuit accepts the cache without ever consulting registerScans. + val sourceDF = spark.range(2) + .select( + (col("id") + 1).cast(IntegerType).as("pk"), + lit(999).cast(IntegerType).as("salary"), + lit("hr").cast(StringType).as("dep")) + .as("source") + sourceDF.cache() + sourceDF.count() + + val (txn, _) = executeTransaction { + sourceDF + .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk")) + .whenMatched() + .update(Map("salary" -> $"source.salary")) + .merge() + } + + assert(txn.currentState == Committed) + assert(txn.registeredScans.isEmpty, + "registerScans should not be consulted when the cached subtree has no txn-catalog reads") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 999, "hr"), + Row(2, 999, "software"))) + } + + test("cached relation inside an IN-subquery is substituted via useCachedData recursion") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + // Create + populate a lookup table in the same txn catalog. We cache a DataFrame against + // this lookup, then drive a MERGE whose source filters by an IN-subquery against the + // cached lookup. The cache substitution can only happen via useCachedData's recursive + // descent into SubqueryExpression, transformDown alone would not find the cached subtree + // hidden inside the IN expression. + val lookupName = "cat.ns1.cache_lookup" + withTable(lookupName) { + sql(s"CREATE TABLE $lookupName (pk INT) USING foo") + sql(s"INSERT INTO $lookupName VALUES (1), (2)") + + val lookupDF = spark.table(lookupName) + lookupDF.cache() + lookupDF.count() + lookupDF.createOrReplaceTempView("cache_lookup_view") + + val sourceDF = spark.table(tableNameAsString) + .where("pk IN (SELECT pk FROM cache_lookup_view)") + .as("source") + + val (txn, _) = executeTransaction { + sourceDF + .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk")) + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .merge() + } + + assert(txn.currentState == Committed) + // The lookup is the only cached entry in this plan, and it sits inside an IN-subquery , + // so the only way the registered scans can include it is via useCachedData's recursive + // descent into SubqueryExpression. Every registered scan must point at the lookup. + val lookupTable = catalog.loadTable(Identifier.of(Array("ns1"), "cache_lookup")) + assert(txn.registeredScans.flatten.collect { + case s: InMemoryBaseTable#InMemoryBatchScan => s.table + }.distinct == Seq(lookupTable)) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), + Row(2, 201, "software"), + Row(3, 300, "hr"))) + } + } + + test("default Transaction.registerScans (returns false) causes cache bypass") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val sourceDF = spark.table(tableNameAsString).where("salary < 250").as("source") + sourceDF.cache() + sourceDF.count() + + // Force the next transaction to use the default Transaction.registerScans behavior, which + // returns false unconditionally. Spark must then bypass the cached source entirely. + catalog.nextTxnUsesDefaultRegisterScans = true + + val (txn, txnTables) = executeTransaction { + sourceDF + .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk")) + .whenMatched() + .update(Map("salary" -> targetTableCol("salary").plus(1))) + .merge() + } + + assert(txn.currentState == Committed) + // The default-impl refusal is observable: registerScans was called but accepted nothing, + // so the buffer stays empty. + assert(txn.registeredScans.isEmpty, + "default registerScans should refuse the cache") + + // Cache bypass means the source's filter is re-evaluated through the txn catalog, leaving + // a scan event on the target's TxnTable. + val targetTxnTable = txnTables(tableNameAsString) + assert(targetTxnTable.scanEvents.flatten.exists { + case sources.LessThan("salary", 250) => true + case _ => false + }, "expected salary<250 to appear as a scan event after cache bypass") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 101, "hr"), + Row(2, 201, "software"), + Row(3, 300, "hr"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index f37a614f99b53..126b84b507caf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -1198,10 +1198,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(txn.currentState == Committed) assert(txn.isClosed) - // both target and source must have been read through the transaction catalog - assert(txnTables.size == 2) + // The source TxnTable is still created during analysis (the txn catalog routes the + // load), but cache substitution reuses the cached scan instead of issuing a fresh + // BatchScanExec for the source, so only the target appears in the executed plan. + assert(txn.catalog.txnTables.size == 2) + assert(txnTables.size == 1) + assert(txnTables.contains(tableNameAsString)) assert(table.version() == "2") - assert(txnTables(sourceNameAsString).scanEvents.nonEmpty) + // The connector accepted the cached source scan via registerScans, which also + // records the scan as a read event against the source's TxnTable. + assert(txn.registeredScans.nonEmpty) + val sourceTxnTable = txn.catalog.txnTables.values.find(_.name == sourceNameAsString).get + assert(sourceTxnTable.scanEvents.nonEmpty) assert(txnTables(tableNameAsString).scanEvents.nonEmpty) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 379e7ba755d9d..7c2ef1f3eab31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -51,6 +51,7 @@ abstract class RowLevelOperationSuiteBase } after { + catalog.nextTxnUsesDefaultRegisterScans = false spark.sessionState.catalogManager.reset() spark.sessionState.conf.unsetConf("spark.sql.catalog.cat") }