Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.catalog.CatalogPlugin;
import org.apache.spark.sql.connector.catalog.TransactionalCatalogPlugin;
import org.apache.spark.sql.connector.read.Scan;

import java.io.Closeable;
import java.util.List;

/**
* Represents a transaction.
Expand Down Expand Up @@ -66,6 +68,29 @@ public interface Transaction extends Closeable {
*/
void abort();

/**
* Attempts to register a list of materialized scans against this transaction's read set.
* <p>
* Spark calls this when considering reuse of a cached subtree during the transaction. The
* list contains every materialized scan in the candidate cached subtree. That may include
* scans that belong to catalogs other than this transaction's catalog. The connector can
* decide which scans to register or ignore.
* <p>
* The connector decides whether reusing the cached snapshots is compatible with the
* transaction's isolation contract. It must either accept the cache entry (returning
* {@code true} after adding any of its own scans to the read set) or refuse it (returning
* {@code false} without modifying the read set).
* <p>
* This method may be called multiple times during a single query with overlapping scan
* lists. Registering a scan that is already in the read set must be a no-op.
*
* @param scans Every materialized scan in the candidate cached subtree.
* @return true if the connector accepts reuse of the cache entry; false otherwise.
*/
default boolean registerScans(List<Scan> scans) {
return false;
}

