diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java index 77044c6202fbe..12ba6dd33c23a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/transactions/Transaction.java @@ -20,8 +20,10 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.catalog.CatalogPlugin; import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin; +import org.apache.spark.sql.connector.read.Scan; import java.io.Closeable; +import java.util.List; /** * Represents a transaction. @@ -66,6 +68,29 @@ public interface Transaction extends Closeable { */ void abort(); + /** + * Attempts to register a list of materialized scans against this transaction's read set. + *
+ * Spark calls this when considering reuse of a cached subtree during the transaction. The + * list contains every materialized scan in the candidate cached subtree. That may include + * scans that belong to catalogs other than this transaction's catalog. The connector can + * decide which scans to register or ignore. + *
+ * The connector decides whether reusing the cached snapshots is compatible with the + * transaction's isolation contract. It must either accept the cache entry (returning + * {@code true} after adding any of its own scans to the read set) or refuse it (returning + * {@code false} without modifying the read set). + *
+ * This method may be called multiple times during a single query with overlapping scan
+ * lists. Registering a scan that is already in the read set must be a no-op.
+ *
+ * @param scans Every materialized scan in the candidate cached subtree.
+ * @return true if the connector accepts reuse of the cache entry; false otherwise.
+ */
+ default boolean registerScans(List
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index 19371dcb94dec..4efb789e51c05 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -28,9 +28,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataC
import org.apache.spark.sql.catalyst.streaming.{StreamingSourceIdentifyingName, Unassigned}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils}
-import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability, TableCatalog, V2TableUtil}
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Column, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability, TableCatalog, V2TableUtil}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
-import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference}
+import org.apache.spark.sql.connector.catalog.constraints.Constraint
+import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference, Transform}
import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics, SupportsRuntimeV2Filtering}
import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram => V2Histogram, HistogramBin => V2HistogramBin}
import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream}
@@ -122,6 +123,21 @@ case class DataSourceV2Relation(
copy(output = output.map(_.newInstance()))
}
+ // Replace the `table` field with an identity placeholder. This allows two relations
+ // that point at the same logical table compare equal even when one is wrapped (e.g.
+ // by a transaction-aware catalog). Note, when `Table.id()` is null we key only on catalog
+ // and identifier, thus, drop-and-recreate is not distinguished.
+ override def doCanonicalize(): LogicalPlan = {
+ val base = super.doCanonicalize().asInstanceOf[DataSourceV2Relation]
+ val catalogName = base.catalog.map(_.name())
+ (catalogName, base.identifier) match {
+ case (Some(cn), Some(ident)) =>
+ base.copy(table = DataSourceV2Relation.CanonicalTableKey(
+ cn, ident, Option(base.table.id())))
+ case _ => base
+ }
+ }
+
override lazy val metadataOutput: Seq[AttributeReference] = table match {
case hasMeta: SupportsMetadataColumns =>
metadataOutputWithOutConflicts(
@@ -320,6 +336,28 @@ object ExtractV2ScanInfo {
}
object DataSourceV2Relation {
+
+ // Substituted into a canonicalized `DataSourceV2Relation.table` field so that two
+ // relations targeting the same metastore entity compare equal regardless of which concrete
+ // `Table` instance backs each side. Should only appear in canonicalized plans.
+ private[v2] case class CanonicalTableKey(
+ catalogName: String,
+ identifier: Identifier,
+ idOpt: Option[String]) extends Table {
+ override def id(): String = idOpt.orNull
+ override def name(): String = s"$catalogName.$identifier"
+ override def capabilities(): java.util.Set[TableCapability] = throwExecutionAccess()
+ override def columns(): Array[Column] = throwExecutionAccess()
+ override def partitioning(): Array[Transform] = throwExecutionAccess()
+ override def properties(): java.util.Map[String, String] = throwExecutionAccess()
+ override def constraints(): Array[Constraint] = throwExecutionAccess()
+ override def version(): String = throwExecutionAccess()
+
+ private def throwExecutionAccess(): Nothing =
+ throw SparkException.internalError(
+ "CanonicalTableKey is canonicalization-only and must not appear in execution plans")
+ }
+
def create(
table: Table,
catalog: Option[CatalogPlugin],
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index f582f3e408cb6..ec92691811a36 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -501,6 +501,7 @@ abstract class InMemoryBaseTable(
if (evaluableFilters.nonEmpty) {
scan.filter(evaluableFilters)
}
+ scan.pushedFilters = _pushedFilters
recordScanEvent(_pushedFilters)
scan
}
@@ -665,6 +666,18 @@ abstract class InMemoryBaseTable(
options: CaseInsensitiveStringMap)
extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeFiltering {
+ // Snapshot of the table version when this scan was built.
+ val builtAtTableVersion: Int = InMemoryBaseTable.this.tableVersion
+
+ // The current table version, read fresh on each access.
+ def currentTableVersion: Int = InMemoryBaseTable.this.tableVersion
+
+ // Back-pointer to the table this scan was built against.
+ val table: InMemoryBaseTable = InMemoryBaseTable.this
+
+ // The filters pushed to this scan at build time.
+ var pushedFilters: Array[Filter] = Array.empty
+
override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
partitioning.flatMap(_.references)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
index 81f976dce510f..f544963f57784 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala
@@ -35,11 +35,18 @@ class InMemoryRowLevelOperationTableCatalog
// All transactions in order (committed and aborted), allowing per-statement
// validation in SQL scripting tests.
val observedTransactions: ArrayBuffer[Txn] = new ArrayBuffer[Txn]()
+ // Test-only knob. When true, the next transaction created by `beginTransaction` will defer
+ // to the default `Transaction.registerScans` behavior (always returns false). The flag is
+ // reset to false after being consumed.
+ var nextTxnUsesDefaultRegisterScans: Boolean = false
override def beginTransaction(info: TransactionInfo): Transaction = {
assert(transaction == null || transaction.currentState != Active)
- this.transaction = new Txn(new TxnTableCatalog(this))
- transaction
+ val txn = new Txn(new TxnTableCatalog(this))
+ txn.useDefaultRegisterScans = nextTxnUsesDefaultRegisterScans
+ nextTxnUsesDefaultRegisterScans = false
+ this.transaction = txn
+ txn
}
override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala
index 4b9dff5c3d780..a8b126684ced3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import org.apache.spark.sql.connector.catalog.transactions.Transaction
+import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, RowLevelOperationBuilder, RowLevelOperationInfo, WriteBuilder}
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
@@ -39,10 +40,50 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction {
private[this] var state: TransactionState = Active
private[this] var closed: Boolean = false
+ // Records every batch of scans the connector accepted via registerScans. Tests assert on this
+ // to confirm cache substitution went through the txn path.
+ val registeredScans: ArrayBuffer[Seq[Scan]] = ArrayBuffer.empty
+
+ // Test-only switch. When true, `registerScans` mimics the default Transaction interface
+ // behavior (always returns false) instead of running the version-equality check. Used to
+ // verify Spark's cache-bypass behavior when a connector hasn't implemented `registerScans`.
+ var useDefaultRegisterScans: Boolean = false
+
def currentState: TransactionState = state
def isClosed: Boolean = closed
+ // Accept the batch if relevant scans were built against the table at exactly the
+ // version the table is at now. A real connector could be more permissive. For example, it
+ // could accept older snapshots and record them in the read set for commit-time conflict
+ // detection.
+ //
+ // The scan list may include foreign scans that originated in other catalogs. We
+ // identify our own scans by structural type AND by table identity an InMemoryBatchScan
+ // whose underlying table is one this catalog instance is tracking. Foreign scans are
+ // non-transactional from our perspective and are ignored.
+ override def registerScans(javaScans: java.util.List[Scan]): Boolean = {
+ if (useDefaultRegisterScans) return false
+
+ val scans = javaScans.asScala.toSeq
+ val myScans = scans.collect {
+ case s: InMemoryBaseTable#InMemoryBatchScan
+ if catalog.txnTables.values.exists(_.delegate eq s.table) => s
+ }
+ val accepted = myScans.forall { s =>
+ s.builtAtTableVersion == s.currentTableVersion
+ }
+ if (accepted) {
+ registeredScans += scans
+ myScans.foreach { s =>
+ catalog.txnTables.values
+ .find(_.delegate eq s.table)
+ .foreach(_.scanEvents += s.pushedFilters)
+ }
+ }
+ accepted
+ }
+
override def commit(): Unit = {
if (closed) throw new IllegalStateException("Can't commit, already closed")
if (state == Aborted) throw new IllegalStateException("Can't commit, already aborted")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2RelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2RelationSuite.scala
index 10ec9efca4aba..88190d3c3db5a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2RelationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2RelationSuite.scala
@@ -19,9 +19,12 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics}
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{Histogram, HistogramBin}
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, Table, TableCapability}
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
class DataSourceV2RelationSuite extends SparkFunSuite {
@@ -43,7 +46,7 @@ class DataSourceV2RelationSuite extends SparkFunSuite {
rowCount = Some(10),
colStats = Map(
"id" -> idColStat,
- // "extra" is not in schema — should be silently skipped
+ // "extra" is not in schema, should be silently skipped
"extra" -> CatalogColumnStat(distinctCount = Some(5))))
val v2Stats = DataSourceV2Relation.v1StatsToV2Stats(catalogStats, schema)
@@ -107,4 +110,101 @@ class DataSourceV2RelationSuite extends SparkFunSuite {
val idV2NoHist = colStats.get(FieldReference.column("id"))
assert(!idV2NoHist.histogram().isPresent)
}
+
+ test("canonicalize matches across Table wrappers when catalog and identifier agree") {
+ val schema = StructType(Seq(StructField("a", IntegerType)))
+ val output = Seq(AttributeReference("a", IntegerType)())
+ val catalog = Some(new StubCatalog("cat"))
+ val ident = Some(Identifier.of(Array("ns"), "t"))
+ val opts = CaseInsensitiveStringMap.empty()
+
+ // Same id on both sides, strict path. Different Table instances, same id => equal.
+ val a1 = DataSourceV2Relation(new StubTable("a", schema, id = "id-1"),
+ output, catalog, ident, opts)
+ val a2 = DataSourceV2Relation(new StubTable("a-wrapped", schema, id = "id-1"),
+ output, catalog, ident, opts)
+ assert(a1.sameResult(a2),
+ "two relations to the same logical table should canonicalize equal")
+
+ // Null id on both sides, permissive path. Same catalog+identifier => still equal.
+ val b1 = DataSourceV2Relation(new StubTable("b", schema, id = null),
+ output, catalog, ident, opts)
+ val b2 = DataSourceV2Relation(new StubTable("b-wrapped", schema, id = null),
+ output, catalog, ident, opts)
+ assert(b1.sameResult(b2),
+ "relations with null id should canonicalize equal when catalog+identifier match")
+
+ // Mixed: one side has an id, the other null. Treated as different (couldn't both be the same
+ // logical table for any connector that exposes id() consistently).
+ val c1 = DataSourceV2Relation(new StubTable("c", schema, id = "id-1"),
+ output, catalog, ident, opts)
+ val c2 = DataSourceV2Relation(new StubTable("c", schema, id = null),
+ output, catalog, ident, opts)
+ assert(!c1.sameResult(c2),
+ "id present vs null should not compare equal")
+
+ // Different non-null ids, drop+recreate scenario. Must compare unequal.
+ val d1 = DataSourceV2Relation(new StubTable("d", schema, id = "id-1"),
+ output, catalog, ident, opts)
+ val d2 = DataSourceV2Relation(new StubTable("d", schema, id = "id-2"),
+ output, catalog, ident, opts)
+ assert(!d1.sameResult(d2),
+ "different ids (drop+recreate) should not compare equal")
+
+ // Different identifier, must compare unequal even if ids would otherwise match.
+ val otherIdent = Some(Identifier.of(Array("ns"), "other"))
+ val e1 = DataSourceV2Relation(new StubTable("e", schema, id = "id-1"),
+ output, catalog, ident, opts)
+ val e2 = DataSourceV2Relation(new StubTable("e", schema, id = "id-1"),
+ output, catalog, otherIdent, opts)
+ assert(!e1.sameResult(e2),
+ "different identifiers should not compare equal")
+
+ // Different catalog, must compare unequal even with same identifier and id.
+ val otherCatalog = Some(new StubCatalog("other"))
+ val f1 = DataSourceV2Relation(new StubTable("f", schema, id = "id-1"),
+ output, catalog, ident, opts)
+ val f2 = DataSourceV2Relation(new StubTable("f", schema, id = "id-1"),
+ output, otherCatalog, ident, opts)
+ assert(!f1.sameResult(f2),
+ "different catalog names should not compare equal")
+ }
+
+ test("canonicalize falls back when catalog or identifier is missing") {
+ val schema = StructType(Seq(StructField("a", IntegerType)))
+ val output = Seq(AttributeReference("a", IntegerType)())
+ val ident = Some(Identifier.of(Array("ns"), "t"))
+ val opts = CaseInsensitiveStringMap.empty()
+
+ // No catalog: falls through to the default canonical form. Equality reduces to reference
+ // equality on the `table` field, same instance is equal, different instances are not.
+ val sharedTable = new StubTable("shared", schema, id = "id-1")
+ val sameInstance1 = DataSourceV2Relation(sharedTable, output, None, ident, opts)
+ val sameInstance2 = DataSourceV2Relation(sharedTable, output, None, ident, opts)
+ assert(sameInstance1.sameResult(sameInstance2),
+ "fallback path: same Table instance compares equal")
+
+ val diff1 = DataSourceV2Relation(new StubTable("x", schema, id = "id-1"),
+ output, None, ident, opts)
+ val diff2 = DataSourceV2Relation(new StubTable("x", schema, id = "id-1"),
+ output, None, ident, opts)
+ assert(!diff1.sameResult(diff2),
+ "fallback path: different Table instances are not equal even with matching id")
+ }
+}
+
+private class StubCatalog(catalogName: String) extends CatalogPlugin {
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = ()
+ override def name(): String = catalogName
+}
+
+private class StubTable(
+ tableName: String,
+ schema: StructType,
+ id: String) extends Table {
+ override def name(): String = tableName
+ override def id(): String = id
+ override def columns(): Array[org.apache.spark.sql.connector.catalog.Column] =
+ org.apache.spark.sql.connector.catalog.CatalogV2Util.structTypeToV2Columns(schema)
+ override def capabilities(): java.util.Set[TableCapability] = java.util.Set.of()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 3f92f24156d3c..d590ebc2286ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
@@ -35,11 +36,12 @@ import org.apache.spark.sql.classic.{Dataset, SparkSession}
import org.apache.spark.sql.connector.catalog.CatalogPlugin
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper}
import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.transactions.Transaction
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation, LogicalRelationWithTable}
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2CatalogAndIdentifier, ExtractV2Table, FileTable, V2TableRefreshUtil}
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, ExtractV2CatalogAndIdentifier, ExtractV2Table, FileTable, V2TableRefreshUtil}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
@@ -479,8 +481,13 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
lookupCachedDataInternal(normalized)
}
- private def lookupCachedDataInternal(plan: LogicalPlan): Option[CachedData] = {
- val result = cachedData.find(cd => plan.sameResult(cd.plan))
+ private def lookupCachedDataInternal(
+ plan: LogicalPlan,
+ transactionOpt: Option[Transaction] = None): Option[CachedData] = {
+ val result = cachedData.find { cd =>
+ plan.sameResult(cd.plan) &&
+ transactionOpt.forall(txn => validateCachedEntryForTransaction(cd, txn))
+ }
if (result.isDefined) {
CacheManager.logCacheOperation(log"Dataframe cache hit for input plan:" +
log"\n${MDC(QUERY_PLAN, plan)} matched with cache entry:" +
@@ -489,16 +496,34 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
result
}
+ // Decides whether the cached entry can be substituted into a plan being executed inside
+ // the given transaction. It identifies the scans and attempts to register them in the connector.
+ // The connector returns true if reusing the cached snapshot is consistent with its isolation
+ // contract. Note, the scans might belong to different catalogs. The connector can decide which
+ // one to register and which ones to ignore.
+ private def validateCachedEntryForTransaction(cd: CachedData, txn: Transaction): Boolean = {
+ val scans = collectWithSubqueries(cd.cachedRepresentation.cacheBuilder.cachedPlan) {
+ case b: BatchScanExec => b.scan
+ }
+ scans.isEmpty || txn.registerScans(scans.asJava)
+ }
+
/**
* Replaces segments of the given logical plan with cached versions where possible. The input
* plan must be normalized.
+ *
+ * @param plan the plan to rewrite.
+ * @param transactionOpt if defined, each candidate cache hit is validated against the
+ * transaction's isolation contract (via `Transaction.registerScans`).
*/
- private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = {
+ private[sql] def useCachedData(
+ plan: LogicalPlan,
+ transactionOpt: Option[Transaction] = None): LogicalPlan = {
val newPlan = plan transformDown {
case command: Command => command
case currentFragment =>
- lookupCachedDataInternal(currentFragment).map { cached =>
+ lookupCachedDataInternal(currentFragment, transactionOpt).map { cached =>
// After cache lookup, we should still keep the hints from the input plan.
val hints = EliminateResolvedHint.extractHintsFromPlan(currentFragment)._2
val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output)
@@ -511,7 +536,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
}
val result = newPlan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
- case s: SubqueryExpression => s.withNewPlan(useCachedData(s.plan))
+ case s: SubqueryExpression => s.withNewPlan(useCachedData(s.plan, transactionOpt))
}
if (result.fastEquals(plan)) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 65bc57de907b2..7f43d14cb7827 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -329,16 +329,11 @@ class QueryExecution(
assertAnalyzed()
assertSupported()
- // During a transaction, skip cache substitution. This is to avoid replacing relations
- // loaded by the transactional catalog with potentially stale relations cached before
- // the transaction was active.
- if (transactionOpt.isDefined) {
- normalized
- } else {
- // Clone the plan to avoid sharing the plan instance between different stages like
- // analyzing, optimizing and planning.
- sparkSession.sharedState.cacheManager.useCachedData(normalized.clone())
- }
+ // Clone the plan to avoid sharing the plan instance between different stages like
+ // analyzing, optimizing and planning.
+ sparkSession.sharedState.cacheManager.useCachedData(
+ normalized.clone(),
+ transactionOpt)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
index d2dfae13b511c..8cc81062d982e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.connector
import org.apache.spark.sql.{sources, Column, Row}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.classic.MergeIntoWriter
-import org.apache.spark.sql.connector.catalog.{Aborted, Committed}
+import org.apache.spark.sql.connector.catalog.{Aborted, Committed, InMemoryBaseTable}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.TableInfo
import org.apache.spark.sql.functions._
@@ -1233,4 +1233,266 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase {
assert(e.message.contains("incompatible changes to table `cat`.`ns1`.`source_table`"))
}
}
+
+ // The DataFrame cache is bypassed during a transaction unless the connector approves the cached
+ // scans via Transaction#registerScans. The next two tests exercise that path against the merge
+ // source, a common case where users cache an expensive source DataFrame before running MERGE.
+ // The test connector's registerScans accepts when the cached scan's table version matches the
+ // current version (no intervening writes) and refuses otherwise.
+
+ test("cached merge source is reused when the table is unchanged since caching") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ val tableVersionBeforeCache = table.version()
+
+ // Cache a derived view of the target table. The InMemoryBatchScan stamps the current table
+ // version at build time; that version is what the connector compares against later.
+ val sourceDF = spark.table(tableNameAsString).where("salary < 250").as("source")
+ sourceDF.cache()
+ sourceDF.count()
+
+ assert(table.version() == tableVersionBeforeCache,
+ "sanity: caching a read should not bump the table version")
+
+ val (txn, txnTables) = executeTransaction {
+ sourceDF
+ .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk"))
+ .whenMatched()
+ .update(Map("salary" -> targetTableCol("salary").plus(1)))
+ .merge()
+ }
+
+ assert(txn.currentState == Committed)
+
+ // Cached scan's version == current version -> registerScans accepted -> cache substituted.
+ // The non-empty registeredScans buffer is the witness that the connector ran the
+ // cache-acceptance path (not a fresh scan-builder path).
+ assert(txn.registeredScans.nonEmpty,
+ "registerScans should have been called and accepted (no writes since cache)")
+
+ // Even though the cache was used, registerScans recorded the read on the target's TxnTable
+ // so that scanEvents uniformly captures cached and fresh reads. The MERGE plan references
+ // the source twice (matched and not-matched branches), so the source filter appears twice
+ // in scanEvents, whether contributed by registerScans (cache path) or by fresh scans
+ // (rescan path). The `registeredScans` non-empty assertion above is what distinguishes
+ // the two cases.
+ val targetTxnTable = txnTables(tableNameAsString)
+ val sourceFilterScans = targetTxnTable.scanEvents.flatten.count {
+ case sources.LessThan("salary", 250) => true
+ case _ => false
+ }
+ assert(sourceFilterScans == 2,
+ s"expected two salary<250 scan events, got " +
+ s"${targetTxnTable.scanEvents.map(_.toSeq).mkString("[", ", ", "]")}")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "hr"),
+ Row(2, 201, "software"),
+ Row(3, 300, "hr")))
+ }
+
+ test("cached merge source is dropped when the table version moves on between caching and MERGE") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // Cache the source at the table's current version.
+ val sourceDF = spark.table(tableNameAsString).where("salary < 250").as("source")
+ sourceDF.cache()
+ sourceDF.count()
+ val versionAtCache = table.version()
+
+ // Simulate an external committer bumping the table version between caching and the MERGE.
+ // A write done through Spark (e.g. another `append`) would also bump the version, but it
+ // would additionally trigger CacheManager.refreshCache, which re-materializes our cached
+ // InMemoryRelation against the new snapshot, defeating the test. The `registerScans`
+ // path is meant to defend against writes Spark didn't observe (other clusters, other
+ // sessions, direct table mutations), so we bump the version directly.
+ //
+ // NOTE on fidelity: the in-memory test catalog does not separate "metastore version" from
+ // "loaded snapshot version",`InMemoryBaseTable.tableVersion` is shared mutable state, so
+ // a scan's `currentTableVersion` immediately reflects this bump. A real DSv2 Delta
+ // connector holds a Table loaded at V1 and would not observe an external bump to V2 unless
+ // it explicitly refreshed from the metastore (per go/spark-refresh design doc, Section 5
+ // Scenario 1). The test therefore exercises a stricter connector policy than Delta's
+ // default, namely one that polices version drift on every registerScans call. A more
+ // faithful model would capture the txn's pinned snapshot in TxnTable at txn-begin and
+ // compare scan.builtAt against that pinned value; see follow-up notes in memory.
+ table.increaseVersion()
+ assert(table.version() != versionAtCache, "sanity: bump should change the version")
+
+ val (txn, txnTables) = executeTransaction {
+ sourceDF
+ .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk"))
+ .whenMatched()
+ .update(Map("salary" -> targetTableCol("salary").plus(1)))
+ .merge()
+ }
+
+ assert(txn.currentState == Committed)
+ // Cached scan's builtAtTableVersion != currentTableVersion -> registerScans refused; nothing
+ // recorded.
+ assert(txn.registeredScans.isEmpty,
+ "registerScans should have refused the stale cached scan")
+
+ // The cached source was not substituted, so the source's salary<250 filter was re-evaluated
+ // through the txn catalog and shows up as a scan event on the target's TxnTable.
+ val targetTxnTable = txnTables(tableNameAsString)
+ val sourceFilterScanned = targetTxnTable.scanEvents.flatten.exists {
+ case sources.LessThan("salary", 250) => true
+ case _ => false
+ }
+ assert(sourceFilterScanned,
+ s"expected the source's salary<250 filter to appear as a scan event after cache bypass, " +
+ s"got ${targetTxnTable.scanEvents.map(_.toSeq).mkString("[", ", ", "]")}")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "hr"),
+ Row(2, 201, "software"),
+ Row(3, 300, "hr")))
+ }
+
+ test("cached source from outside the txn catalog is reused without consulting registerScans") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |""".stripMargin)
+
+ // Build the source from spark.range, no DataSourceV2Relation, no catalog reference at all.
+ // Inside the txn, validateCachedEntryForTransaction's txnTables set is empty for this cached
+ // entry, so the short-circuit accepts the cache without ever consulting registerScans.
+ val sourceDF = spark.range(2)
+ .select(
+ (col("id") + 1).cast(IntegerType).as("pk"),
+ lit(999).cast(IntegerType).as("salary"),
+ lit("hr").cast(StringType).as("dep"))
+ .as("source")
+ sourceDF.cache()
+ sourceDF.count()
+
+ val (txn, _) = executeTransaction {
+ sourceDF
+ .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk"))
+ .whenMatched()
+ .update(Map("salary" -> $"source.salary"))
+ .merge()
+ }
+
+ assert(txn.currentState == Committed)
+ assert(txn.registeredScans.isEmpty,
+ "registerScans should not be consulted when the cached subtree has no txn-catalog reads")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 999, "hr"),
+ Row(2, 999, "software")))
+ }
+
+ test("cached relation inside an IN-subquery is substituted via useCachedData recursion") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ // Create + populate a lookup table in the same txn catalog. We cache a DataFrame against
+ // this lookup, then drive a MERGE whose source filters by an IN-subquery against the
+ // cached lookup. The cache substitution can only happen via useCachedData's recursive
+ // descent into SubqueryExpression, transformDown alone would not find the cached subtree
+ // hidden inside the IN expression.
+ val lookupName = "cat.ns1.cache_lookup"
+ withTable(lookupName) {
+ sql(s"CREATE TABLE $lookupName (pk INT) USING foo")
+ sql(s"INSERT INTO $lookupName VALUES (1), (2)")
+
+ val lookupDF = spark.table(lookupName)
+ lookupDF.cache()
+ lookupDF.count()
+ lookupDF.createOrReplaceTempView("cache_lookup_view")
+
+ val sourceDF = spark.table(tableNameAsString)
+ .where("pk IN (SELECT pk FROM cache_lookup_view)")
+ .as("source")
+
+ val (txn, _) = executeTransaction {
+ sourceDF
+ .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk"))
+ .whenMatched()
+ .update(Map("salary" -> targetTableCol("salary").plus(1)))
+ .merge()
+ }
+
+ assert(txn.currentState == Committed)
+ // The lookup is the only cached entry in this plan, and it sits inside an IN-subquery ,
+ // so the only way the registered scans can include it is via useCachedData's recursive
+ // descent into SubqueryExpression. Every registered scan must point at the lookup.
+ val lookupTable = catalog.loadTable(Identifier.of(Array("ns1"), "cache_lookup"))
+ assert(txn.registeredScans.flatten.collect {
+ case s: InMemoryBaseTable#InMemoryBatchScan => s.table
+ }.distinct == Seq(lookupTable))
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "hr"),
+ Row(2, 201, "software"),
+ Row(3, 300, "hr")))
+ }
+ }
+
+ test("default Transaction.registerScans (returns false) causes cache bypass") {
+ createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
+ """{ "pk": 1, "salary": 100, "dep": "hr" }
+ |{ "pk": 2, "salary": 200, "dep": "software" }
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
+ |""".stripMargin)
+
+ val sourceDF = spark.table(tableNameAsString).where("salary < 250").as("source")
+ sourceDF.cache()
+ sourceDF.count()
+
+ // Force the next transaction to use the default Transaction.registerScans behavior, which
+ // returns false unconditionally. Spark must then bypass the cached source entirely.
+ catalog.nextTxnUsesDefaultRegisterScans = true
+
+ val (txn, txnTables) = executeTransaction {
+ sourceDF
+ .mergeInto(tableNameAsString, $"source.pk" === targetTableCol("pk"))
+ .whenMatched()
+ .update(Map("salary" -> targetTableCol("salary").plus(1)))
+ .merge()
+ }
+
+ assert(txn.currentState == Committed)
+ // The default-impl refusal is observable: registerScans was called but accepted nothing,
+ // so the buffer stays empty.
+ assert(txn.registeredScans.isEmpty,
+ "default registerScans should refuse the cache")
+
+ // Cache bypass means the source's filter is re-evaluated through the txn catalog, leaving
+ // a scan event on the target's TxnTable.
+ val targetTxnTable = txnTables(tableNameAsString)
+ assert(targetTxnTable.scanEvents.flatten.exists {
+ case sources.LessThan("salary", 250) => true
+ case _ => false
+ }, "expected salary<250 to appear as a scan event after cache bypass")
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, 101, "hr"),
+ Row(2, 201, "software"),
+ Row(3, 300, "hr")))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index f37a614f99b53..126b84b507caf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -1198,10 +1198,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
assert(txn.currentState == Committed)
assert(txn.isClosed)
- // both target and source must have been read through the transaction catalog
- assert(txnTables.size == 2)
+ // The source TxnTable is still created during analysis (the txn catalog routes the
+ // load), but cache substitution reuses the cached scan instead of issuing a fresh
+ // BatchScanExec for the source, so only the target appears in the executed plan.
+ assert(txn.catalog.txnTables.size == 2)
+ assert(txnTables.size == 1)
+ assert(txnTables.contains(tableNameAsString))
assert(table.version() == "2")
- assert(txnTables(sourceNameAsString).scanEvents.nonEmpty)
+ // The connector accepted the cached source scan via registerScans, which also
+ // records the scan as a read event against the source's TxnTable.
+ assert(txn.registeredScans.nonEmpty)
+ val sourceTxnTable = txn.catalog.txnTables.values.find(_.name == sourceNameAsString).get
+ assert(sourceTxnTable.scanEvents.nonEmpty)
assert(txnTables(tableNameAsString).scanEvents.nonEmpty)
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
index 379e7ba755d9d..7c2ef1f3eab31 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala
@@ -51,6 +51,7 @@ abstract class RowLevelOperationSuiteBase
}
after {
+ catalog.nextTxnUsesDefaultRegisterScans = false
spark.sessionState.catalogManager.reset()
spark.sessionState.conf.unsetConf("spark.sql.catalog.cat")
}