/**
* Releases any resources held by this transaction.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataC
import org.apache.spark.sql.catalyst.streaming.{StreamingSourceIdentifyingName, Unassigned}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability, TableCatalog, V2TableUtil}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Column, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability, TableCatalog, V2TableUtil}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference}
import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference, Transform}
import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics, SupportsRuntimeV2Filtering}
import org.apache.spark.sql.connector.read.colstats.{ColumnStatistics, Histogram => V2Histogram, HistogramBin => V2HistogramBin}
import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream}
Expand Down Expand Up @@ -122,6 +123,21 @@ case class DataSourceV2Relation(
copy(output = output.map(_.newInstance()))
}

// Replace the `table` field with an identity placeholder. This allows two relations
// that point at the same logical table compare equal even when one is wrapped (e.g.
// by a transaction-aware catalog). Note, when `Table.id()` is null we key only on catalog
// and identifier, thus, drop-and-recreate is not distinguished.
override def doCanonicalize(): LogicalPlan = {
val base = super.doCanonicalize().asInstanceOf[DataSourceV2Relation]
val catalogName = base.catalog.map(_.name())
(catalogName, base.identifier) match {
case (Some(cn), Some(ident)) =>
base.copy(table = DataSourceV2Relation.CanonicalTableKey(
cn, ident, Option(base.table.id())))
case _ => base
}
}

override lazy val metadataOutput: Seq[AttributeReference] = table match {
case hasMeta: SupportsMetadataColumns =>
metadataOutputWithOutConflicts(
Expand Down Expand Up @@ -320,6 +336,28 @@ object ExtractV2ScanInfo {
}

object DataSourceV2Relation {

// Substituted into a canonicalized `DataSourceV2Relation.table` field so that two
// relations targeting the same metastore entity compare equal regardless of which concrete
// `Table` instance backs each side. Should only appear in canonicalized plans.
private[v2] case class CanonicalTableKey(
catalogName: String,
identifier: Identifier,
idOpt: Option[String]) extends Table {
override def id(): String = idOpt.orNull
override def name(): String = s"$catalogName.$identifier"
override def capabilities(): java.util.Set[TableCapability] = throwExecutionAccess()
override def columns(): Array[Column] = throwExecutionAccess()
override def partitioning(): Array[Transform] = throwExecutionAccess()
override def properties(): java.util.Map[String, String] = throwExecutionAccess()
override def constraints(): Array[Constraint] = throwExecutionAccess()
override def version(): String = throwExecutionAccess()

private def throwExecutionAccess(): Nothing =
throw SparkException.internalError(
"CanonicalTableKey is canonicalization-only and must not appear in execution plans")
}

def create(
table: Table,
catalog: Option[CatalogPlugin],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ abstract class InMemoryBaseTable(
if (evaluableFilters.nonEmpty) {
scan.filter(evaluableFilters)
}
scan.pushedFilters = _pushedFilters
recordScanEvent(_pushedFilters)
scan
}
Expand Down Expand Up @@ -665,6 +666,18 @@ abstract class InMemoryBaseTable(
options: CaseInsensitiveStringMap)
extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeFiltering {

// Snapshot of the table version when this scan was built.
val builtAtTableVersion: Int = InMemoryBaseTable.this.tableVersion

// The current table version, read fresh on each access.
def currentTableVersion: Int = InMemoryBaseTable.this.tableVersion

// Back-pointer to the table this scan was built against.
val table: InMemoryBaseTable = InMemoryBaseTable.this

// The filters pushed to this scan at build time.
var pushedFilters: Array[Filter] = Array.empty

override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
partitioning.flatMap(_.references)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,18 @@ class InMemoryRowLevelOperationTableCatalog
// All transactions in order (committed and aborted), allowing per-statement
// validation in SQL scripting tests.
val observedTransactions: ArrayBuffer[Txn] = new ArrayBuffer[Txn]()
// Test-only knob. When true, the next transaction created by `beginTransaction` will defer
// to the default `Transaction.registerScans` behavior (always returns false). The flag is
// reset to false after being consumed.
var nextTxnUsesDefaultRegisterScans: Boolean = false

override def beginTransaction(info: TransactionInfo): Transaction = {
assert(transaction == null || transaction.currentState != Active)
this.transaction = new Txn(new TxnTableCatalog(this))
transaction
val txn = new Txn(new TxnTableCatalog(this))
txn.useDefaultRegisterScans = nextTxnUsesDefaultRegisterScans
nextTxnUsesDefaultRegisterScans = false
this.transaction = txn
txn
}

override def createTable(ident: Identifier, tableInfo: TableInfo): Table = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._

import org.apache.spark.sql.connector.catalog.transactions.Transaction
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, RowLevelOperationBuilder, RowLevelOperationInfo, WriteBuilder}
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
Expand All @@ -39,10 +40,50 @@ class Txn(override val catalog: TxnTableCatalog) extends Transaction {
private[this] var state: TransactionState = Active
private[this] var closed: Boolean = false

// Records every batch of scans the connector accepted via registerScans. Tests assert on this
// to confirm cache substitution went through the txn path.
val registeredScans: ArrayBuffer[Seq[Scan]] = ArrayBuffer.empty

// Test-only switch. When true, `registerScans` mimics the default Transaction interface
// behavior (always returns false) instead of running the version-equality check. Used to
// verify Spark's cache-bypass behavior when a connector hasn't implemented `registerScans`.
var useDefaultRegisterScans: Boolean = false

def currentState: TransactionState = state

def isClosed: Boolean = closed

// Accept the batch if relevant scans were built against the table at exactly the
// version the table is at now. A real connector could be more permissive. For example, it
// could accept older snapshots and record them in the read set for commit-time conflict
// detection.
//
// The scan list may include foreign scans that originated in other catalogs. We
// identify our own scans by structural type AND by table identity an InMemoryBatchScan
// whose underlying table is one this catalog instance is tracking. Foreign scans are
// non-transactional from our perspective and are ignored.
override def registerScans(javaScans: java.util.List[Scan]): Boolean = {
if (useDefaultRegisterScans) return false

val scans = javaScans.asScala.toSeq
val myScans = scans.collect {
case s: InMemoryBaseTable#InMemoryBatchScan
if catalog.txnTables.values.exists(_.delegate eq s.table) => s
}
val accepted = myScans.forall { s =>
s.builtAtTableVersion == s.currentTableVersion
}
if (accepted) {
registeredScans += scans
myScans.foreach { s =>
catalog.txnTables.values
.find(_.delegate eq s.table)
.foreach(_.scanEvents += s.pushedFilters)
}
}
accepted
}

override def commit(): Unit = {
if (closed) throw new IllegalStateException("Can't commit, already closed")
if (state == Aborted) throw new IllegalStateException("Can't commit, already aborted")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{Histogram, HistogramBin}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, Table, TableCapability}
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class DataSourceV2RelationSuite extends SparkFunSuite {

Expand All @@ -43,7 +46,7 @@ class DataSourceV2RelationSuite extends SparkFunSuite {
rowCount = Some(10),
colStats = Map(
"id" -> idColStat,
// "extra" is not in schema should be silently skipped
// "extra" is not in schema, should be silently skipped
"extra" -> CatalogColumnStat(distinctCount = Some(5))))

val v2Stats = DataSourceV2Relation.v1StatsToV2Stats(catalogStats, schema)
Expand Down Expand Up @@ -107,4 +110,101 @@ class DataSourceV2RelationSuite extends SparkFunSuite {
val idV2NoHist = colStats.get(FieldReference.column("id"))
assert(!idV2NoHist.histogram().isPresent)
}

test("canonicalize matches across Table wrappers when catalog and identifier agree") {
val schema = StructType(Seq(StructField("a", IntegerType)))
val output = Seq(AttributeReference("a", IntegerType)())
val catalog = Some(new StubCatalog("cat"))
val ident = Some(Identifier.of(Array("ns"), "t"))
val opts = CaseInsensitiveStringMap.empty()

// Same id on both sides, strict path. Different Table instances, same id => equal.
val a1 = DataSourceV2Relation(new StubTable("a", schema, id = "id-1"),
output, catalog, ident, opts)
val a2 = DataSourceV2Relation(new StubTable("a-wrapped", schema, id = "id-1"),
output, catalog, ident, opts)
assert(a1.sameResult(a2),
"two relations to the same logical table should canonicalize equal")

// Null id on both sides, permissive path. Same catalog+identifier => still equal.
val b1 = DataSourceV2Relation(new StubTable("b", schema, id = null),
output, catalog, ident, opts)
val b2 = DataSourceV2Relation(new StubTable("b-wrapped", schema, id = null),
output, catalog, ident, opts)
assert(b1.sameResult(b2),
"relations with null id should canonicalize equal when catalog+identifier match")

// Mixed: one side has an id, the other null. Treated as different (couldn't both be the same
// logical table for any connector that exposes id() consistently).
val c1 = DataSourceV2Relation(new StubTable("c", schema, id = "id-1"),
output, catalog, ident, opts)
val c2 = DataSourceV2Relation(new StubTable("c", schema, id = null),
output, catalog, ident, opts)
assert(!c1.sameResult(c2),
"id present vs null should not compare equal")

// Different non-null ids, drop+recreate scenario. Must compare unequal.
val d1 = DataSourceV2Relation(new StubTable("d", schema, id = "id-1"),
output, catalog, ident, opts)
val d2 = DataSourceV2Relation(new StubTable("d", schema, id = "id-2"),
output, catalog, ident, opts)
assert(!d1.sameResult(d2),
"different ids (drop+recreate) should not compare equal")

// Different identifier, must compare unequal even if ids would otherwise match.
val otherIdent = Some(Identifier.of(Array("ns"), "other"))
val e1 = DataSourceV2Relation(new StubTable("e", schema, id = "id-1"),
output, catalog, ident, opts)
val e2 = DataSourceV2Relation(new StubTable("e", schema, id = "id-1"),
output, catalog, otherIdent, opts)
assert(!e1.sameResult(e2),
"different identifiers should not compare equal")

// Different catalog, must compare unequal even with same identifier and id.
val otherCatalog = Some(new StubCatalog("other"))
val f1 = DataSourceV2Relation(new StubTable("f", schema, id = "id-1"),
output, catalog, ident, opts)
val f2 = DataSourceV2Relation(new StubTable("f", schema, id = "id-1"),
output, otherCatalog, ident, opts)
assert(!f1.sameResult(f2),
"different catalog names should not compare equal")
}

test("canonicalize falls back when catalog or identifier is missing") {
val schema = StructType(Seq(StructField("a", IntegerType)))
val output = Seq(AttributeReference("a", IntegerType)())
val ident = Some(Identifier.of(Array("ns"), "t"))
val opts = CaseInsensitiveStringMap.empty()

// No catalog: falls through to the default canonical form. Equality reduces to reference
// equality on the `table` field, same instance is equal, different instances are not.
val sharedTable = new StubTable("shared", schema, id = "id-1")
val sameInstance1 = DataSourceV2Relation(sharedTable, output, None, ident, opts)
val sameInstance2 = DataSourceV2Relation(sharedTable, output, None, ident, opts)
assert(sameInstance1.sameResult(sameInstance2),
"fallback path: same Table instance compares equal")

val diff1 = DataSourceV2Relation(new StubTable("x", schema, id = "id-1"),
output, None, ident, opts)
val diff2 = DataSourceV2Relation(new StubTable("x", schema, id = "id-1"),
output, None, ident, opts)
assert(!diff1.sameResult(diff2),
"fallback path: different Table instances are not equal even with matching id")
}
}

private class StubCatalog(catalogName: String) extends CatalogPlugin {
override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = ()
override def name(): String = catalogName
}

private class StubTable(
tableName: String,
schema: StructType,
id: String) extends Table {
override def name(): String = tableName
override def id(): String = id
override def columns(): Array[org.apache.spark.sql.connector.catalog.Column] =
org.apache.spark.sql.connector.catalog.CatalogV2Util.structTypeToV2Columns(schema)
override def capabilities(): java.util.Set[TableCapability] = java.util.Set.of()
}
Loading