diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 74cff7fef..ffe9c8189 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -94,9 +94,7 @@ jobs: POSTGRES_DB: postgres POSTGRES_PORT: 5430 POSTGRES_HOST: localhost - run: | - sudo apt-get install libpq-dev -y - ./etl-api/scripts/run_migrations.sh + run: ./scripts/run_migrations.sh etl-api - name: Run Tests run: | @@ -153,9 +151,7 @@ jobs: POSTGRES_DB: postgres POSTGRES_PORT: 5430 POSTGRES_HOST: localhost - run: | - sudo apt-get install libpq-dev -y - ./etl-api/scripts/run_migrations.sh + run: ./scripts/run_migrations.sh etl-api - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov diff --git a/Cargo.toml b/Cargo.toml index ef4e382e5..1b68b1101 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ const-oid = { version = "0.9.6", default-features = false } constant_time_eq = { version = "0.4.2" } fail = { version = "0.5.1", default-features = false } futures = { version = "0.3.31", default-features = false } -gcp-bigquery-client = { version = "0.27.0", default-features = false } +gcp-bigquery-client = { git = "https://github.com/iambriccardo/gcp-bigquery-client", default-features = false, rev = "a1cc7895afce36c0c86cd71bab94253fef04f05c" } iceberg = { version = "0.7.0", default-features = false } iceberg-catalog-rest = { version = "0.7.0", default-features = false } insta = { version = "1.43.1", default-features = false } @@ -57,7 +57,7 @@ metrics-exporter-prometheus = { version = "0.17.2", default-features = false } parquet = { version = "55.0", default-features = false } pg_escape = { version = "0.1.1", default-features = false } pin-project-lite = { version = "0.2.16", default-features = false } -postgres-replication = { git = "https://github.com/MaterializeInc/rust-postgres", default-features = false, rev = "c4b473b478b3adfbf8667d2fbe895d8423f1290b" } +postgres-replication = { git = "https://github.com/iambriccardo/rust-postgres", default-features = false, rev = "31acf55c7e5c2244e5bb3a36e7afa2a01bf52c38" } prost = { version = "0.14.1", default-features = false } rand = { version = "0.9.2", default-features = false } reqwest = { version = "0.12.22", default-features = false } @@ -74,7 +74,7 @@ thiserror = "2.0.12" tikv-jemalloc-ctl = { version = "0.6.0", default-features = false, features = ["stats"] } tikv-jemallocator = { version = "0.6.1", default-features = false, features = ["background_threads_runtime_support", "unprefixed_malloc_on_supported_platforms"] } tokio = { version = "1.47.0", default-features = false } -tokio-postgres = { git = "https://github.com/MaterializeInc/rust-postgres", default-features = false, rev = "c4b473b478b3adfbf8667d2fbe895d8423f1290b" } +tokio-postgres = { git = "https://github.com/iambriccardo/rust-postgres", default-features = false, rev = "31acf55c7e5c2244e5bb3a36e7afa2a01bf52c38" } tokio-rustls = { version = "0.26.2", default-features = false } tracing = { version = "0.1.41", default-features = false } tracing-actix-web = { version = "0.7.19", default-features = false } diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 53818284e..aa3814cb0 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -99,14 +99,13 @@ If you prefer manual setup or have an existing PostgreSQL instance: #### Single Database Setup -If using one database for both the API and replicator state: +If using one database for both the API and etl state: ```bash export DATABASE_URL=postgres://USER:PASSWORD@HOST:PORT/DB -# Run both migrations on the same database -./etl-api/scripts/run_migrations.sh -./etl-replicator/scripts/run_migrations.sh +# Run all migrations on the same database +./scripts/run_migrations.sh ``` #### Separate Database Setup @@ -116,16 +115,16 @@ If using separate databases (recommended for production): ```bash # API migrations on the control plane database export DATABASE_URL=postgres://USER:PASSWORD@API_HOST:PORT/API_DB -./etl-api/scripts/run_migrations.sh +./scripts/run_migrations.sh etl-api -# Replicator migrations on the source database +# ETL migrations on the source database export DATABASE_URL=postgres://USER:PASSWORD@SOURCE_HOST:PORT/SOURCE_DB -./etl-replicator/scripts/run_migrations.sh +./scripts/run_migrations.sh etl ``` This separation allows you to: - Scale the control plane independently from replication workloads -- Keep the replicator state close to the source data +- Keep the etl state close to the source data - Isolate concerns between infrastructure management and data replication ## Database Migrations @@ -140,7 +139,7 @@ Located in `etl-api/migrations/`, these create the control plane schema (`app` s ```bash # From project root -./etl-api/scripts/run_migrations.sh +./scripts/run_migrations.sh etl-api # Or manually with SQLx CLI sqlx migrate run --source etl-api/migrations @@ -167,19 +166,19 @@ cd etl-api cargo sqlx prepare ``` -### ETL Replicator Migrations +### ETL Migrations -Located in `etl-replicator/migrations/`, these create the replicator's state store schema (`etl` schema) for tracking replication state, table schemas, and mappings. +Located in `etl/migrations/`, these create the etl state store schema (`etl` schema) for tracking replication state, table schemas, and mappings. -**Running replicator migrations:** +**Running etl migrations:** ```bash # From project root -./etl-replicator/scripts/run_migrations.sh +./scripts/run_migrations.sh etl # Or manually with SQLx CLI (requires setting search_path) psql $DATABASE_URL -c "create schema if not exists etl;" -sqlx migrate run --source etl-replicator/migrations --database-url "${DATABASE_URL}?options=-csearch_path%3Detl" +sqlx migrate run --source etl/migrations --database-url "${DATABASE_URL}?options=-csearch_path%3Detl" ``` **Important:** Migrations are run automatically when using the `etl-replicator` binary (see `etl-replicator/src/migrations.rs:16`). However, if you integrate the `etl` crate directly into your own application as a library, you should run these migrations manually before starting your pipeline. This design decision ensures: @@ -193,10 +192,10 @@ sqlx migrate run --source etl-replicator/migrations --database-url "${DATABASE_U - Testing migrations independently - CI/CD pipelines that separate migration and deployment steps -**Creating a new replicator migration:** +**Creating a new etl migration:** ```bash -cd etl-replicator +cd etl sqlx migrate add ``` diff --git a/etl-api/scripts/run_migrations.sh b/etl-api/scripts/run_migrations.sh deleted file mode 100755 index 6feb646e8..000000000 --- a/etl-api/scripts/run_migrations.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env bash -set -eo pipefail - -if [ ! -d "etl-api/migrations" ]; then - echo >&2 "❌ Error: 'etl-api/migrations' folder not found." - echo >&2 "Please run this script from the 'etl' directory." - exit 1 -fi - -if ! [ -x "$(command -v sqlx)" ]; then - echo >&2 "❌ Error: SQLx CLI is not installed." - echo >&2 "To install it, run:" - echo >&2 " cargo install --version='~0.7' sqlx-cli --no-default-features --features rustls,postgres" - exit 1 -fi - -# Database configuration -DB_USER="${POSTGRES_USER:=postgres}" -DB_PASSWORD="${POSTGRES_PASSWORD:=postgres}" -DB_NAME="${POSTGRES_DB:=postgres}" -DB_PORT="${POSTGRES_PORT:=5430}" -DB_HOST="${POSTGRES_HOST:=localhost}" - -# Set up the database URL -export DATABASE_URL=postgres://${DB_USER}:${DB_PASSWORD}@${DB_HOST}:${DB_PORT}/${DB_NAME} - -echo "🔄 Running database migrations..." -sqlx database create -sqlx migrate run --source etl-api/migrations -echo "✨ Database migrations complete! Ready to go!" diff --git a/etl-api/src/db/pipelines.rs b/etl-api/src/db/pipelines.rs index a6e3260d1..fe7505d7b 100644 --- a/etl-api/src/db/pipelines.rs +++ b/etl-api/src/db/pipelines.rs @@ -12,7 +12,7 @@ use crate::db::replicators::{Replicator, ReplicatorsDbError, create_replicator}; use crate::db::sources::Source; use crate::routes::connect_to_source_database_with_defaults; use crate::routes::pipelines::PipelineError; -use etl_postgres::replication::{health, schema, slots, state, table_mappings}; +use etl_postgres::replication::{destination_metadata, health, schema, slots, state}; use sqlx::{PgExecutor, PgTransaction}; use std::ops::DerefMut; use thiserror::Error; @@ -247,13 +247,13 @@ pub async fn delete_pipeline_cascading( None }; - // Delete state, schema, and table mappings from the source database, only if ETL tables exist. + // Delete state, schema, and destination metadata from the source database, only if ETL tables exist. if etl_present { let _ = state::delete_replication_state_for_all_tables(source_txn.deref_mut(), pipeline.id) .await?; let _ = schema::delete_table_schemas_for_all_tables(source_txn.deref_mut(), pipeline.id) .await?; - let _ = table_mappings::delete_table_mappings_for_all_tables( + let _ = destination_metadata::delete_destination_tables_metadata_for_all_tables( source_txn.deref_mut(), pipeline.id, ) diff --git a/etl-api/tests/pipelines.rs b/etl-api/tests/pipelines.rs index efe3b22c4..691244c0b 100644 --- a/etl-api/tests/pipelines.rs +++ b/etl-api/tests/pipelines.rs @@ -1153,15 +1153,15 @@ async fn rollback_table_state_with_full_reset_succeeds() { .await .unwrap(); - // Insert a table mapping for this table - sqlx::query("INSERT INTO etl.table_mappings (pipeline_id, source_table_id, destination_table_id) VALUES ($1, $2, 'dest_test_users')") + // Insert destination metadata for this table + sqlx::query("INSERT INTO etl.destination_tables_metadata (pipeline_id, table_id, destination_table_id, snapshot_id, schema_status, replication_mask) VALUES ($1, $2, 'dest_test_users', '0/0'::pg_lsn, 'applied', '\\x01')") .bind(pipeline_id) .bind(table_oid) .execute(&source_db_pool) .await .unwrap(); - // Verify table schema and mapping exist before reset + // Verify table schema and metadata exist before reset let schema_count_before: i64 = sqlx::query_scalar( "select count(*) from etl.table_schemas where pipeline_id = $1 and table_id = $2", ) @@ -1172,15 +1172,15 @@ async fn rollback_table_state_with_full_reset_succeeds() { .unwrap(); assert_eq!(schema_count_before, 1); - let mapping_count_before: i64 = sqlx::query_scalar( - "select count(*) from etl.table_mappings where pipeline_id = $1 and source_table_id = $2", + let metadata_count_before: i64 = sqlx::query_scalar( + "select count(*) from etl.destination_tables_metadata where pipeline_id = $1 and table_id = $2", ) .bind(pipeline_id) .bind(table_oid) .fetch_one(&source_db_pool) .await .unwrap(); - assert_eq!(mapping_count_before, 1); + assert_eq!(metadata_count_before, 1); let response = test_rollback( &app, @@ -1222,22 +1222,22 @@ async fn rollback_table_state_with_full_reset_succeeds() { .unwrap(); assert_eq!(schema_count_after, 0); - // Verify table mapping was deleted - let mapping_count_after: i64 = sqlx::query_scalar( - "select count(*) from etl.table_mappings where pipeline_id = $1 and source_table_id = $2", + // Verify destination metadata was deleted + let metadata_count_after: i64 = sqlx::query_scalar( + "select count(*) from etl.destination_tables_metadata where pipeline_id = $1 and table_id = $2", ) .bind(pipeline_id) .bind(table_oid) .fetch_one(&source_db_pool) .await .unwrap(); - assert_eq!(mapping_count_after, 0); + assert_eq!(metadata_count_after, 0); drop_pg_database(&source_db_config).await; } #[tokio::test(flavor = "multi_thread")] -async fn rollback_to_init_cleans_up_schemas_and_mappings() { +async fn rollback_to_init_cleans_up_schemas_and_metadata() { init_test_tracing(); let (app, tenant_id, pipeline_id, source_db_pool, source_db_config) = setup_pipeline_with_source_db().await; @@ -1273,7 +1273,7 @@ async fn rollback_to_init_cleans_up_schemas_and_mappings() { .await .unwrap(); - sqlx::query("INSERT INTO etl.table_mappings (pipeline_id, source_table_id, destination_table_id) VALUES ($1, $2, 'dest_test_users')") + sqlx::query("INSERT INTO etl.destination_tables_metadata (pipeline_id, table_id, destination_table_id, snapshot_id, schema_status, replication_mask) VALUES ($1, $2, 'dest_test_users', '0/0'::pg_lsn, 'applied', '\\x01')") .bind(pipeline_id) .bind(table_oid) .execute(&source_db_pool) @@ -1308,22 +1308,22 @@ async fn rollback_to_init_cleans_up_schemas_and_mappings() { .unwrap(); assert_eq!(schema_count, 0); - // Verify table mapping was deleted - let mapping_count: i64 = sqlx::query_scalar( - "select count(*) from etl.table_mappings where pipeline_id = $1 and source_table_id = $2", + // Verify destination metadata was deleted + let metadata_count: i64 = sqlx::query_scalar( + "select count(*) from etl.destination_tables_metadata where pipeline_id = $1 and table_id = $2", ) .bind(pipeline_id) .bind(table_oid) .fetch_one(&source_db_pool) .await .unwrap(); - assert_eq!(mapping_count, 0); + assert_eq!(metadata_count, 0); drop_pg_database(&source_db_config).await; } #[tokio::test(flavor = "multi_thread")] -async fn rollback_to_non_starting_state_keeps_schemas_and_mappings() { +async fn rollback_to_non_starting_state_keeps_schemas_and_metadata() { init_test_tracing(); let (app, tenant_id, pipeline_id, source_db_pool, source_db_config) = setup_pipeline_with_source_db().await; @@ -1359,7 +1359,7 @@ async fn rollback_to_non_starting_state_keeps_schemas_and_mappings() { .await .unwrap(); - sqlx::query("INSERT INTO etl.table_mappings (pipeline_id, source_table_id, destination_table_id) VALUES ($1, $2, 'dest_test_users')") + sqlx::query("INSERT INTO etl.destination_tables_metadata (pipeline_id, table_id, destination_table_id, snapshot_id, schema_status, replication_mask) VALUES ($1, $2, 'dest_test_users', '0/0'::pg_lsn, 'applied', '\\x01')") .bind(pipeline_id) .bind(table_oid) .execute(&source_db_pool) @@ -1394,16 +1394,16 @@ async fn rollback_to_non_starting_state_keeps_schemas_and_mappings() { .unwrap(); assert_eq!(schema_count, 1); - // Verify table mapping was NOT deleted - let mapping_count: i64 = sqlx::query_scalar( - "select count(*) from etl.table_mappings where pipeline_id = $1 and source_table_id = $2", + // Verify destination metadata was NOT deleted + let metadata_count: i64 = sqlx::query_scalar( + "select count(*) from etl.destination_tables_metadata where pipeline_id = $1 and table_id = $2", ) .bind(pipeline_id) .bind(table_oid) .fetch_one(&source_db_pool) .await .unwrap(); - assert_eq!(mapping_count, 1); + assert_eq!(metadata_count, 1); drop_pg_database(&source_db_config).await; } diff --git a/etl-api/tests/support/database.rs b/etl-api/tests/support/database.rs index b3b912ff4..ed69fd962 100644 --- a/etl-api/tests/support/database.rs +++ b/etl-api/tests/support/database.rs @@ -89,8 +89,8 @@ pub async fn run_etl_migrations_on_source_database(source_db_config: &PgConnecti .await .expect("failed to set search path"); - // Run replicator migrations to create the state store tables. - sqlx::migrate!("../etl-replicator/migrations") + // Run migrations to create the etl tables. + sqlx::migrate!("../etl/migrations") .run(&source_pool) .await .expect("failed to run etl migrations"); diff --git a/etl-benchmarks/benches/table_copies.rs b/etl-benchmarks/benches/table_copies.rs index 7764a8a16..b30729646 100644 --- a/etl-benchmarks/benches/table_copies.rs +++ b/etl-benchmarks/benches/table_copies.rs @@ -4,7 +4,7 @@ use etl::error::EtlResult; use etl::pipeline::Pipeline; use etl::state::table::TableReplicationPhaseType; use etl::test_utils::notify::NotifyingStore; -use etl::types::{Event, TableRow}; +use etl::types::{Event, ReplicatedTableSchema, TableRow}; use etl_config::Environment; use etl_config::shared::{BatchConfig, PgConnectionConfig, PipelineConfig, TlsConfig}; use etl_destinations::bigquery::BigQueryDestination; @@ -413,21 +413,30 @@ impl Destination for BenchDestination { "bench_destination" } - async fn truncate_table(&self, table_id: TableId) -> EtlResult<()> { + async fn truncate_table( + &self, + replicated_table_schema: &ReplicatedTableSchema, + ) -> EtlResult<()> { match self { - BenchDestination::Null(dest) => dest.truncate_table(table_id).await, - BenchDestination::BigQuery(dest) => dest.truncate_table(table_id).await, + BenchDestination::Null(dest) => dest.truncate_table(replicated_table_schema).await, + BenchDestination::BigQuery(dest) => dest.truncate_table(replicated_table_schema).await, } } async fn write_table_rows( &self, - table_id: TableId, + replicated_table_schema: &ReplicatedTableSchema, table_rows: Vec, ) -> EtlResult<()> { match self { - BenchDestination::Null(dest) => dest.write_table_rows(table_id, table_rows).await, - BenchDestination::BigQuery(dest) => dest.write_table_rows(table_id, table_rows).await, + BenchDestination::Null(dest) => { + dest.write_table_rows(replicated_table_schema, table_rows) + .await + } + BenchDestination::BigQuery(dest) => { + dest.write_table_rows(replicated_table_schema, table_rows) + .await + } } } @@ -444,13 +453,16 @@ impl Destination for NullDestination { "null" } - async fn truncate_table(&self, _table_id: TableId) -> EtlResult<()> { + async fn truncate_table( + &self, + _replicated_table_schema: &ReplicatedTableSchema, + ) -> EtlResult<()> { Ok(()) } async fn write_table_rows( &self, - _table_id: TableId, + _replicated_table_schema: &ReplicatedTableSchema, _table_rows: Vec, ) -> EtlResult<()> { Ok(()) diff --git a/etl-destinations/Cargo.toml b/etl-destinations/Cargo.toml index c962347e2..a8965a376 100644 --- a/etl-destinations/Cargo.toml +++ b/etl-destinations/Cargo.toml @@ -59,6 +59,7 @@ uuid = { workspace = true, optional = true, features = ["v4"] } [dev-dependencies] etl = { workspace = true, features = ["test-utils"] } +etl-postgres = { workspace = true, features = ["test-utils"] } etl-telemetry = { workspace = true } chrono = { workspace = true } diff --git a/etl-destinations/src/bigquery/client.rs b/etl-destinations/src/bigquery/client.rs index 0b9bf41b9..b7e86cf08 100644 --- a/etl-destinations/src/bigquery/client.rs +++ b/etl-destinations/src/bigquery/client.rs @@ -1,6 +1,6 @@ use etl::error::{ErrorKind, EtlError, EtlResult}; use etl::etl_error; -use etl::types::{Cell, ColumnSchema, TableRow, Type, is_array_type}; +use etl::types::{Cell, ColumnSchema, ReplicatedTableSchema, TableRow, Type, is_array_type}; use gcp_bigquery_client::google::cloud::bigquery::storage::v1::RowError; use gcp_bigquery_client::storage::ColumnMode; use gcp_bigquery_client::yup_oauth2::parse_service_account_key; @@ -116,9 +116,9 @@ impl BigQueryClient { pub async fn new_with_flow_authenticator, P: Into>( project_id: BigQueryProjectId, secret: S, - persistant_file_path: P, + persistent_file_path: P, ) -> EtlResult { - let client = Client::from_installed_flow_authenticator(secret, persistant_file_path) + let client = Client::from_installed_flow_authenticator(secret, persistent_file_path) .await .map_err(bq_error_to_etl_error)?; @@ -151,14 +151,14 @@ impl BigQueryClient { &self, dataset_id: &BigQueryDatasetId, table_id: &BigQueryTableId, - column_schemas: &[ColumnSchema], + replicated_table_schema: &ReplicatedTableSchema, max_staleness_mins: Option, ) -> EtlResult { let table_existed = self.table_exists(dataset_id, table_id).await?; let full_table_name = self.full_table_name(dataset_id, table_id)?; - let columns_spec = Self::create_columns_spec(column_schemas)?; + let columns_spec = Self::create_columns_spec(replicated_table_schema)?; let max_staleness_option = if let Some(max_staleness_mins) = max_staleness_mins { Self::max_staleness_option(max_staleness_mins) } else { @@ -186,7 +186,7 @@ impl BigQueryClient { &self, dataset_id: &BigQueryDatasetId, table_id: &BigQueryTableId, - column_schemas: &[ColumnSchema], + column_schemas: &ReplicatedTableSchema, max_staleness_mins: Option, ) -> EtlResult { if self.table_exists(dataset_id, table_id).await? { @@ -207,12 +207,12 @@ impl BigQueryClient { &self, dataset_id: &BigQueryDatasetId, table_id: &BigQueryTableId, - column_schemas: &[ColumnSchema], + replicated_table_schema: &ReplicatedTableSchema, max_staleness_mins: Option, ) -> EtlResult<()> { let full_table_name = self.full_table_name(dataset_id, table_id)?; - let columns_spec = Self::create_columns_spec(column_schemas)?; + let columns_spec = Self::create_columns_spec(replicated_table_schema)?; let max_staleness_option = if let Some(max_staleness_mins) = max_staleness_mins { Self::max_staleness_option(max_staleness_mins) } else { @@ -290,6 +290,81 @@ impl BigQueryClient { Ok(()) } + /// Adds a column to an existing BigQuery table. + /// + /// Executes an ALTER TABLE ADD COLUMN statement to add a new column with the + /// specified schema. New columns must be nullable in BigQuery. + pub async fn add_column( + &self, + dataset_id: &BigQueryDatasetId, + table_id: &BigQueryTableId, + column_schema: &ColumnSchema, + ) -> EtlResult<()> { + let full_table_name = self.full_table_name(dataset_id, table_id)?; + let column_name = Self::sanitize_identifier(&column_schema.name, "BigQuery column name")?; + let column_type = Self::postgres_to_bigquery_type(&column_schema.typ); + + info!( + "adding column `{column_name}` ({column_type}) to table {full_table_name} in BigQuery" + ); + + // BigQuery requires new columns to be nullable (no NOT NULL constraint allowed). Also, we wouldn't + // be able to add it nonetheless since we don't have a way to set a default value for past columns. + let query = + format!("alter table {full_table_name} add column `{column_name}` {column_type}"); + + let _ = self.query(QueryRequest::new(query)).await?; + + Ok(()) + } + + /// Drops a column from an existing BigQuery table. + /// + /// Executes an ALTER TABLE DROP COLUMN statement to remove the specified column. + pub async fn drop_column( + &self, + dataset_id: &BigQueryDatasetId, + table_id: &BigQueryTableId, + column_name: &str, + ) -> EtlResult<()> { + let full_table_name = self.full_table_name(dataset_id, table_id)?; + let column_name = Self::sanitize_identifier(column_name, "BigQuery column name")?; + + info!("dropping column `{column_name}` from table {full_table_name} in BigQuery"); + + let query = format!("alter table {full_table_name} drop column `{column_name}`"); + + let _ = self.query(QueryRequest::new(query)).await?; + + Ok(()) + } + + /// Renames a column in an existing BigQuery table. + /// + /// Executes an ALTER TABLE RENAME COLUMN statement to rename the specified column. + pub async fn rename_column( + &self, + dataset_id: &BigQueryDatasetId, + table_id: &BigQueryTableId, + old_name: &str, + new_name: &str, + ) -> EtlResult<()> { + let full_table_name = self.full_table_name(dataset_id, table_id)?; + let old_name = Self::sanitize_identifier(old_name, "BigQuery column name")?; + let new_name = Self::sanitize_identifier(new_name, "BigQuery column name")?; + + info!( + "renaming column `{old_name}` to `{new_name}` in table {full_table_name} in BigQuery" + ); + + let query = + format!("alter table {full_table_name} rename column `{old_name}` to `{new_name}`"); + + let _ = self.query(QueryRequest::new(query)).await?; + + Ok(()) + } + /// Checks whether a table exists in the BigQuery dataset. /// /// Returns `true` if the table exists, `false` otherwise. @@ -320,12 +395,20 @@ impl BigQueryClient { /// which can be processed concurrently. /// If ordering guarantees are needed, all data for a given table must be included /// in a single batch. - pub async fn stream_table_batches_concurrent( + /// + /// TODO: we might want to improve the detection of retriable errors by having a special error + /// type that we return for this. + pub async fn stream_table_batches_concurrent( &self, - table_batches: Vec>, + table_batches: I, max_concurrent_streams: usize, - ) -> EtlResult<(usize, usize)> { - if table_batches.is_empty() { + ) -> EtlResult<(usize, usize)> + where + I: IntoIterator>>, + I::IntoIter: ExactSizeIterator, + { + let table_batches = table_batches.into_iter(); + if table_batches.len() == 0 { return Ok((0, 0)); } @@ -388,14 +471,14 @@ impl BigQueryClient { /// Creates a TableBatch for a specific table with validated rows. /// /// Converts TableRow instances to BigQueryTableRow and creates a properly configured - /// TableBatch with the appropriate stream name and table descriptor. + /// TableBatch wrapped in Arc for efficient sharing and retry operations. pub fn create_table_batch( &self, dataset_id: &BigQueryDatasetId, table_id: &BigQueryTableId, - table_descriptor: Arc, + table_descriptor: TableDescriptor, rows: Vec, - ) -> EtlResult> { + ) -> EtlResult>> { let validated_rows = rows .into_iter() .map(BigQueryTableRow::try_from) @@ -410,11 +493,11 @@ impl BigQueryClient { table_id.to_string(), ); - Ok(TableBatch::new( + Ok(Arc::new(TableBatch::new( stream_name, table_descriptor, validated_rows, - )) + ))) } /// Executes a BigQuery SQL query and returns the result set. @@ -487,38 +570,53 @@ impl BigQueryClient { /// Creates a primary key clause for table creation. /// - /// Generates a primary key constraint clause from columns marked as primary key. - fn add_primary_key_clause(column_schemas: &[ColumnSchema]) -> EtlResult { - let identity_columns: Vec = column_schemas - .iter() - .filter(|s| s.primary) + /// Generates a primary key constraint clause from columns marked as primary key, + /// sorted by their ordinal position to ensure correct composite key ordering. + fn add_primary_key_clause( + replicated_table_schema: &ReplicatedTableSchema, + ) -> EtlResult> { + let mut primary_key_columns: Vec<_> = replicated_table_schema + .column_schemas() + .filter(|s| s.primary_key()) + .collect(); + + // If no primary key columns are marked, return early. + if primary_key_columns.is_empty() { + return Ok(None); + } + + // Sort by primary_key_ordinal_position to ensure correct composite key ordering. + primary_key_columns.sort_by_key(|c| c.primary_key_ordinal_position); + + let primary_key_columns: Vec = primary_key_columns + .into_iter() .map(|c| { Self::sanitize_identifier(&c.name, "BigQuery primary key column") .map(|name| format!("`{name}`")) }) .collect::>>()?; - if identity_columns.is_empty() { - return Ok("".to_string()); - } - - Ok(format!( + let primary_key_clause = format!( ", primary key ({}) not enforced", - identity_columns.join(",") - )) + primary_key_columns.join(",") + ); + + Ok(Some(primary_key_clause)) } /// Builds complete column specifications for CREATE TABLE statements. - fn create_columns_spec(column_schemas: &[ColumnSchema]) -> EtlResult { - let mut s = column_schemas - .iter() + fn create_columns_spec(replicated_table_schema: &ReplicatedTableSchema) -> EtlResult { + let mut column_spec = replicated_table_schema + .column_schemas() .map(Self::column_spec) .collect::>>()? .join(","); - s.push_str(&Self::add_primary_key_clause(column_schemas)?); + if let Some(primary_key_clause) = Self::add_primary_key_clause(replicated_table_schema)? { + column_spec.push_str(&primary_key_clause); + } - Ok(format!("({s})")) + Ok(format!("({column_spec})")) } /// Creates max staleness option clause for CDC table creation. @@ -575,13 +673,13 @@ impl BigQueryClient { /// Maps data types and nullability to BigQuery column specifications, setting /// appropriate column modes and automatically adding CDC special columns. pub fn column_schemas_to_table_descriptor( - column_schemas: &[ColumnSchema], + replicated_table_schema: &ReplicatedTableSchema, use_cdc_sequence_column: bool, ) -> TableDescriptor { - let mut field_descriptors = Vec::with_capacity(column_schemas.len()); + let mut field_descriptors = vec![]; let mut number = 1; - for column_schema in column_schemas { + for column_schema in replicated_table_schema.column_schemas() { let typ = match column_schema.typ { Type::BOOL => ColumnType::Bool, Type::CHAR | Type::BPCHAR | Type::VARCHAR | Type::NAME | Type::TEXT => { @@ -755,6 +853,11 @@ fn bq_error_to_etl_error(err: BQError) -> EtlError { == "The caller does not have permission to execute the specified operation" { (ErrorKind::PermissionDenied, "BigQuery permission denied") + } else if is_retryable_streaming_message(status.message()) { + ( + ErrorKind::DestinationSchemaMismatch, + "BigQuery schema mismatch", + ) } else { (ErrorKind::DestinationError, "BigQuery gRPC status error") } @@ -777,8 +880,44 @@ fn bq_error_to_etl_error(err: BQError) -> EtlError { etl_error!(kind, description, err.to_string()) } +/// Patterns that indicate transient/retryable errors from BigQuery Storage Write API. +/// +/// These errors can occur temporarily after DDL operations (e.g., column renames, +/// ADD/DROP COLUMN) when BigQuery hasn't fully propagated changes. They include: +/// - Schema mismatch errors when the cached schema is stale +/// - Entity not found errors when streaming endpoints aren't ready +/// +/// Retrying with backoff typically resolves these issues. +const RETRYABLE_STREAMING_PATTERNS: &[&str] = &[ + // Schema mismatch patterns + "extra field", + "is missing in the proto", + // Entity not found patterns (transient after DDL) + "was not found", + "entity was not found", +]; + +/// Checks if an error message indicates a retryable streaming error. +fn is_retryable_streaming_message(message: &str) -> bool { + let lower = message.to_lowercase(); + RETRYABLE_STREAMING_PATTERNS + .iter() + .any(|pattern| lower.contains(pattern)) +} + /// Converts BigQuery row errors to ETL destination errors. +/// +/// Detects retryable streaming errors by checking for specific patterns in the +/// error message that indicate transient issues after DDL operations. fn row_error_to_etl_error(err: RowError) -> EtlError { + if is_retryable_streaming_message(&err.message) { + return etl_error!( + ErrorKind::DestinationSchemaMismatch, + "BigQuery schema mismatch", + format!("{err:?}") + ); + } + etl_error!( ErrorKind::DestinationError, "BigQuery row error", @@ -790,6 +929,43 @@ fn row_error_to_etl_error(err: RowError) -> EtlError { mod tests { use super::*; + use etl::types::{ReplicationMask, TableId, TableName, TableSchema}; + use std::collections::HashSet; + + /// Creates a test column schema with common defaults. + /// + /// This helper simplifies column schema creation in tests by providing sensible + /// defaults for fields that are typically not relevant to the test logic. + fn test_column( + name: &str, + typ: Type, + ordinal_position: i32, + nullable: bool, + primary_key_ordinal: Option, + ) -> ColumnSchema { + ColumnSchema::new( + name.to_string(), + typ, + -1, + ordinal_position, + primary_key_ordinal, + nullable, + ) + } + + /// Creates a [`ReplicatedTableSchema`] from test columns with all columns replicated. + fn test_replicated_schema(columns: Vec) -> ReplicatedTableSchema { + let column_names: HashSet = columns.iter().map(|c| c.name.clone()).collect(); + let table_schema = Arc::new(TableSchema::new( + TableId(1), // Dummy table ID + TableName::new("public".to_string(), "test_table".to_string()), + columns, + )); + let replication_mask = ReplicationMask::build_or_all(&table_schema, &column_names); + + ReplicatedTableSchema::from_mask(table_schema, replication_mask) + } + #[test] fn test_postgres_to_bigquery_type_basic_types() { assert_eq!( @@ -848,24 +1024,23 @@ mod tests { #[test] fn test_column_spec() { - let column_schema = ColumnSchema::new("test_col".to_string(), Type::TEXT, -1, true, false); + let column_schema = test_column("test_col", Type::TEXT, 1, true, None); let spec = BigQueryClient::column_spec(&column_schema).expect("column spec generation"); assert_eq!(spec, "`test_col` string"); - let not_null_column = ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true); + let not_null_column = test_column("id", Type::INT4, 1, false, Some(1)); let not_null_spec = BigQueryClient::column_spec(¬_null_column).expect("not null column spec"); assert_eq!(not_null_spec, "`id` int64 not null"); - let array_column = - ColumnSchema::new("tags".to_string(), Type::TEXT_ARRAY, -1, false, false); + let array_column = test_column("tags", Type::TEXT_ARRAY, 1, false, None); let array_spec = BigQueryClient::column_spec(&array_column).expect("array column spec"); assert_eq!(array_spec, "`tags` array"); } #[test] fn test_column_spec_escapes_backticks() { - let column_schema = ColumnSchema::new("pwn`name".to_string(), Type::TEXT, -1, true, false); + let column_schema = test_column("pwn`name", Type::TEXT, 1, true, None); let spec = BigQueryClient::column_spec(&column_schema).expect("escaped column spec"); @@ -885,43 +1060,65 @@ mod tests { #[test] fn test_add_primary_key_clause() { let columns_with_pk = vec![ - ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true), - ColumnSchema::new("name".to_string(), Type::TEXT, -1, true, false), + test_column("id", Type::INT4, 1, false, Some(1)), + test_column("name", Type::TEXT, 2, true, None), ]; - let pk_clause = - BigQueryClient::add_primary_key_clause(&columns_with_pk).expect("pk clause"); + let schema_with_pk = test_replicated_schema(columns_with_pk); + let pk_clause = BigQueryClient::add_primary_key_clause(&schema_with_pk) + .expect("pk clause") + .unwrap(); assert_eq!(pk_clause, ", primary key (`id`) not enforced"); + // Composite primary key with correct ordinal positions. let columns_with_composite_pk = vec![ - ColumnSchema::new("tenant_id".to_string(), Type::INT4, -1, false, true), - ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true), - ColumnSchema::new("name".to_string(), Type::TEXT, -1, true, false), + test_column("tenant_id", Type::INT4, 1, false, Some(1)), + test_column("id", Type::INT4, 2, false, Some(2)), + test_column("name", Type::TEXT, 3, true, None), ]; - let composite_pk_clause = - BigQueryClient::add_primary_key_clause(&columns_with_composite_pk) - .expect("composite pk clause"); + let schema_with_composite_pk = test_replicated_schema(columns_with_composite_pk); + let composite_pk_clause = BigQueryClient::add_primary_key_clause(&schema_with_composite_pk) + .unwrap() + .expect("composite pk clause"); assert_eq!( composite_pk_clause, ", primary key (`tenant_id`,`id`) not enforced" ); + // Composite primary key with reversed column order but correct ordinal positions. + // The primary key clause should still be ordered by ordinal position. + let columns_with_reversed_pk = vec![ + test_column("id", Type::INT4, 1, false, Some(2)), + test_column("tenant_id", Type::INT4, 2, false, Some(1)), + test_column("name", Type::TEXT, 3, true, None), + ]; + let schema_with_reversed_pk = test_replicated_schema(columns_with_reversed_pk); + let reversed_pk_clause = BigQueryClient::add_primary_key_clause(&schema_with_reversed_pk) + .unwrap() + .expect("reversed pk clause"); + assert_eq!( + reversed_pk_clause, + ", primary key (`tenant_id`,`id`) not enforced" + ); + let columns_no_pk = vec![ - ColumnSchema::new("name".to_string(), Type::TEXT, -1, true, false), - ColumnSchema::new("age".to_string(), Type::INT4, -1, true, false), + test_column("name", Type::TEXT, 1, true, None), + test_column("age", Type::INT4, 2, true, None), ]; + let schema_no_pk = test_replicated_schema(columns_no_pk); let no_pk_clause = - BigQueryClient::add_primary_key_clause(&columns_no_pk).expect("no pk clause"); - assert_eq!(no_pk_clause, ""); + BigQueryClient::add_primary_key_clause(&schema_no_pk).expect("no pk clause"); + assert!(no_pk_clause.is_none()); } #[test] fn test_create_columns_spec() { let columns = vec![ - ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true), - ColumnSchema::new("name".to_string(), Type::TEXT, -1, true, false), - ColumnSchema::new("active".to_string(), Type::BOOL, -1, false, false), + test_column("id", Type::INT4, 1, false, Some(1)), + test_column("name", Type::TEXT, 2, true, None), + test_column("active", Type::BOOL, 3, false, None), ]; - let spec = BigQueryClient::create_columns_spec(&columns).expect("columns spec"); + let schema = test_replicated_schema(columns); + let spec = BigQueryClient::create_columns_spec(&schema).expect("columns spec"); assert_eq!( spec, "(`id` int64 not null,`name` string,`active` bool not null, primary key (`id`) not enforced)" @@ -937,13 +1134,14 @@ mod tests { #[test] fn test_column_schemas_to_table_descriptor() { let columns = vec![ - ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true), - ColumnSchema::new("name".to_string(), Type::TEXT, -1, true, false), - ColumnSchema::new("active".to_string(), Type::BOOL, -1, false, false), - ColumnSchema::new("tags".to_string(), Type::TEXT_ARRAY, -1, false, false), + test_column("id", Type::INT4, 1, false, Some(1)), + test_column("name", Type::TEXT, 2, true, None), + test_column("active", Type::BOOL, 3, false, None), + test_column("tags", Type::TEXT_ARRAY, 4, false, None), ]; + let schema = test_replicated_schema(columns); - let descriptor = BigQueryClient::column_schemas_to_table_descriptor(&columns, true); + let descriptor = BigQueryClient::column_schemas_to_table_descriptor(&schema, true); assert_eq!(descriptor.field_descriptors.len(), 6); // 4 columns + CDC columns @@ -1020,15 +1218,16 @@ mod tests { #[test] fn test_column_schemas_to_table_descriptor_complex_types() { let columns = vec![ - ColumnSchema::new("uuid_col".to_string(), Type::UUID, -1, true, false), - ColumnSchema::new("json_col".to_string(), Type::JSON, -1, true, false), - ColumnSchema::new("bytea_col".to_string(), Type::BYTEA, -1, true, false), - ColumnSchema::new("numeric_col".to_string(), Type::NUMERIC, -1, true, false), - ColumnSchema::new("date_col".to_string(), Type::DATE, -1, true, false), - ColumnSchema::new("time_col".to_string(), Type::TIME, -1, true, false), + test_column("uuid_col", Type::UUID, 1, true, None), + test_column("json_col", Type::JSON, 2, true, None), + test_column("bytea_col", Type::BYTEA, 3, true, None), + test_column("numeric_col", Type::NUMERIC, 4, true, None), + test_column("date_col", Type::DATE, 5, true, None), + test_column("time_col", Type::TIME, 6, true, None), ]; + let schema = test_replicated_schema(columns); - let descriptor = BigQueryClient::column_schemas_to_table_descriptor(&columns, true); + let descriptor = BigQueryClient::column_schemas_to_table_descriptor(&schema, true); assert_eq!(descriptor.field_descriptors.len(), 8); // 6 columns + CDC columns @@ -1082,9 +1281,10 @@ mod tests { let table_id = "test_table"; let columns = vec![ - ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true), - ColumnSchema::new("name".to_string(), Type::TEXT, -1, true, false), + test_column("id", Type::INT4, 1, false, Some(1)), + test_column("name", Type::TEXT, 2, true, None), ]; + let schema = test_replicated_schema(columns); // Simulate the query generation logic let full_table_name = format!( @@ -1093,7 +1293,7 @@ mod tests { dataset = BigQueryClient::sanitize_identifier(dataset_id, "dataset").unwrap(), table = BigQueryClient::sanitize_identifier(table_id, "table").unwrap() ); - let columns_spec = BigQueryClient::create_columns_spec(&columns).unwrap(); + let columns_spec = BigQueryClient::create_columns_spec(&schema).unwrap(); let query = format!("create or replace table {full_table_name} {columns_spec}"); let expected_query = "create or replace table `test-project.test_dataset.test_table` (`id` int64 not null,`name` string, primary key (`id`) not enforced)"; @@ -1107,13 +1307,8 @@ mod tests { let table_id = "test_table"; let max_staleness_mins = 15; - let columns = vec![ColumnSchema::new( - "id".to_string(), - Type::INT4, - -1, - false, - true, - )]; + let columns = vec![test_column("id", Type::INT4, 1, false, Some(1))]; + let schema = test_replicated_schema(columns); // Simulate the query generation logic with staleness let full_table_name = format!( @@ -1122,7 +1317,7 @@ mod tests { dataset = BigQueryClient::sanitize_identifier(dataset_id, "dataset").unwrap(), table = BigQueryClient::sanitize_identifier(table_id, "table").unwrap() ); - let columns_spec = BigQueryClient::create_columns_spec(&columns).unwrap(); + let columns_spec = BigQueryClient::create_columns_spec(&schema).unwrap(); let max_staleness_option = BigQueryClient::max_staleness_option(max_staleness_mins); let query = format!( "create or replace table {full_table_name} {columns_spec} {max_staleness_option}" @@ -1131,4 +1326,27 @@ mod tests { let expected_query = "create or replace table `test-project.test_dataset.test_table` (`id` int64 not null, primary key (`id`) not enforced) options (max_staleness = interval 15 minute)"; assert_eq!(query, expected_query); } + + #[test] + fn test_is_retryable_streaming_message() { + // Schema mismatch patterns. + assert!(is_retryable_streaming_message("extra field in row")); + assert!(is_retryable_streaming_message("Extra Field detected")); + assert!(is_retryable_streaming_message( + "field foo is missing in the proto" + )); + + // Entity not found patterns (transient after DDL). + assert!(is_retryable_streaming_message( + "Requested entity was not found. Entity: projects/foo/datasets/bar/tables/baz/streams/_default" + )); + assert!(is_retryable_streaming_message("entity was not found")); + assert!(is_retryable_streaming_message("Table was not found")); + + // Messages that should not match. + assert!(!is_retryable_streaming_message("connection timeout")); + assert!(!is_retryable_streaming_message("permission denied")); + assert!(!is_retryable_streaming_message("invalid data format")); + assert!(!is_retryable_streaming_message("")); + } } diff --git a/etl-destinations/src/bigquery/core.rs b/etl-destinations/src/bigquery/core.rs index 6c5408fbf..4d49c911c 100644 --- a/etl-destinations/src/bigquery/core.rs +++ b/etl-destinations/src/bigquery/core.rs @@ -1,16 +1,23 @@ use etl::destination::Destination; use etl::error::{ErrorKind, EtlError, EtlResult}; use etl::store::schema::SchemaStore; -use etl::store::state::StateStore; -use etl::types::{Cell, Event, TableId, TableName, TableRow, generate_sequence_number}; +use etl::store::state::{DestinationTableMetadata, DestinationTableSchemaStatus, StateStore}; +use etl::types::{ + Cell, Event, ReplicatedTableSchema, SchemaDiff, TableId, TableName, TableRow, + generate_sequence_number, +}; use etl::{bail, etl_error}; -use gcp_bigquery_client::storage::TableDescriptor; +use gcp_bigquery_client::storage::{TableBatch, TableDescriptor}; + +use crate::bigquery::encoding::BigQueryTableRow; use std::collections::{HashMap, HashSet}; use std::fmt::Display; use std::iter; use std::str::FromStr; use std::sync::Arc; +use std::time::Duration; use tokio::sync::Mutex; +use tokio::time::sleep; use tracing::{debug, info, warn}; use crate::bigquery::client::{BigQueryClient, BigQueryOperationType}; @@ -21,6 +28,14 @@ const BIGQUERY_TABLE_ID_DELIMITER: &str = "_"; /// Replacement string for escaping underscores in Postgres names. const BIGQUERY_TABLE_ID_DELIMITER_ESCAPE_REPLACEMENT: &str = "__"; +/// Maximum number of retry attempts for schema mismatch errors. +/// +/// After DDL changes, the BigQuery Storage Write API may cache stale schema information. +/// These retries allow the cache to refresh before failing permanently. +const MAX_SCHEMA_MISMATCH_ATTEMPTS: usize = 10; +/// Delay between schema mismatch retry attempts in milliseconds. +const SCHEMA_MISMATCH_RETRY_DELAY_MS: u64 = 1000; + /// Returns the [`BigQueryTableId`] for a supplied [`TableName`]. /// /// Escapes underscores in schema and table names to prevent collisions when combining them. @@ -171,7 +186,7 @@ pub struct BigQueryDestination { dataset_id: BigQueryDatasetId, max_staleness_mins: Option, max_concurrent_streams: usize, - store: S, + state_store: S, inner: Arc>, } @@ -191,7 +206,7 @@ where sa_key: &str, max_staleness_mins: Option, max_concurrent_streams: usize, - store: S, + state_store: S, ) -> EtlResult { let client = BigQueryClient::new_with_key_path(project_id, sa_key).await?; let inner = Inner { @@ -204,7 +219,7 @@ where dataset_id, max_staleness_mins, max_concurrent_streams, - store, + state_store, inner: Arc::new(Mutex::new(inner)), }) } @@ -221,7 +236,7 @@ where sa_key: &str, max_staleness_mins: Option, max_concurrent_streams: usize, - store: S, + state_store: S, ) -> EtlResult { let client = BigQueryClient::new_with_key(project_id, sa_key).await?; let inner = Inner { @@ -234,7 +249,7 @@ where dataset_id, max_staleness_mins, max_concurrent_streams, - store, + state_store, inner: Arc::new(Mutex::new(inner)), }) } @@ -249,7 +264,7 @@ where dataset_id: BigQueryDatasetId, max_staleness_mins: Option, max_concurrent_streams: usize, - store: S, + state_store: S, ) -> EtlResult { let client = BigQueryClient::new_with_adc(project_id).await?; let inner = Inner { @@ -262,7 +277,7 @@ where dataset_id, max_staleness_mins, max_concurrent_streams, - store, + state_store, inner: Arc::new(Mutex::new(inner)), }) } @@ -299,62 +314,75 @@ where dataset_id, max_staleness_mins, max_concurrent_streams, - store, + state_store: store, inner: Arc::new(Mutex::new(inner)), }) } /// Prepares a table for CDC streaming operations with schema-aware table creation. /// - /// Retrieves the table schema from the store, creates or verifies the BigQuery table exists, + /// Creates or verifies the BigQuery table exists using the provided schema, /// and ensures the view points to the current versioned table. Uses caching to avoid /// redundant table creation checks. async fn prepare_table_for_streaming( &self, - table_id: &TableId, + replicated_table_schema: &ReplicatedTableSchema, use_cdc_sequence_column: bool, - ) -> EtlResult<(SequencedBigQueryTableId, Arc)> { + ) -> EtlResult<(SequencedBigQueryTableId, TableDescriptor)> { // We hold the lock for the entire preparation to avoid race conditions since the consistency // of this code path is critical. let mut inner = self.inner.lock().await; - // We load the schema of the table, if present. This is needed to create the table in BigQuery - // and also prepare the table descriptor for CDC streaming. - let table_schema = self - .store - .get_table_schema(table_id) - .await? - .ok_or_else(|| { - etl_error!( - ErrorKind::MissingTableSchema, - "Table not found in the schema store", - format!( - "The table schema for table {table_id} was not found in the schema store" - ) - ) - })?; + let table_id = replicated_table_schema.id(); // We determine the BigQuery table ID for the table together with the current sequence number. - let bigquery_table_id = table_name_to_bigquery_table_id(&table_schema.name); - let sequenced_bigquery_table_id = self - .get_or_create_sequenced_bigquery_table_id(table_id, &bigquery_table_id) + let bigquery_table_id = table_name_to_bigquery_table_id(replicated_table_schema.name()); + let snapshot_id = replicated_table_schema.get_inner().snapshot_id; + let replication_mask = replicated_table_schema.replication_mask().clone(); + + // Check if we have existing metadata for this table. + let existing_metadata = self + .state_store + .get_destination_table_metadata(&table_id) .await?; + let sequenced_bigquery_table_id = match &existing_metadata { + Some(metadata) => metadata.destination_table_id.parse()?, + None => SequencedBigQueryTableId::new(bigquery_table_id.clone()), + }; + // Optimistically skip table creation if we've already seen this sequenced table. // // Note that if the table is deleted outside ETL and the cache marks it as created, the // inserts will fail because the table will be missing and won't be created. if !inner.created_tables.contains(&sequenced_bigquery_table_id) { + // Create metadata with applying status. For new tables, this is the initial insert. + // For existing tables, this updates the status. + let metadata = DestinationTableMetadata::new_applying( + sequenced_bigquery_table_id.to_string(), + snapshot_id, + replication_mask.clone(), + ); + + // Store or update metadata before creating the table. + self.state_store + .store_destination_table_metadata(table_id, metadata.clone()) + .await?; + self.client .create_table_if_missing( &self.dataset_id, - // TODO: down the line we might want to reduce an allocation here. &sequenced_bigquery_table_id.to_string(), - &table_schema.column_schemas, + replicated_table_schema, self.max_staleness_mins, ) .await?; + // Mark as applied after successful table creation. + self.state_store + .store_destination_table_metadata(table_id, metadata.to_applied()) + .await?; + // Add the sequenced table to the cache. Self::add_to_created_tables_cache(&mut inner, &sequenced_bigquery_table_id); @@ -373,12 +401,17 @@ where ) .await?; + // Note: We return TableDescriptor by value for simplicity, which means callers clone it + // when creating multiple batches. This is acceptable because the descriptor is small + // (one String per column) and the cost is negligible compared to network I/O. If profiling + // shows this is a bottleneck, we could wrap it in Arc here and use Arc::unwrap_or_clone + // at the call site to avoid redundant clones. let table_descriptor = BigQueryClient::column_schemas_to_table_descriptor( - &table_schema.column_schemas, + replicated_table_schema, use_cdc_sequence_column, ); - Ok((sequenced_bigquery_table_id, Arc::new(table_descriptor))) + Ok((sequenced_bigquery_table_id, table_descriptor)) } /// Adds a table to the creation cache to avoid redundant existence checks. @@ -390,37 +423,20 @@ where inner.created_tables.insert(table_id.clone()); } - /// Retrieves the current sequenced table ID or creates a new one starting at version 0. - async fn get_or_create_sequenced_bigquery_table_id( - &self, - table_id: &TableId, - bigquery_table_id: &BigQueryTableId, - ) -> EtlResult { - let Some(sequenced_bigquery_table_id) = - self.get_sequenced_bigquery_table_id(table_id).await? - else { - let sequenced_bigquery_table_id = - SequencedBigQueryTableId::new(bigquery_table_id.clone()); - self.store - .store_table_mapping(*table_id, sequenced_bigquery_table_id.to_string()) - .await?; - - return Ok(sequenced_bigquery_table_id); - }; - - Ok(sequenced_bigquery_table_id) - } - - /// Retrieves the current sequenced table ID from the state store. + /// Retrieves the current sequenced table ID from the destination metadata. async fn get_sequenced_bigquery_table_id( &self, table_id: &TableId, ) -> EtlResult> { - let Some(current_table_id) = self.store.get_table_mapping(table_id).await? else { + let Some(metadata) = self + .state_store + .get_destination_table_metadata(table_id) + .await? + else { return Ok(None); }; - let sequenced_bigquery_table_id = current_table_id.parse()?; + let sequenced_bigquery_table_id = metadata.destination_table_id.parse()?; Ok(Some(sequenced_bigquery_table_id)) } @@ -468,14 +484,15 @@ where /// `max_concurrent_streams`, and streams to BigQuery using concurrent processing. async fn write_table_rows( &self, - table_id: TableId, + replicated_table_schema: &ReplicatedTableSchema, mut table_rows: Vec, ) -> EtlResult<()> { // Prepare table for streaming. - let (sequenced_bigquery_table_id, table_descriptor) = - self.prepare_table_for_streaming(&table_id, false).await?; + let (sequenced_bigquery_table_id, table_descriptor) = self + .prepare_table_for_streaming(replicated_table_schema, false) + .await?; - // Add CDC operation type to all rows (no lock needed). + // Add the CDC operation type to all rows (no lock needed). for table_row in table_rows.iter_mut() { table_row .values @@ -499,13 +516,10 @@ where } } - // Stream all the batches concurrently. - if !table_batches.is_empty() { - let (bytes_sent, bytes_received) = self - .client - .stream_table_batches_concurrent(table_batches, self.max_concurrent_streams) - .await?; + // Stream with schema mismatch retry. + let (bytes_sent, bytes_received) = self.stream_with_retry(&table_batches).await?; + if bytes_sent > 0 { // Logs with egress_metric = true can be used to identify egress logs. // This can e.g. be used to send egress logs to a location different // than the other logs. These logs should also have bytes_sent set to @@ -522,23 +536,303 @@ where Ok(()) } + /// Handles a schema change event (Relation) by computing the diff and applying changes. + /// + /// This method: + /// 1. Gets the current destination schema state from the state store. + /// 2. If no state exists, this is the initial schema - just records it as applied. + /// 3. If state exists and snapshot_id differs, computes the diff and applies changes. + /// 4. If state is in `Applying`, a previous change was interrupted - returns an error. + async fn handle_relation_event(&self, new_schema: &ReplicatedTableSchema) -> EtlResult<()> { + let table_id = new_schema.id(); + let new_snapshot_id = new_schema.get_inner().snapshot_id; + + // Get current destination metadata. + let current_metadata = self + .state_store + .get_destination_table_metadata(&table_id) + .await?; + + match current_metadata { + None => { + // No metadata exists, this is a broken invariant since the metadata should + // have been recorded during write_table_rows before any Relation event. + bail!( + ErrorKind::CorruptedTableSchema, + "Missing destination table metadata", + format!( + "No destination table metadata found for table {} when processing schema change. \ + This indicates a broken invariant, the metadata should have been recorded \ + during initial table synchronization.", + table_id + ) + ); + } + Some(metadata) if metadata.is_applying() => { + // A previous schema change was interrupted, require manual intervention since BigQuery + // DDL is not atomic, thus we can't say anything about the table schema. + bail!( + ErrorKind::CorruptedTableSchema, + "Schema change recovery required", + format!( + "A previous schema change for table {} was interrupted at snapshot_id {}. \ + Manual intervention is required to resolve the destination schema state. \ + The previous valid snapshot can be derived from the table_schemas table.", + table_id, metadata.snapshot_id + ) + ); + } + Some(metadata) if metadata.is_applied() => { + let current_snapshot_id = metadata.snapshot_id; + let current_replication_mask = metadata.replication_mask.clone(); + let new_replication_mask = new_schema.replication_mask().clone(); + + // Check both snapshot_id and replication mask - the mask can change + // independently if columns are added/removed from the publication. + if current_snapshot_id == new_snapshot_id + && current_replication_mask == new_replication_mask + { + // Schema hasn't changed, nothing to do. + info!( + "schema for table {} unchanged (snapshot_id: {}, replication_mask: {})", + table_id, new_snapshot_id, new_replication_mask + ); + + return Ok(()); + } + + info!( + "schema change detected for table {}: snapshot_id {} -> {}, mask {} -> {}", + table_id, + current_snapshot_id, + new_snapshot_id, + current_replication_mask, + new_replication_mask + ); + + // Get the old schema from the schema store to compute the diff. + let old_table_schema = self + .state_store + .get_table_schema(&table_id, current_snapshot_id) + .await? + .ok_or_else(|| { + etl_error!( + ErrorKind::InvalidState, + "Old schema not found", + format!( + "Could not find schema for table {} at snapshot_id {}", + table_id, current_snapshot_id + ) + ) + })?; + + // Build a ReplicatedTableSchema using the stored replication mask. + let old_schema = + ReplicatedTableSchema::from_mask(old_table_schema, current_replication_mask); + + // Mark as applying before making changes (with the NEW snapshot_id and mask). + // + // NOTE: BigQuery does not support transactional DDL, so if the system crashes + // while in 'Applying' state, the destination table may be in an inconsistent + // state and manual intervention may be required. The `previous_snapshot_id` + // is stored for debugging purposes but automatic recovery is not possible. + let updated_metadata = metadata.with_schema_change( + new_snapshot_id, + new_replication_mask.clone(), + DestinationTableSchemaStatus::Applying, + ); + self.state_store + .store_destination_table_metadata(table_id, updated_metadata.clone()) + .await?; + + // Compute and apply the diff. + let diff = old_schema.diff(new_schema); + if let Err(err) = self.apply_schema_diff(&table_id, &diff).await { + warn!( + "schema change failed for table {}: {}. Manual intervention may be required.", + table_id, err + ); + return Err(err); + } + + // Mark as applied after successful changes. + self.state_store + .store_destination_table_metadata(table_id, updated_metadata.to_applied()) + .await?; + + info!( + "schema change completed for table {}: snapshot_id {} applied", + table_id, new_snapshot_id + ); + } + Some(_) => unreachable!("All state types are covered"), + } + + Ok(()) + } + + /// Applies a schema diff to the BigQuery table. + /// + /// Executes the necessary DDL operations (ADD COLUMN, DROP COLUMN, RENAME COLUMN) + /// to transform the destination schema. + async fn apply_schema_diff(&self, table_id: &TableId, diff: &SchemaDiff) -> EtlResult<()> { + if diff.is_empty() { + debug!("no schema changes to apply for table {}", table_id); + return Ok(()); + } + + // Get the BigQuery table ID for this table. + let bigquery_table_id = self + .get_sequenced_bigquery_table_id(table_id) + .await? + .ok_or_else(|| { + etl_error!( + ErrorKind::InvalidState, + "Table not found", + format!( + "No BigQuery table mapping found for table {}. Schema changes cannot be applied to a non-existent table.", + table_id + ) + ) + })?; + + info!( + "applying schema changes to table {}: {} additions, {} removals, {} renames", + bigquery_table_id, + diff.columns_to_add.len(), + diff.columns_to_remove.len(), + diff.columns_to_rename.len() + ); + + // Apply column additions first (safest operation). + for column in &diff.columns_to_add { + self.client + .add_column(&self.dataset_id, &bigquery_table_id.to_string(), column) + .await?; + } + + // Apply column renames (must be done before removals in case of position conflicts). + for rename in &diff.columns_to_rename { + self.client + .rename_column( + &self.dataset_id, + &bigquery_table_id.to_string(), + &rename.old_name, + &rename.new_name, + ) + .await?; + } + + // Apply column removals last. + for column in &diff.columns_to_remove { + self.client + .drop_column( + &self.dataset_id, + &bigquery_table_id.to_string(), + &column.name, + ) + .await?; + } + + info!( + "schema changes applied successfully to table {}", + bigquery_table_id + ); + + Ok(()) + } + + /// Checks if an error is a transient streaming error that can be retried. + /// + /// Returns `true` if any of the error kinds is [`ErrorKind::DestinationSchemaMismatch`], + /// which indicates transient BigQuery issues after DDL changes such as: + /// - Stale cached schema information ("extra field" errors) + /// - Streaming endpoint not yet available ("entity not found" errors) + fn is_retryable_streaming_error(error: &EtlError) -> bool { + error + .kinds() + .contains(&ErrorKind::DestinationSchemaMismatch) + } + + /// Streams table batches to BigQuery with automatic retry on transient errors. + /// + /// After DDL changes (e.g., `ALTER TABLE ADD COLUMN`, column renames), the BigQuery + /// Storage Write API may temporarily: + /// - Cache stale schema information and reject inserts with "extra field" errors + /// - Return "entity not found" errors for streaming endpoints + /// + /// This method retries streaming operations when such transient errors are detected, + /// allowing time for BigQuery to propagate DDL changes. + /// + /// Takes a slice of `Arc` to enable efficient retries - on each attempt, + /// we iterate over the slice and clone the `Arc`s (O(1) per batch) rather than + /// recreating the batches. + async fn stream_with_retry( + &self, + table_batches: &[Arc>], + ) -> EtlResult<(usize, usize)> { + if table_batches.is_empty() { + return Ok((0, 0)); + } + + let retry_delay = Duration::from_millis(SCHEMA_MISMATCH_RETRY_DELAY_MS); + let mut attempts = 0; + + loop { + // Clone the Arc references (O(1) per batch) to create an iterator for this attempt. + let batches_iter = table_batches.iter().cloned(); + + match self + .client + .stream_table_batches_concurrent(batches_iter, self.max_concurrent_streams) + .await + { + Ok(result) => return Ok(result), + Err(error) => { + if !Self::is_retryable_streaming_error(&error) { + return Err(error); + } + + attempts += 1; + if attempts >= MAX_SCHEMA_MISMATCH_ATTEMPTS { + return Err(error); + } + + warn!( + attempt = attempts, + max_attempts = MAX_SCHEMA_MISMATCH_ATTEMPTS, + error = %error, + "transient streaming error detected; retrying after delay" + ); + sleep(retry_delay).await; + } + } + } + } + /// Processes CDC events in batches with proper ordering and truncate handling. /// /// Groups streaming operations (insert/update/delete) by table and processes them together, - /// then handles truncate events separately by creating new versioned tables. + /// then handles truncate events separately by creating new versioned tables. Uses the schema + /// from the first event of each table for table creation and descriptor building. async fn write_events(&self, events: Vec) -> EtlResult<()> { - let mut event_iter = events.into_iter().peekable(); - - while event_iter.peek().is_some() { - let mut table_id_to_table_rows = HashMap::new(); - - // Process events until we hit a truncate event or run out of events - while let Some(event) = event_iter.peek() { - if matches!(event, Event::Truncate(_)) { + let mut events_iter = events.into_iter().peekable(); + + while events_iter.peek().is_some() { + // Maps table ID to (schema, rows). We are assuming that the table schema is the same for + // all events within two Relation event boundaries. + let mut table_id_to_data: HashMap)> = + HashMap::new(); + + // Process events until we hit a truncate or relation event, or run out of events. + // Truncate and Relation events require flushing all batched data first before + // they can be processed, to maintain correct ordering. + while let Some(event) = events_iter.peek() { + if matches!(event, Event::Truncate(_) | Event::Relation(_)) { break; } - let event = event_iter.next().unwrap(); + let event = events_iter.next().unwrap(); match event { Event::Insert(mut insert) => { let sequence_number = @@ -549,9 +843,11 @@ where .push(BigQueryOperationType::Upsert.into_cell()); insert.table_row.values.push(Cell::String(sequence_number)); - let table_rows: &mut Vec = - table_id_to_table_rows.entry(insert.table_id).or_default(); - table_rows.push(insert.table_row); + let table_id = insert.replicated_table_schema.id(); + let entry = table_id_to_data.entry(table_id).or_insert_with(|| { + (insert.replicated_table_schema.clone(), Vec::new()) + }); + entry.1.push(insert.table_row); } Event::Update(mut update) => { let sequence_number = @@ -562,9 +858,11 @@ where .push(BigQueryOperationType::Upsert.into_cell()); update.table_row.values.push(Cell::String(sequence_number)); - let table_rows: &mut Vec = - table_id_to_table_rows.entry(update.table_id).or_default(); - table_rows.push(update.table_row); + let table_id = update.replicated_table_schema.id(); + let entry = table_id_to_data.entry(table_id).or_insert_with(|| { + (update.replicated_table_schema.clone(), Vec::new()) + }); + entry.1.push(update.table_row); } Event::Delete(delete) => { let Some((_, mut old_table_row)) = delete.old_table_row else { @@ -579,40 +877,54 @@ where .push(BigQueryOperationType::Delete.into_cell()); old_table_row.values.push(Cell::String(sequence_number)); - let table_rows: &mut Vec = - table_id_to_table_rows.entry(delete.table_id).or_default(); - table_rows.push(old_table_row); + let table_id = delete.replicated_table_schema.id(); + let entry = table_id_to_data.entry(table_id).or_insert_with(|| { + (delete.replicated_table_schema.clone(), Vec::new()) + }); + entry.1.push(old_table_row); } _ => { - // Every other event type is currently not supported. - debug!("skipping unsupported event in BigQuery"); + // Begin, Commit, Unsupported events are skipped. + debug!("skipping non-data event in BigQuery"); } } } // Process accumulated events for each table. - if !table_id_to_table_rows.is_empty() { - let mut table_batches = Vec::with_capacity(table_id_to_table_rows.len()); + if !table_id_to_data.is_empty() { + // Prepare batch metadata for all tables before streaming. + // This collects (sequenced_table_id, table_descriptor, rows) for retry support. + let mut prepared_data: Vec<(String, TableDescriptor, Vec)> = + Vec::with_capacity(table_id_to_data.len()); + + for (_, (replicated_table_schema, table_rows)) in table_id_to_data { + let (sequenced_bigquery_table_id, table_descriptor) = self + .prepare_table_for_streaming(&replicated_table_schema, true) + .await?; - for (table_id, table_rows) in table_id_to_table_rows { - let (sequenced_bigquery_table_id, table_descriptor) = - self.prepare_table_for_streaming(&table_id, true).await?; + prepared_data.push(( + sequenced_bigquery_table_id.to_string(), + table_descriptor, + table_rows, + )); + } + // Create table batches from prepared data. + let mut table_batches = Vec::with_capacity(prepared_data.len()); + for (table_id, descriptor, rows) in prepared_data { let table_batch = self.client.create_table_batch( &self.dataset_id, - &sequenced_bigquery_table_id.to_string(), - table_descriptor.clone(), - table_rows, + &table_id, + descriptor, + rows, )?; table_batches.push(table_batch); } - if !table_batches.is_empty() { - let (bytes_sent, bytes_received) = self - .client - .stream_table_batches_concurrent(table_batches, self.max_concurrent_streams) - .await?; + // Stream with schema mismatch retry. + let (bytes_sent, bytes_received) = self.stream_with_retry(&table_batches).await?; + if bytes_sent > 0 { // Logs with egress_metric = true can be used to identify egress logs. // This can e.g. be used to send egress logs to a location different // than the other logs. These logs should also have bytes_sent set to @@ -627,23 +939,32 @@ where } } - // Collect and deduplicate all table IDs from all truncate events. + // Process any Relation events (schema changes) that caused the batch to flush. + // Multiple consecutive Relation events are processed sequentially. + while let Some(Event::Relation(_)) = events_iter.peek() { + if let Some(Event::Relation(relation)) = events_iter.next() { + self.handle_relation_event(&relation.replicated_table_schema) + .await?; + } + } + + // Collect and deduplicate schemas from all truncate events. // - // This is done as an optimization since if we have multiple table ids being truncated in a + // This is done as an optimization since if we have multiple tables being truncated in a // row without applying other events in the meanwhile, it doesn't make any sense to create // new empty tables for each of them. - let mut truncate_table_ids = HashSet::new(); + let mut truncate_schemas: HashMap = HashMap::new(); - while let Some(Event::Truncate(_)) = event_iter.peek() { - if let Some(Event::Truncate(truncate_event)) = event_iter.next() { - for table_id in truncate_event.rel_ids { - truncate_table_ids.insert(TableId::new(table_id)); + while let Some(Event::Truncate(_)) = events_iter.peek() { + if let Some(Event::Truncate(truncate_event)) = events_iter.next() { + for schema in truncate_event.truncated_tables { + truncate_schemas.insert(schema.id(), schema); } } } - if !truncate_table_ids.is_empty() { - self.process_truncate_for_table_ids(truncate_table_ids.into_iter(), true) + if !truncate_schemas.is_empty() { + self.process_truncate_for_schemas(truncate_schemas.into_values()) .await?; } } @@ -654,50 +975,33 @@ where /// Handles table truncation by creating new versioned tables and updating views. /// /// Creates fresh empty tables with incremented version numbers, updates views to point - /// to new tables, and schedules cleanup of old table versions. Deduplicates table IDs - /// to optimize multiple truncates of the same table. - async fn process_truncate_for_table_ids( + /// to new tables, and schedules cleanup of old table versions. Uses the provided schemas + /// directly instead of looking them up from a store. + async fn process_truncate_for_schemas( &self, - table_ids: impl IntoIterator, - is_cdc_truncate: bool, + replicated_table_schemas: impl IntoIterator, ) -> EtlResult<()> { // We want to lock for the entire processing to ensure that we don't have any race conditions // and possible errors are easier to reason about. let mut inner = self.inner.lock().await; - for table_id in table_ids { - let table_schema = self.store.get_table_schema(&table_id).await?; - // If we are not doing CDC, it means that this truncation has been issued while recovering - // from a failed data sync operation. In that case, we could have failed before table schemas - // were stored in the schema store, so we just continue and emit a warning. If we are doing - // CDC, it's a problem if the schema disappears while streaming, so we error out. - if !is_cdc_truncate { + for replicated_table_schema in replicated_table_schemas { + let table_id = replicated_table_schema.id(); + + // We need to determine the current sequenced table ID for this table. + // + // If no mapping exists, it means the table was never created in BigQuery (e.g., due to + // validation errors during copy). In this case, we skip the truncate since there's + // nothing to truncate. + let Some(sequenced_bigquery_table_id) = + self.get_sequenced_bigquery_table_id(&table_id).await? + else { warn!( - "the table schema for table {table_id} was not found in the schema store while processing truncate events for BigQuery", + "skipping truncate for table {}: no mapping exists (table was likely never created)", + table_id ); - continue; - } - - let table_schema = table_schema.ok_or_else(|| etl_error!( - ErrorKind::MissingTableSchema, - "Table not found in the schema store", - format!( - "The table schema for table {table_id} was not found in the schema store while processing truncate events for BigQuery" - ) - ))?; - - // We need to determine the current sequenced table ID for this table. - let sequenced_bigquery_table_id = - self.get_sequenced_bigquery_table_id(&table_id) - .await? - .ok_or_else(|| etl_error!( - ErrorKind::MissingTableMapping, - "Table mapping not found", - format!( - "The table mapping for table id {table_id} was not found while processing truncate events for BigQuery" - ) - ))?; + }; // We compute the new sequence table ID since we want a new table for each truncate event. let next_sequenced_bigquery_table_id = sequenced_bigquery_table_id.next(); @@ -715,7 +1019,7 @@ where .create_or_replace_table( &self.dataset_id, &next_sequenced_bigquery_table_id.to_string(), - &table_schema.column_schemas, + &replicated_table_schema, self.max_staleness_mins, ) .await?; @@ -731,9 +1035,14 @@ where ) .await?; - // Update the store table mappings to point to the new table. - self.store - .store_table_mapping(table_id, next_sequenced_bigquery_table_id.to_string()) + // Update the metadata to point to the new table. + let metadata = DestinationTableMetadata::new_applied( + next_sequenced_bigquery_table_id.to_string(), + replicated_table_schema.get_inner().snapshot_id, + replicated_table_schema.replication_mask().clone(), + ); + self.state_store + .store_destination_table_metadata(table_id, metadata) .await?; // Please note that the three statements above are not transactional, so if one fails, @@ -791,17 +1100,21 @@ where "bigquery" } - async fn truncate_table(&self, table_id: TableId) -> EtlResult<()> { - self.process_truncate_for_table_ids(iter::once(table_id), false) + async fn truncate_table( + &self, + replicated_table_schema: &ReplicatedTableSchema, + ) -> EtlResult<()> { + self.process_truncate_for_schemas(iter::once(replicated_table_schema.clone())) .await } async fn write_table_rows( &self, - table_id: TableId, + replicated_table_schema: &ReplicatedTableSchema, table_rows: Vec, ) -> EtlResult<()> { - self.write_table_rows(table_id, table_rows).await?; + self.write_table_rows(replicated_table_schema, table_rows) + .await?; Ok(()) } diff --git a/etl-destinations/src/iceberg/client.rs b/etl-destinations/src/iceberg/client.rs index 6927f1548..9964ecce2 100644 --- a/etl-destinations/src/iceberg/client.rs +++ b/etl-destinations/src/iceberg/client.rs @@ -166,8 +166,10 @@ impl IcebergClient { column_schemas: &[ColumnSchema], ) -> Result<(), iceberg::Error> { debug!("creating table {table_name} in namespace {namespace} if missing"); + let namespace_ident = NamespaceIdent::from_strs(namespace.split('.'))?; let table_ident = TableIdent::new(namespace_ident.clone(), table_name.clone()); + if !self.catalog.table_exists(&table_ident).await? { let iceberg_schema = postgres_to_iceberg_schema(column_schemas)?; let creation = TableCreation::builder() @@ -179,6 +181,7 @@ impl IcebergClient { .create_table(&namespace_ident, creation) .await?; } + Ok(()) } diff --git a/etl-destinations/src/iceberg/core.rs b/etl-destinations/src/iceberg/core.rs index 399820244..2c805cf4c 100644 --- a/etl-destinations/src/iceberg/core.rs +++ b/etl-destinations/src/iceberg/core.rs @@ -8,17 +8,15 @@ use crate::iceberg::IcebergClient; use crate::iceberg::error::iceberg_error_to_etl_error; use etl::destination::Destination; use etl::error::{ErrorKind, EtlResult}; -use etl::store::schema::SchemaStore; -use etl::store::state::StateStore; +use etl::etl_error; +use etl::store::state::{DestinationTableMetadata, StateStore}; use etl::types::{ - Cell, ColumnSchema, Event, TableId, TableName, TableRow, TableSchema, Type, + Cell, ColumnSchema, Event, ReplicatedTableSchema, TableId, TableName, TableRow, Type, generate_sequence_number, }; -use etl::{bail, etl_error}; use tokio::sync::Mutex; use tokio::task::JoinSet; -use tracing::log::warn; -use tracing::{debug, info}; +use tracing::{debug, info, warn}; /// CDC operation types for Iceberg changelog tables. /// @@ -159,7 +157,7 @@ struct Inner { impl IcebergDestination where - S: StateStore + SchemaStore + Send + Sync, + S: StateStore + Send + Sync, { /// Creates a new Iceberg destination instance. /// @@ -187,52 +185,27 @@ where /// Removes all data from the target table by dropping the existing Iceberg table /// and creating a fresh empty table with the same schema. Updates the internal /// table creation cache to reflect the new table state. - async fn truncate_table(&self, table_id: TableId, is_cdc_truncate: bool) -> EtlResult<()> { + async fn truncate_table( + &self, + replicated_table_schema: &ReplicatedTableSchema, + ) -> EtlResult<()> { let mut inner = self.inner.lock().await; - - let Some(table_schema) = self.store.get_table_schema(&table_id).await? else { - // If this is not a cdc truncate event, we just raise a warning since it could be that the - // table schema is not there. - if !is_cdc_truncate { - warn!( - "the table schema for table {table_id} was not found in the schema store while processing truncate events for Iceberg", - ); - - return Ok(()); - } - - // If this is a cdc truncate event, the table schema must be there, so we raise an error. - bail!( - ErrorKind::MissingTableSchema, - "Table not found in the schema store", - format!( - "The table schema for table {table_id} was not found in the schema store while processing truncate events for Iceberg" - ) + let table_id = replicated_table_schema.id(); + + // Check if metadata exists for this table. + // + // If no metadata exists, it means the table was never created in Iceberg (e.g., due to + // errors during copy). In this case, we skip the truncate since there's nothing to truncate. + let Some(metadata) = self.store.get_destination_table_metadata(&table_id).await? else { + warn!( + "skipping truncate for table {}: no metadata exists (table was likely never created)", + table_id ); + return Ok(()); }; + let iceberg_table_name = metadata.destination_table_id; - let Some(iceberg_table_name) = self.store.get_table_mapping(&table_id).await? else { - // If this is not a cdc truncate event, we just raise a warning since it could be that the - // table mapping is not there. - if !is_cdc_truncate { - warn!( - "the table mapping for table {table_id} was not found in the state store while processing truncate events for Iceberg", - ); - - return Ok(()); - } - - // If this is a cdc truncate event, the table mapping must be there, so we raise an error. - bail!( - ErrorKind::MissingTableMapping, - "Table mapping not found", - format!( - "The table mapping for table id {table_id} was not found while processing truncate events for Iceberg" - ) - ); - }; - - let namespace = schema_to_namespace(&table_schema.name.schema); + let namespace = schema_to_namespace(&replicated_table_schema.name().schema); let namespace = inner.namespace.get_or(&namespace); self.client @@ -242,7 +215,7 @@ where inner.created_tables.remove(&iceberg_table_name); // We recreate the table with the same schema. - self.prepare_table_for_streaming(&mut inner, table_id) + self.prepare_table_for_streaming(&mut inner, replicated_table_schema) .await?; Ok(()) @@ -255,21 +228,21 @@ where /// All rows are treated as upsert operations in this context. async fn write_table_rows( &self, - table_id: TableId, + replicated_table_schema: &ReplicatedTableSchema, mut table_rows: Vec, ) -> EtlResult<()> { let (namespace, iceberg_table_name) = { // We hold the lock for the entire preparation to avoid race conditions since the consistency // of this code path is critical. let mut inner = self.inner.lock().await; - self.prepare_table_for_streaming(&mut inner, table_id) + self.prepare_table_for_streaming(&mut inner, replicated_table_schema) .await? }; - for row in &mut table_rows { + for table_row in &mut table_rows { let sequence_number = generate_sequence_number(0.into(), 0.into()); - row.values.push(IcebergOperationType::Insert.into()); - row.values.push(Cell::String(sequence_number)); + table_row.values.push(IcebergOperationType::Insert.into()); + table_row.values.push(Cell::String(sequence_number)); } if !table_rows.is_empty() { @@ -279,8 +252,8 @@ where .await?; // Logs with egress_metric = true can be used to identify egress logs. - // This can e.g. be used to send egress logs to a location different - // than the other logs. These logs should also have bytes_sent set to + // This can e.g., be used to send egress logs to a location different + // from the other logs. These logs should also have bytes_sent set to // the number of bytes sent to the destination. info!( bytes_sent, @@ -300,18 +273,21 @@ where /// and deduplicated for efficiency. Each event is augmented with CDC metadata /// including operation type and sequence number based on LSN information. async fn write_events(&self, events: Vec) -> EtlResult<()> { - let mut event_iter = events.into_iter().peekable(); + let mut events_iter = events.into_iter().peekable(); - while event_iter.peek().is_some() { - let mut table_id_to_table_rows = HashMap::new(); + while events_iter.peek().is_some() { + // Maps table ID to (schema, rows); schema is the first one seen for that table. + // Once schema change support is implemented, we will re-implement this. + let mut table_id_to_data: HashMap)> = + HashMap::new(); // Process events until we hit a truncate event or run out of events - while let Some(event) = event_iter.peek() { + while let Some(event) = events_iter.peek() { if matches!(event, Event::Truncate(_)) { break; } - let event = event_iter.next().unwrap(); + let event = events_iter.next().unwrap(); match event { Event::Insert(mut insert) => { let sequence_number = @@ -322,9 +298,11 @@ where .push(IcebergOperationType::Insert.into()); insert.table_row.values.push(Cell::String(sequence_number)); - let table_rows: &mut Vec = - table_id_to_table_rows.entry(insert.table_id).or_default(); - table_rows.push(insert.table_row); + let table_id = insert.replicated_table_schema.id(); + let entry = table_id_to_data.entry(table_id).or_insert_with(|| { + (insert.replicated_table_schema.clone(), Vec::new()) + }); + entry.1.push(insert.table_row); } Event::Update(mut update) => { let sequence_number = @@ -335,9 +313,11 @@ where .push(IcebergOperationType::Update.into()); update.table_row.values.push(Cell::String(sequence_number)); - let table_rows: &mut Vec = - table_id_to_table_rows.entry(update.table_id).or_default(); - table_rows.push(update.table_row); + let table_id = update.replicated_table_schema.id(); + let entry = table_id_to_data.entry(table_id).or_insert_with(|| { + (update.replicated_table_schema.clone(), Vec::new()) + }); + entry.1.push(update.table_row); } Event::Delete(delete) => { let Some((_, mut old_table_row)) = delete.old_table_row else { @@ -352,9 +332,38 @@ where .push(IcebergOperationType::Delete.into()); old_table_row.values.push(Cell::String(sequence_number)); - let table_rows: &mut Vec = - table_id_to_table_rows.entry(delete.table_id).or_default(); - table_rows.push(old_table_row); + let table_id = delete.replicated_table_schema.id(); + let entry = table_id_to_data.entry(table_id).or_insert_with(|| { + (delete.replicated_table_schema.clone(), Vec::new()) + }); + entry.1.push(old_table_row); + } + Event::Relation(relation) => { + // Check if schema has changed - if so, error since Iceberg doesn't + // support schema changes yet. + let table_id = relation.replicated_table_schema.id(); + let new_snapshot_id = + relation.replicated_table_schema.get_inner().snapshot_id; + let new_replication_mask = + relation.replicated_table_schema.replication_mask(); + + if let Some(metadata) = + self.store.get_destination_table_metadata(&table_id).await? + { + if metadata.snapshot_id != new_snapshot_id + || &metadata.replication_mask != new_replication_mask + { + return Err(etl_error!( + ErrorKind::CorruptedTableSchema, + "Schema changes not supported", + format!( + "Iceberg destination does not support schema changes. \ + Table {} schema changed from snapshot_id {} to {}.", + table_id, metadata.snapshot_id, new_snapshot_id + ) + )); + } + } } _ => { // Every other event type is currently not supported. @@ -364,15 +373,15 @@ where } // Process accumulated events for each table. - if !table_id_to_table_rows.is_empty() { + if !table_id_to_data.is_empty() { let mut join_set = JoinSet::new(); - for (table_id, table_rows) in table_id_to_table_rows { + for (_, (replicated_table_schema, table_rows)) in table_id_to_data { let (namespace, iceberg_table_name) = { // We hold the lock for the entire preparation to avoid race conditions since the consistency // of this code path is critical. let mut inner = self.inner.lock().await; - self.prepare_table_for_streaming(&mut inner, table_id) + self.prepare_table_for_streaming(&mut inner, &replicated_table_schema) .await? }; @@ -403,23 +412,23 @@ where ); } - // Collect and deduplicate all table IDs from all truncate events. + // Collect and deduplicate schemas from all truncate events. // - // This is done as an optimization since if we have multiple table ids being truncated in a + // This is done as an optimization since if we have multiple tables being truncated in a // row without applying other events in the meanwhile, it doesn't make any sense to create // new empty tables for each of them. - let mut truncate_table_ids = HashSet::new(); + let mut truncate_schemas: HashMap = HashMap::new(); - while let Some(Event::Truncate(_)) = event_iter.peek() { - if let Some(Event::Truncate(truncate_event)) = event_iter.next() { - for table_id in truncate_event.rel_ids { - truncate_table_ids.insert(TableId::new(table_id)); + while let Some(Event::Truncate(_)) = events_iter.peek() { + if let Some(Event::Truncate(truncate_event)) = events_iter.next() { + for schema in truncate_event.truncated_tables { + truncate_schemas.insert(schema.id(), schema); } } } - for table_id in truncate_table_ids { - self.truncate_table(table_id, true).await?; + for (_, schema) in truncate_schemas { + self.truncate_table(&schema).await?; } } @@ -428,46 +437,76 @@ where /// Prepares a table for CDC streaming operations with schema-aware table creation. /// - /// Retrieves the table schema from the store, augments it with CDC columns, - /// and ensures the corresponding Iceberg table exists in the namespace. - /// Uses caching to avoid redundant table creation checks and holds a lock - /// during the entire preparation to prevent race conditions. + /// Augments the provided schema with CDC columns and ensures the corresponding + /// Iceberg table exists in the namespace. Uses caching to avoid redundant table + /// creation checks and holds a lock during the entire preparation to prevent race conditions. + /// + /// Follows the applying -> applied pattern for crash recovery: + /// 1. Store metadata with `Applying` status before creating the table + /// 2. Create the table + /// 3. Update metadata to `Applied` after successful creation async fn prepare_table_for_streaming( &self, inner: &mut Inner, - table_id: TableId, + replicated_table_schema: &ReplicatedTableSchema, ) -> EtlResult<(String, IcebergTableName)> { - let table_schema = self.get_table_schema(table_id).await?; - let table_schema = Self::modify_schema_with_cdc_columns(&table_schema); + let table_id = replicated_table_schema.id(); + let table_name = replicated_table_schema.name(); + let snapshot_id = replicated_table_schema.get_inner().snapshot_id; + let replication_mask = replicated_table_schema.replication_mask().clone(); + let column_schemas = Self::build_cdc_column_schemas(replicated_table_schema); + + // Check if we have existing metadata for this table. + let existing_metadata = self.store.get_destination_table_metadata(&table_id).await?; let iceberg_table_name = - table_name_to_iceberg_table_name(&table_schema.name, inner.namespace.is_single()); - let iceberg_table_name = self - .get_or_create_iceberg_table_name(&table_id, iceberg_table_name) - .await?; + table_name_to_iceberg_table_name(table_name, inner.namespace.is_single()); + let iceberg_table_name = match &existing_metadata { + Some(metadata) => metadata.destination_table_id.clone(), + None => iceberg_table_name, + }; - let namespace = schema_to_namespace(&table_schema.name.schema); + // We prepare the namespace. + let namespace = schema_to_namespace(&table_name.schema); let namespace = inner.namespace.get_or(&namespace).to_string(); let namespace = self.create_namespace_if_missing(inner, namespace).await?; - let iceberg_table_name = self - .create_table_if_missing(inner, iceberg_table_name, &namespace, &table_schema) + // If the table is already in the cache, we skip the creation. This works assuming that etl + // is the only system managing the underlyind tables. + if inner.created_tables.contains(&iceberg_table_name) { + debug!( + "iceberg table {iceberg_table_name} found in creation cache, skipping existence check" + ); + + return Ok((namespace, iceberg_table_name)); + } + + // Create metadata with applying status before creating the table. + let metadata = DestinationTableMetadata::new_applying( + iceberg_table_name.clone(), + snapshot_id, + replication_mask, + ); + self.store + .store_destination_table_metadata(table_id, metadata.clone()) .await?; - Ok((namespace, iceberg_table_name)) - } + self.client + .create_table_if_missing(&namespace, iceberg_table_name.clone(), &column_schemas) + .await + .map_err(iceberg_error_to_etl_error)?; - async fn get_table_schema(&self, table_id: TableId) -> EtlResult> { + // Mark as applied after successful table creation. self.store - .get_table_schema(&table_id) - .await? - .ok_or_else(|| { - etl_error!( - ErrorKind::MissingTableSchema, - "Table schema not found", - format!("No schema found for table {table_id}") - ) - }) + .store_destination_table_metadata(table_id, metadata.to_applied()) + .await?; + + // We add the table to the cache. + inner.created_tables.insert(iceberg_table_name.clone()); + + debug!("iceberg table {iceberg_table_name} added to creation cache"); + + Ok((namespace, iceberg_table_name)) } /// Creates a namespace if it is missing in the destination. @@ -492,95 +531,50 @@ where Ok(namespace) } - /// Creates a table if it is missing in the destination. - /// Once created adds it to the created_trees HashMap to - /// avoid creating it again. - async fn create_table_if_missing( - &self, - inner: &mut Inner, - iceberg_table_name: String, - namespace: &str, - table_schema: &TableSchema, - ) -> EtlResult { - if inner.created_tables.contains(&iceberg_table_name) { - return Ok(iceberg_table_name); - } - - self.client - .create_table_if_missing( - namespace, - iceberg_table_name.clone(), - &table_schema.column_schemas, - ) - .await - .map_err(iceberg_error_to_etl_error)?; - - inner.created_tables.insert(iceberg_table_name.clone()); - - Ok(iceberg_table_name) - } - - /// Derives a CDC table schema by adding CDC-specific columns. + /// Builds column schemas with CDC-specific columns added. /// - /// Creates a new table schema based on the source table schema with two - /// additional columns for CDC operations: - /// - `cdc_operation`: Tracks whether the row represents an upsert or delete + /// Takes the replicated columns from the schema and adds two additional columns + /// for CDC operations: + /// - `cdc_operation`: Tracks whether the row represents an insert, update, or delete /// - `sequence_number`: Provides ordering information based on WAL LSN /// /// These columns enable CDC consumers to understand the chronological order /// of changes and distinguish between different types of operations. - fn modify_schema_with_cdc_columns(table_schema: &TableSchema) -> TableSchema { - let mut final_schema = table_schema.clone(); + fn build_cdc_column_schemas( + replicated_table_schema: &ReplicatedTableSchema, + ) -> Vec { + let mut column_schemas: Vec = + replicated_table_schema.column_schemas().cloned().collect(); // Add cdc specific columns. - let cdc_operation_col = - find_unique_column_name(&final_schema.column_schemas, CDC_OPERATION_COLUMN_NAME); + let cdc_operation_col = find_unique_column_name(&column_schemas, CDC_OPERATION_COLUMN_NAME); let sequence_number_col = - find_unique_column_name(&final_schema.column_schemas, SEQUENCE_NUMBER_COLUMN_NAME); - - final_schema.add_column_schema(ColumnSchema { - name: cdc_operation_col, - typ: Type::TEXT, - modifier: -1, - nullable: false, - primary: false, - }); - final_schema.add_column_schema(ColumnSchema { - name: sequence_number_col, - typ: Type::TEXT, - modifier: -1, - nullable: false, - primary: false, - }); - final_schema - } - - /// Retrieves or creates a table mapping for the Iceberg table name. - /// - /// Checks if a table mapping already exists for the given table ID. - /// If no mapping exists, creates a new mapping with the provided - /// Iceberg table name. This ensures consistent table name resolution - /// across multiple operations on the same logical table. - async fn get_or_create_iceberg_table_name( - &self, - table_id: &TableId, - iceberg_table_name: IcebergTableName, - ) -> EtlResult { - let Some(iceberg_table_name) = self.store.get_table_mapping(table_id).await? else { - self.store - .store_table_mapping(*table_id, iceberg_table_name.to_string()) - .await?; - - return Ok(iceberg_table_name); - }; - - Ok(iceberg_table_name) + find_unique_column_name(&column_schemas, SEQUENCE_NUMBER_COLUMN_NAME); + + column_schemas.push(ColumnSchema::new( + cdc_operation_col, + Type::TEXT, + -1, + 0, + None, + false, + )); + column_schemas.push(ColumnSchema::new( + sequence_number_col, + Type::TEXT, + -1, + 0, + None, + false, + )); + + column_schemas } } impl Destination for IcebergDestination where - S: StateStore + SchemaStore + Send + Sync, + S: StateStore + Send + Sync, { /// Returns the identifier name for this destination type. fn name() -> &'static str { @@ -591,8 +585,11 @@ where /// /// Removes all data from the target Iceberg table while preserving /// the table schema structure for continued CDC operations. - async fn truncate_table(&self, table_id: TableId) -> EtlResult<()> { - self.truncate_table(table_id, false).await?; + async fn truncate_table( + &self, + replicated_table_schema: &ReplicatedTableSchema, + ) -> EtlResult<()> { + self.truncate_table(replicated_table_schema).await?; Ok(()) } @@ -604,10 +601,11 @@ where /// as upsert operations with generated sequence numbers. async fn write_table_rows( &self, - table_id: TableId, + replicated_table_schema: &ReplicatedTableSchema, table_rows: Vec, ) -> EtlResult<()> { - self.write_table_rows(table_id, table_rows).await?; + self.write_table_rows(replicated_table_schema, table_rows) + .await?; Ok(()) } @@ -697,63 +695,54 @@ mod tests { CDC_OPERATION_COLUMN_NAME, find_unique_column_name, schema_to_namespace, }; + /// Creates a test column schema with common defaults. + /// + /// This helper simplifies column schema creation in tests by providing sensible + /// defaults for fields that are typically not relevant to the test logic. + fn test_column( + name: &str, + typ: Type, + ordinal_position: i32, + nullable: bool, + primary_key_ordinal: Option, + ) -> ColumnSchema { + ColumnSchema::new( + name.to_string(), + typ, + -1, + ordinal_position, + primary_key_ordinal, + nullable, + ) + } + #[test] fn can_find_unique_column_name() { let column_schemas = vec![]; let col_name = find_unique_column_name(&column_schemas, CDC_OPERATION_COLUMN_NAME); assert_eq!(col_name, CDC_OPERATION_COLUMN_NAME.to_string()); - let column_schemas = vec![ColumnSchema { - name: "id".to_string(), - typ: Type::BOOL, - modifier: -1, - nullable: false, - primary: true, - }]; + let column_schemas = vec![test_column("id", Type::BOOL, 1, false, Some(1))]; let col_name = find_unique_column_name(&column_schemas, CDC_OPERATION_COLUMN_NAME); assert_eq!(col_name, CDC_OPERATION_COLUMN_NAME.to_string()); let column_schemas = vec![ - ColumnSchema { - name: "id".to_string(), - typ: Type::BOOL, - modifier: -1, - nullable: false, - primary: true, - }, - ColumnSchema { - name: CDC_OPERATION_COLUMN_NAME.to_string(), - typ: Type::BOOL, - modifier: -1, - nullable: false, - primary: true, - }, + test_column("id", Type::BOOL, 1, false, Some(1)), + test_column(CDC_OPERATION_COLUMN_NAME, Type::BOOL, 2, false, Some(2)), ]; let col_name = find_unique_column_name(&column_schemas, CDC_OPERATION_COLUMN_NAME); assert_eq!(col_name, format!("{CDC_OPERATION_COLUMN_NAME}_1")); let column_schemas = vec![ - ColumnSchema { - name: "id".to_string(), - typ: Type::BOOL, - modifier: -1, - nullable: false, - primary: true, - }, - ColumnSchema { - name: CDC_OPERATION_COLUMN_NAME.to_string(), - typ: Type::BOOL, - modifier: -1, - nullable: false, - primary: true, - }, - ColumnSchema { - name: format!("{CDC_OPERATION_COLUMN_NAME}_1"), - typ: Type::BOOL, - modifier: -1, - nullable: false, - primary: true, - }, + test_column("id", Type::BOOL, 1, false, Some(1)), + test_column(CDC_OPERATION_COLUMN_NAME, Type::BOOL, 2, false, Some(2)), + test_column( + &format!("{CDC_OPERATION_COLUMN_NAME}_1"), + Type::BOOL, + 3, + false, + Some(3), + ), ]; let col_name = find_unique_column_name(&column_schemas, CDC_OPERATION_COLUMN_NAME); assert_eq!(col_name, format!("{CDC_OPERATION_COLUMN_NAME}_2")); diff --git a/etl-destinations/src/iceberg/schema.rs b/etl-destinations/src/iceberg/schema.rs index bc8f72dd2..8a3e67374 100644 --- a/etl-destinations/src/iceberg/schema.rs +++ b/etl-destinations/src/iceberg/schema.rs @@ -79,33 +79,63 @@ fn create_iceberg_list_type(element_type: PrimitiveType, field_id: i32) -> Icebe } /// Converts a Postgres table schema to an Iceberg schema. +/// +/// Primary key columns are converted to Iceberg identifier fields. This enables +/// Iceberg to understand which columns uniquely identify rows in the table. +/// Iceberg identifier fields are unordered (stored as a set). +/// +/// Field IDs are assigned following iceberg-rust's convention: all outer field IDs +/// are assigned first, then nested field IDs (e.g., list element fields). This ensures +/// consistency with how the Iceberg library handles schema evolution. pub fn postgres_to_iceberg_schema( column_schemas: &[ColumnSchema], ) -> Result { + let mut identifier_field_ids = Vec::new(); + + // First pass: assign IDs to all outer fields (1, 2, 3, ...). + let mut outer_field_id = 1; + let outer_fields: Vec<_> = column_schemas + .iter() + .map(|col| { + let id = outer_field_id; + outer_field_id += 1; + (col, id) + }) + .collect(); + + // Second pass: assign IDs to nested fields (list elements) and build the schema. + // Nested field IDs start after all outer field IDs. + let mut nested_field_id = outer_field_id; let mut fields = Vec::new(); - let mut field_id = 1; - - // Convert each column to Iceberg field - for column in column_schemas { - let field_type = if is_array_type(&column.typ) { - // For array types, we need to assign a unique field ID to the list element - // We increment field_id and use it for the element field - field_id += 1; - postgres_array_type_to_iceberg_type(&column.typ, field_id - 1) + + for (column_schema, field_id) in outer_fields { + let field_type = if is_array_type(&column_schema.typ) { + let element_id = nested_field_id; + nested_field_id += 1; + postgres_array_type_to_iceberg_type(&column_schema.typ, element_id) } else { - postgres_scalar_type_to_iceberg_type(&column.typ) + postgres_scalar_type_to_iceberg_type(&column_schema.typ) }; - let field = if column.nullable { - NestedField::optional(field_id, &column.name, field_type) + let field = if column_schema.nullable { + NestedField::optional(field_id, &column_schema.name, field_type) } else { - NestedField::required(field_id, &column.name, field_type) + NestedField::required(field_id, &column_schema.name, field_type) }; fields.push(Arc::new(field)); - field_id += 1; + + if column_schema.primary_key() { + identifier_field_ids.push(field_id); + } + } + + let mut builder = IcebergSchema::builder().with_fields(fields); + + if !identifier_field_ids.is_empty() { + builder = builder.with_identifier_field_ids(identifier_field_ids); } - let schema = IcebergSchema::builder().with_fields(fields).build()?; + let schema = builder.build()?; Ok(schema) } @@ -366,4 +396,66 @@ mod tests { ); } } + + /// Creates a test column schema with common defaults. + fn test_column( + name: &str, + typ: Type, + ordinal_position: i32, + nullable: bool, + primary_key_ordinal: Option, + ) -> ColumnSchema { + ColumnSchema::new( + name.to_string(), + typ, + -1, + ordinal_position, + primary_key_ordinal, + nullable, + ) + } + + #[test] + fn test_identifier_fields_single_primary_key() { + let columns = vec![ + test_column("id", Type::INT4, 1, false, Some(1)), + test_column("name", Type::TEXT, 2, true, None), + ]; + + let schema = postgres_to_iceberg_schema(&columns).expect("schema creation"); + + // id column should be field_id 1 + let ids: Vec = schema.identifier_field_ids().collect(); + assert_eq!(ids, vec![1]); + } + + #[test] + fn test_identifier_fields_composite_primary_key() { + let columns = vec![ + test_column("tenant_id", Type::INT4, 1, false, Some(1)), + test_column("id", Type::INT4, 2, false, Some(2)), + test_column("name", Type::TEXT, 3, true, None), + ]; + + let schema = postgres_to_iceberg_schema(&columns).expect("schema creation"); + + // tenant_id is field_id 1, id is field_id 2 + // Iceberg identifier fields are unordered, so we sort before comparing. + let mut ids: Vec = schema.identifier_field_ids().collect(); + ids.sort(); + assert_eq!(ids, vec![1, 2]); + } + + #[test] + fn test_identifier_fields_no_primary_key() { + let columns = vec![ + test_column("name", Type::TEXT, 1, true, None), + test_column("age", Type::INT4, 2, true, None), + ]; + + let schema = postgres_to_iceberg_schema(&columns).expect("schema creation"); + + let ids: Vec = schema.identifier_field_ids().collect(); + assert!(ids.is_empty()); + } } diff --git a/etl-destinations/tests/bigquery_pipeline.rs b/etl-destinations/tests/bigquery_pipeline.rs index ddf6cdcda..194a14ed0 100644 --- a/etl-destinations/tests/bigquery_pipeline.rs +++ b/etl-destinations/tests/bigquery_pipeline.rs @@ -4,6 +4,7 @@ use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; use etl::config::BatchConfig; use etl::error::ErrorKind; use etl::state::table::TableReplicationPhaseType; +use etl::store::state::StateStore; use etl::test_utils::database::{spawn_source_database, test_table_name}; use etl::test_utils::notify::NotifyingStore; use etl::test_utils::pipeline::{create_pipeline, create_pipeline_with}; @@ -11,6 +12,7 @@ use etl::test_utils::test_destination_wrapper::TestDestinationWrapper; use etl::test_utils::test_schema::{TableSelection, insert_mock_data, setup_test_database_schema}; use etl::types::{EventType, PgNumeric, PipelineId}; use etl_destinations::encryption::install_crypto_provider; +use etl_postgres::tokio::test_utils::TableModification; use etl_telemetry::tracing::init_test_tracing; use rand::random; use std::str::FromStr; @@ -556,7 +558,7 @@ async fn table_nullable_scalar_columns() { table_sync_done_notification.notified().await; - // insert + // Insert let event_notify = destination .wait_for_events_count(vec![(EventType::Insert, 1)]) .await; @@ -600,7 +602,7 @@ async fn table_nullable_scalar_columns() { let parsed_table_rows = parse_bigquery_table_rows::(table_rows); assert_eq!(parsed_table_rows, vec![NullableColsScalar::all_nulls(1),]); - // update + // Update let event_notify = destination .wait_for_events_count(vec![(EventType::Update, 1)]) .await; @@ -1580,6 +1582,7 @@ async fn table_array_with_null_values() { // We have to reset the state of the table and copy it from scratch, otherwise the CDC will contain // the inserts and deletes, failing again. store.reset_table_state(table_id).await.unwrap(); + // We also clear the events so that it's more idiomatic to wait for them, since we don't have // the insert of before. destination.clear_events().await; @@ -1773,3 +1776,177 @@ async fn table_validation_out_of_bounds_values() { .contains(&ErrorKind::UnsupportedValueInDestination) ); } + +#[tokio::test(flavor = "multi_thread")] +async fn table_schema_change() { + init_test_tracing(); + install_crypto_provider(); + + let database = spawn_source_database().await; + let bigquery_database = setup_bigquery_connection().await; + let table_name = test_table_name("schema_multi_ops"); + let table_id = database + .create_table( + table_name.clone(), + true, + &[ + ("name", "text not null"), + ("age", "integer not null"), + ("status", "text"), + ], + ) + .await + .unwrap(); + + let store = NotifyingStore::new(); + let pipeline_id: PipelineId = random(); + let raw_destination = bigquery_database.build_destination(store.clone()).await; + let destination = TestDestinationWrapper::wrap(raw_destination); + + let publication_name = "test_pub_multi_ops".to_string(); + database + .create_publication(&publication_name, std::slice::from_ref(&table_name)) + .await + .expect("Failed to create publication"); + + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + publication_name, + store.clone(), + destination.clone(), + ); + + let table_sync_done = store + .notify_on_table_state_type(table_id, TableReplicationPhaseType::SyncDone) + .await; + + pipeline.start().await.unwrap(); + table_sync_done.notified().await; + + // Insert the initial row. + let event_notify = destination + .wait_for_events_count(vec![(EventType::Insert, 1)]) + .await; + + database + .insert_values( + table_name.clone(), + &["name", "age", "status"], + &[&"Alice", &25, &"active"], + ) + .await + .unwrap(); + + event_notify.notified().await; + + // Verify initial schema. + let initial_schema = bigquery_database + .query_table_schema(table_name.clone()) + .await + .unwrap(); + initial_schema.assert_columns(&["id", "name", "age", "status"]); + + // Verify destination schema state is applied after initial table creation. + let initial_state = store + .get_destination_table_metadata(&table_id) + .await + .unwrap() + .expect("destination schema state should exist after table creation"); + assert!( + initial_state.is_applied(), + "initial destination schema state should be applied" + ); + let initial_snapshot_id = initial_state.snapshot_id; + + // Apply multiple schema changes: + // 1. Rename name -> full_name + // 2. Drop the status column + // 3. Add email column + // + // Note: Each DDL change is captured via the DDL event trigger and stored in the schema + // store, but PostgreSQL sends only ONE Relation message with the final schema when the + // next DML operation (INSERT) occurs. The schema diffing in handle_relation_event then + // computes and applies all changes at once. + let event_notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + database + .alter_table( + table_name.clone(), + &[TableModification::RenameColumn { + old_name: "name", + new_name: "full_name", + }], + ) + .await + .unwrap(); + + database + .alter_table( + table_name.clone(), + &[TableModification::DropColumn { name: "status" }], + ) + .await + .unwrap(); + + database + .alter_table( + table_name.clone(), + &[TableModification::AddColumn { + name: "email", + data_type: "text", + }], + ) + .await + .unwrap(); + + // Insert row with new schema. + database + .insert_values( + table_name.clone(), + &["full_name", "age", "email"], + &[&"Bob", &30, &"bob@example.com"], + ) + .await + .unwrap(); + + event_notify.notified().await; + + // Give BigQuery time to apply all schema changes. + sleep(Duration::from_secs(3)).await; + + pipeline.shutdown_and_wait().await.unwrap(); + + // Verify the final schema: + // - name should be renamed to full_name + // - status should be dropped + // - email should be added + let final_schema = bigquery_database + .query_table_schema(table_name.clone()) + .await + .unwrap(); + final_schema.assert_columns(&["id", "full_name", "age", "email"]); + final_schema.assert_no_column("name"); + final_schema.assert_no_column("status"); + + // Verify destination schema state is applied after schema changes. + let final_state = store + .get_destination_table_metadata(&table_id) + .await + .unwrap() + .expect("destination schema state should exist after schema change"); + assert!( + final_state.is_applied(), + "final destination schema state should be applied" + ); + assert!( + final_state.snapshot_id > initial_snapshot_id, + "snapshot_id should have increased after schema change" + ); + + // Verify data was inserted correctly. + let rows = bigquery_database.query_table(table_name).await.unwrap(); + assert_eq!(rows.len(), 2); +} diff --git a/etl-destinations/tests/iceberg_client.rs b/etl-destinations/tests/iceberg_client.rs index 276350d9f..615e1c6d4 100644 --- a/etl-destinations/tests/iceberg_client.rs +++ b/etl-destinations/tests/iceberg_client.rs @@ -6,6 +6,27 @@ use etl_destinations::iceberg::IcebergClient; use etl_telemetry::tracing::init_test_tracing; use uuid::Uuid; +/// Creates a test column schema with common defaults. +/// +/// This helper simplifies column schema creation in tests by providing sensible +/// defaults for fields that are typically not relevant to the test logic. +fn test_column( + name: &str, + typ: Type, + ordinal_position: i32, + nullable: bool, + primary_key_ordinal_position: Option, +) -> ColumnSchema { + ColumnSchema::new( + name.to_string(), + typ, + -1, + ordinal_position, + primary_key_ordinal_position, + nullable, + ) +} + use crate::support::{ iceberg::{LAKEKEEPER_URL, create_props, get_catalog_url, read_all_rows}, lakekeeper::LakekeeperClient, @@ -111,198 +132,66 @@ async fn create_table_if_missing() { let table_name = "test_table".to_string(); let column_schemas = vec![ // Primary key - ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true), + test_column("id", Type::INT4, 1, false, Some(1)), // Boolean types - ColumnSchema::new("bool_col".to_string(), Type::BOOL, -1, true, false), + test_column("bool_col", Type::BOOL, 2, true, None), // String types - ColumnSchema::new("char_col".to_string(), Type::CHAR, -1, true, false), - ColumnSchema::new("bpchar_col".to_string(), Type::BPCHAR, -1, true, false), - ColumnSchema::new("varchar_col".to_string(), Type::VARCHAR, -1, true, false), - ColumnSchema::new("name_col".to_string(), Type::NAME, -1, true, false), - ColumnSchema::new("text_col".to_string(), Type::TEXT, -1, true, false), + test_column("char_col", Type::CHAR, 3, true, None), + test_column("bpchar_col", Type::BPCHAR, 4, true, None), + test_column("varchar_col", Type::VARCHAR, 5, true, None), + test_column("name_col", Type::NAME, 6, true, None), + test_column("text_col", Type::TEXT, 7, true, None), // Integer types - ColumnSchema::new("int2_col".to_string(), Type::INT2, -1, true, false), - ColumnSchema::new("int4_col".to_string(), Type::INT4, -1, true, false), - ColumnSchema::new("int8_col".to_string(), Type::INT8, -1, true, false), + test_column("int2_col", Type::INT2, 8, true, None), + test_column("int4_col", Type::INT4, 9, true, None), + test_column("int8_col", Type::INT8, 10, true, None), // Float types - ColumnSchema::new("float4_col".to_string(), Type::FLOAT4, -1, true, false), - ColumnSchema::new("float8_col".to_string(), Type::FLOAT8, -1, true, false), + test_column("float4_col", Type::FLOAT4, 11, true, None), + test_column("float8_col", Type::FLOAT8, 12, true, None), // Numeric type - ColumnSchema::new("numeric_col".to_string(), Type::NUMERIC, -1, true, false), + test_column("numeric_col", Type::NUMERIC, 13, true, None), // Date/Time types - ColumnSchema::new("date_col".to_string(), Type::DATE, -1, true, false), - ColumnSchema::new("time_col".to_string(), Type::TIME, -1, true, false), - ColumnSchema::new( - "timestamp_col".to_string(), - Type::TIMESTAMP, - -1, - true, - false, - ), - ColumnSchema::new( - "timestamptz_col".to_string(), - Type::TIMESTAMPTZ, - -1, - true, - false, - ), + test_column("date_col", Type::DATE, 14, true, None), + test_column("time_col", Type::TIME, 15, true, None), + test_column("timestamp_col", Type::TIMESTAMP, 16, true, None), + test_column("timestamptz_col", Type::TIMESTAMPTZ, 17, true, None), // UUID type - ColumnSchema::new("uuid_col".to_string(), Type::UUID, -1, true, false), + test_column("uuid_col", Type::UUID, 18, true, None), // JSON types - ColumnSchema::new("json_col".to_string(), Type::JSON, -1, true, false), - ColumnSchema::new("jsonb_col".to_string(), Type::JSONB, -1, true, false), + test_column("json_col", Type::JSON, 19, true, None), + test_column("jsonb_col", Type::JSONB, 20, true, None), // OID type - ColumnSchema::new("oid_col".to_string(), Type::OID, -1, true, false), + test_column("oid_col", Type::OID, 21, true, None), // Binary type - ColumnSchema::new("bytea_col".to_string(), Type::BYTEA, -1, true, false), + test_column("bytea_col", Type::BYTEA, 22, true, None), // Array types - ColumnSchema::new( - "bool_array_col".to_string(), - Type::BOOL_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "char_array_col".to_string(), - Type::CHAR_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "bpchar_array_col".to_string(), - Type::BPCHAR_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "varchar_array_col".to_string(), - Type::VARCHAR_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "name_array_col".to_string(), - Type::NAME_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "text_array_col".to_string(), - Type::TEXT_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "int2_array_col".to_string(), - Type::INT2_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "int4_array_col".to_string(), - Type::INT4_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "int8_array_col".to_string(), - Type::INT8_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "float4_array_col".to_string(), - Type::FLOAT4_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "float8_array_col".to_string(), - Type::FLOAT8_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "numeric_array_col".to_string(), - Type::NUMERIC_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "date_array_col".to_string(), - Type::DATE_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "time_array_col".to_string(), - Type::TIME_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "timestamp_array_col".to_string(), - Type::TIMESTAMP_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "timestamptz_array_col".to_string(), + test_column("bool_array_col", Type::BOOL_ARRAY, 23, true, None), + test_column("char_array_col", Type::CHAR_ARRAY, 24, true, None), + test_column("bpchar_array_col", Type::BPCHAR_ARRAY, 25, true, None), + test_column("varchar_array_col", Type::VARCHAR_ARRAY, 26, true, None), + test_column("name_array_col", Type::NAME_ARRAY, 27, true, None), + test_column("text_array_col", Type::TEXT_ARRAY, 28, true, None), + test_column("int2_array_col", Type::INT2_ARRAY, 29, true, None), + test_column("int4_array_col", Type::INT4_ARRAY, 30, true, None), + test_column("int8_array_col", Type::INT8_ARRAY, 31, true, None), + test_column("float4_array_col", Type::FLOAT4_ARRAY, 32, true, None), + test_column("float8_array_col", Type::FLOAT8_ARRAY, 33, true, None), + test_column("numeric_array_col", Type::NUMERIC_ARRAY, 34, true, None), + test_column("date_array_col", Type::DATE_ARRAY, 35, true, None), + test_column("time_array_col", Type::TIME_ARRAY, 36, true, None), + test_column("timestamp_array_col", Type::TIMESTAMP_ARRAY, 37, true, None), + test_column( + "timestamptz_array_col", Type::TIMESTAMPTZ_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "uuid_array_col".to_string(), - Type::UUID_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "json_array_col".to_string(), - Type::JSON_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "jsonb_array_col".to_string(), - Type::JSONB_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "oid_array_col".to_string(), - Type::OID_ARRAY, - -1, + 38, true, - false, - ), - ColumnSchema::new( - "bytea_array_col".to_string(), - Type::BYTEA_ARRAY, - -1, - true, - false, + None, ), + test_column("uuid_array_col", Type::UUID_ARRAY, 39, true, None), + test_column("json_array_col", Type::JSON_ARRAY, 40, true, None), + test_column("jsonb_array_col", Type::JSONB_ARRAY, 41, true, None), + test_column("oid_array_col", Type::OID_ARRAY, 42, true, None), + test_column("bytea_array_col", Type::BYTEA_ARRAY, 43, true, None), ]; // table doesn't exist yet @@ -327,6 +216,19 @@ async fn create_table_if_missing() { .unwrap() ); + // Verify identifier fields are set correctly + let table = client + .load_table(namespace.to_string(), table_name.clone()) + .await + .unwrap(); + let identifier_field_ids: Vec = table + .metadata() + .current_schema() + .identifier_field_ids() + .collect(); + // The "id" column is the primary key and should be the only identifier field (field_id = 1) + assert_eq!(identifier_field_ids, vec![1]); + // Creating the same table again should be a no-op (no error) client .create_table_if_missing(namespace, table_name.clone(), &column_schemas) @@ -373,13 +275,7 @@ async fn drop_table_if_exists_is_idempotent() { // Create a simple table schema let table_name = "test_table".to_string(); - let column_schemas = vec![ColumnSchema::new( - "id".to_string(), - Type::INT4, - -1, - false, - true, - )]; + let column_schemas = vec![test_column("id", Type::INT4, 1, false, Some(1))]; // Create table client @@ -451,50 +347,38 @@ async fn insert_nullable_scalars() { let table_name = "test_table".to_string(); let column_schemas = vec![ // Primary key - ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true), + test_column("id", Type::INT4, 1, false, Some(1)), // Boolean types - ColumnSchema::new("bool_col".to_string(), Type::BOOL, -1, true, false), + test_column("bool_col", Type::BOOL, 2, true, None), // String types - ColumnSchema::new("char_col".to_string(), Type::CHAR, -1, true, false), - ColumnSchema::new("bpchar_col".to_string(), Type::BPCHAR, -1, true, false), - ColumnSchema::new("varchar_col".to_string(), Type::VARCHAR, -1, true, false), - ColumnSchema::new("name_col".to_string(), Type::NAME, -1, true, false), - ColumnSchema::new("text_col".to_string(), Type::TEXT, -1, true, false), + test_column("char_col", Type::CHAR, 3, true, None), + test_column("bpchar_col", Type::BPCHAR, 4, true, None), + test_column("varchar_col", Type::VARCHAR, 5, true, None), + test_column("name_col", Type::NAME, 6, true, None), + test_column("text_col", Type::TEXT, 7, true, None), // Integer types - ColumnSchema::new("int2_col".to_string(), Type::INT2, -1, true, false), - ColumnSchema::new("int4_col".to_string(), Type::INT4, -1, true, false), - ColumnSchema::new("int8_col".to_string(), Type::INT8, -1, true, false), + test_column("int2_col", Type::INT2, 8, true, None), + test_column("int4_col", Type::INT4, 9, true, None), + test_column("int8_col", Type::INT8, 10, true, None), // Float types - ColumnSchema::new("float4_col".to_string(), Type::FLOAT4, -1, true, false), - ColumnSchema::new("float8_col".to_string(), Type::FLOAT8, -1, true, false), + test_column("float4_col", Type::FLOAT4, 11, true, None), + test_column("float8_col", Type::FLOAT8, 12, true, None), // Numeric type - ColumnSchema::new("numeric_col".to_string(), Type::NUMERIC, -1, true, false), + test_column("numeric_col", Type::NUMERIC, 13, true, None), // Date/Time types - ColumnSchema::new("date_col".to_string(), Type::DATE, -1, true, false), - ColumnSchema::new("time_col".to_string(), Type::TIME, -1, true, false), - ColumnSchema::new( - "timestamp_col".to_string(), - Type::TIMESTAMP, - -1, - true, - false, - ), - ColumnSchema::new( - "timestamptz_col".to_string(), - Type::TIMESTAMPTZ, - -1, - true, - false, - ), + test_column("date_col", Type::DATE, 14, true, None), + test_column("time_col", Type::TIME, 15, true, None), + test_column("timestamp_col", Type::TIMESTAMP, 16, true, None), + test_column("timestamptz_col", Type::TIMESTAMPTZ, 17, true, None), // UUID type - ColumnSchema::new("uuid_col".to_string(), Type::UUID, -1, true, false), + test_column("uuid_col", Type::UUID, 18, true, None), // JSON types - ColumnSchema::new("json_col".to_string(), Type::JSON, -1, true, false), - ColumnSchema::new("jsonb_col".to_string(), Type::JSONB, -1, true, false), + test_column("json_col", Type::JSON, 19, true, None), + test_column("jsonb_col", Type::JSONB, 20, true, None), // OID type - ColumnSchema::new("oid_col".to_string(), Type::OID, -1, true, false), + test_column("oid_col", Type::OID, 21, true, None), // Binary type - ColumnSchema::new("bytea_col".to_string(), Type::BYTEA, -1, true, false), + test_column("bytea_col", Type::BYTEA, 22, true, None), ]; client @@ -620,50 +504,38 @@ async fn insert_non_nullable_scalars() { let table_name = "test_table".to_string(); let column_schemas = vec![ // Primary key - ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true), + test_column("id", Type::INT4, 1, false, Some(1)), // Boolean types - ColumnSchema::new("bool_col".to_string(), Type::BOOL, -1, false, false), + test_column("bool_col", Type::BOOL, 2, false, None), // String types - ColumnSchema::new("char_col".to_string(), Type::CHAR, -1, false, false), - ColumnSchema::new("bpchar_col".to_string(), Type::BPCHAR, -1, false, false), - ColumnSchema::new("varchar_col".to_string(), Type::VARCHAR, -1, false, false), - ColumnSchema::new("name_col".to_string(), Type::NAME, -1, false, false), - ColumnSchema::new("text_col".to_string(), Type::TEXT, -1, false, false), + test_column("char_col", Type::CHAR, 3, false, None), + test_column("bpchar_col", Type::BPCHAR, 4, false, None), + test_column("varchar_col", Type::VARCHAR, 5, false, None), + test_column("name_col", Type::NAME, 6, false, None), + test_column("text_col", Type::TEXT, 7, false, None), // Integer types - ColumnSchema::new("int2_col".to_string(), Type::INT2, -1, false, false), - ColumnSchema::new("int4_col".to_string(), Type::INT4, -1, false, false), - ColumnSchema::new("int8_col".to_string(), Type::INT8, -1, false, false), + test_column("int2_col", Type::INT2, 8, false, None), + test_column("int4_col", Type::INT4, 9, false, None), + test_column("int8_col", Type::INT8, 10, false, None), // Float types - ColumnSchema::new("float4_col".to_string(), Type::FLOAT4, -1, false, false), - ColumnSchema::new("float8_col".to_string(), Type::FLOAT8, -1, false, false), + test_column("float4_col", Type::FLOAT4, 11, false, None), + test_column("float8_col", Type::FLOAT8, 12, false, None), // Numeric type - ColumnSchema::new("numeric_col".to_string(), Type::NUMERIC, -1, false, false), + test_column("numeric_col", Type::NUMERIC, 13, false, None), // Date/Time types - ColumnSchema::new("date_col".to_string(), Type::DATE, -1, false, false), - ColumnSchema::new("time_col".to_string(), Type::TIME, -1, false, false), - ColumnSchema::new( - "timestamp_col".to_string(), - Type::TIMESTAMP, - -1, - false, - false, - ), - ColumnSchema::new( - "timestamptz_col".to_string(), - Type::TIMESTAMPTZ, - -1, - false, - false, - ), + test_column("date_col", Type::DATE, 14, false, None), + test_column("time_col", Type::TIME, 15, false, None), + test_column("timestamp_col", Type::TIMESTAMP, 16, false, None), + test_column("timestamptz_col", Type::TIMESTAMPTZ, 17, false, None), // UUID type - ColumnSchema::new("uuid_col".to_string(), Type::UUID, -1, false, false), + test_column("uuid_col", Type::UUID, 18, false, None), // JSON types - ColumnSchema::new("json_col".to_string(), Type::JSON, -1, false, false), - ColumnSchema::new("jsonb_col".to_string(), Type::JSONB, -1, false, false), + test_column("json_col", Type::JSON, 19, false, None), + test_column("jsonb_col", Type::JSONB, 20, false, None), // OID type - ColumnSchema::new("oid_col".to_string(), Type::OID, -1, false, false), + test_column("oid_col", Type::OID, 21, false, None), // Binary type - ColumnSchema::new("bytea_col".to_string(), Type::BYTEA, -1, false, false), + test_column("bytea_col", Type::BYTEA, 22, false, None), ]; client @@ -763,164 +635,44 @@ async fn insert_nullable_array() { let table_name = "test_array_table".to_string(); let column_schemas = vec![ // Primary key - ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true), + test_column("id", Type::INT4, 1, false, Some(1)), // Boolean array type - ColumnSchema::new( - "bool_array_col".to_string(), - Type::BOOL_ARRAY, - -1, - true, - false, - ), + test_column("bool_array_col", Type::BOOL_ARRAY, 2, true, None), // String array types - ColumnSchema::new( - "char_array_col".to_string(), - Type::CHAR_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "bpchar_array_col".to_string(), - Type::BPCHAR_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "varchar_array_col".to_string(), - Type::VARCHAR_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "name_array_col".to_string(), - Type::NAME_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "text_array_col".to_string(), - Type::TEXT_ARRAY, - -1, - true, - false, - ), + test_column("char_array_col", Type::CHAR_ARRAY, 3, true, None), + test_column("bpchar_array_col", Type::BPCHAR_ARRAY, 4, true, None), + test_column("varchar_array_col", Type::VARCHAR_ARRAY, 5, true, None), + test_column("name_array_col", Type::NAME_ARRAY, 6, true, None), + test_column("text_array_col", Type::TEXT_ARRAY, 7, true, None), // Integer array types - ColumnSchema::new( - "int2_array_col".to_string(), - Type::INT2_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "int4_array_col".to_string(), - Type::INT4_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "int8_array_col".to_string(), - Type::INT8_ARRAY, - -1, - true, - false, - ), + test_column("int2_array_col", Type::INT2_ARRAY, 8, true, None), + test_column("int4_array_col", Type::INT4_ARRAY, 9, true, None), + test_column("int8_array_col", Type::INT8_ARRAY, 10, true, None), // Float array types - ColumnSchema::new( - "float4_array_col".to_string(), - Type::FLOAT4_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "float8_array_col".to_string(), - Type::FLOAT8_ARRAY, - -1, - true, - false, - ), + test_column("float4_array_col", Type::FLOAT4_ARRAY, 11, true, None), + test_column("float8_array_col", Type::FLOAT8_ARRAY, 12, true, None), // Numeric array type - ColumnSchema::new( - "numeric_array_col".to_string(), - Type::NUMERIC_ARRAY, - -1, - true, - false, - ), + test_column("numeric_array_col", Type::NUMERIC_ARRAY, 13, true, None), // Date/Time array types - ColumnSchema::new( - "date_array_col".to_string(), - Type::DATE_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "time_array_col".to_string(), - Type::TIME_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "timestamp_array_col".to_string(), - Type::TIMESTAMP_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "timestamptz_array_col".to_string(), + test_column("date_array_col", Type::DATE_ARRAY, 14, true, None), + test_column("time_array_col", Type::TIME_ARRAY, 15, true, None), + test_column("timestamp_array_col", Type::TIMESTAMP_ARRAY, 16, true, None), + test_column( + "timestamptz_array_col", Type::TIMESTAMPTZ_ARRAY, - -1, + 17, true, - false, + None, ), // UUID array type - ColumnSchema::new( - "uuid_array_col".to_string(), - Type::UUID_ARRAY, - -1, - true, - false, - ), + test_column("uuid_array_col", Type::UUID_ARRAY, 18, true, None), // JSON array types - ColumnSchema::new( - "json_array_col".to_string(), - Type::JSON_ARRAY, - -1, - true, - false, - ), - ColumnSchema::new( - "jsonb_array_col".to_string(), - Type::JSONB_ARRAY, - -1, - true, - false, - ), + test_column("json_array_col", Type::JSON_ARRAY, 19, true, None), + test_column("jsonb_array_col", Type::JSONB_ARRAY, 20, true, None), // OID array type - ColumnSchema::new( - "oid_array_col".to_string(), - Type::OID_ARRAY, - -1, - true, - false, - ), + test_column("oid_array_col", Type::OID_ARRAY, 21, true, None), // Binary array type - ColumnSchema::new( - "bytea_array_col".to_string(), - Type::BYTEA_ARRAY, - -1, - true, - false, - ), + test_column("bytea_array_col", Type::BYTEA_ARRAY, 22, true, None), ]; client @@ -1134,164 +886,50 @@ async fn insert_non_nullable_array() { let table_name = "test_non_nullable_array_table".to_string(); let column_schemas = vec![ // Primary key - ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true), + test_column("id", Type::INT4, 1, false, Some(1)), // Boolean array type - ColumnSchema::new( - "bool_array_col".to_string(), - Type::BOOL_ARRAY, - -1, - false, - false, - ), + test_column("bool_array_col", Type::BOOL_ARRAY, 2, false, None), // String array types - ColumnSchema::new( - "char_array_col".to_string(), - Type::CHAR_ARRAY, - -1, - false, - false, - ), - ColumnSchema::new( - "bpchar_array_col".to_string(), - Type::BPCHAR_ARRAY, - -1, - false, - false, - ), - ColumnSchema::new( - "varchar_array_col".to_string(), - Type::VARCHAR_ARRAY, - -1, - false, - false, - ), - ColumnSchema::new( - "name_array_col".to_string(), - Type::NAME_ARRAY, - -1, - false, - false, - ), - ColumnSchema::new( - "text_array_col".to_string(), - Type::TEXT_ARRAY, - -1, - false, - false, - ), + test_column("char_array_col", Type::CHAR_ARRAY, 3, false, None), + test_column("bpchar_array_col", Type::BPCHAR_ARRAY, 4, false, None), + test_column("varchar_array_col", Type::VARCHAR_ARRAY, 5, false, None), + test_column("name_array_col", Type::NAME_ARRAY, 6, false, None), + test_column("text_array_col", Type::TEXT_ARRAY, 7, false, None), // Integer array types - ColumnSchema::new( - "int2_array_col".to_string(), - Type::INT2_ARRAY, - -1, - false, - false, - ), - ColumnSchema::new( - "int4_array_col".to_string(), - Type::INT4_ARRAY, - -1, - false, - false, - ), - ColumnSchema::new( - "int8_array_col".to_string(), - Type::INT8_ARRAY, - -1, - false, - false, - ), + test_column("int2_array_col", Type::INT2_ARRAY, 8, false, None), + test_column("int4_array_col", Type::INT4_ARRAY, 9, false, None), + test_column("int8_array_col", Type::INT8_ARRAY, 10, false, None), // Float array types - ColumnSchema::new( - "float4_array_col".to_string(), - Type::FLOAT4_ARRAY, - -1, - false, - false, - ), - ColumnSchema::new( - "float8_array_col".to_string(), - Type::FLOAT8_ARRAY, - -1, - false, - false, - ), + test_column("float4_array_col", Type::FLOAT4_ARRAY, 11, false, None), + test_column("float8_array_col", Type::FLOAT8_ARRAY, 12, false, None), // Numeric array type - ColumnSchema::new( - "numeric_array_col".to_string(), - Type::NUMERIC_ARRAY, - -1, - false, - false, - ), + test_column("numeric_array_col", Type::NUMERIC_ARRAY, 13, false, None), // Date/Time array types - ColumnSchema::new( - "date_array_col".to_string(), - Type::DATE_ARRAY, - -1, - false, - false, - ), - ColumnSchema::new( - "time_array_col".to_string(), - Type::TIME_ARRAY, - -1, - false, - false, - ), - ColumnSchema::new( - "timestamp_array_col".to_string(), + test_column("date_array_col", Type::DATE_ARRAY, 14, false, None), + test_column("time_array_col", Type::TIME_ARRAY, 15, false, None), + test_column( + "timestamp_array_col", Type::TIMESTAMP_ARRAY, - -1, - false, + 16, false, + None, ), - ColumnSchema::new( - "timestamptz_array_col".to_string(), + test_column( + "timestamptz_array_col", Type::TIMESTAMPTZ_ARRAY, - -1, - false, + 17, false, + None, ), // UUID array type - ColumnSchema::new( - "uuid_array_col".to_string(), - Type::UUID_ARRAY, - -1, - false, - false, - ), + test_column("uuid_array_col", Type::UUID_ARRAY, 18, false, None), // JSON array types - ColumnSchema::new( - "json_array_col".to_string(), - Type::JSON_ARRAY, - -1, - false, - false, - ), - ColumnSchema::new( - "jsonb_array_col".to_string(), - Type::JSONB_ARRAY, - -1, - false, - false, - ), + test_column("json_array_col", Type::JSON_ARRAY, 19, false, None), + test_column("jsonb_array_col", Type::JSONB_ARRAY, 20, false, None), // OID array type - ColumnSchema::new( - "oid_array_col".to_string(), - Type::OID_ARRAY, - -1, - false, - false, - ), + test_column("oid_array_col", Type::OID_ARRAY, 21, false, None), // Binary array type - ColumnSchema::new( - "bytea_array_col".to_string(), - Type::BYTEA_ARRAY, - -1, - false, - false, - ), + test_column("bytea_array_col", Type::BYTEA_ARRAY, 22, false, None), ]; client diff --git a/etl-destinations/tests/support/bigquery.rs b/etl-destinations/tests/support/bigquery.rs index d68a72818..c689f2021 100644 --- a/etl-destinations/tests/support/bigquery.rs +++ b/etl-destinations/tests/support/bigquery.rs @@ -161,6 +161,47 @@ impl BigQueryDatabase { } } + /// Queries the schema (column metadata) for a table. + /// + /// Returns the column names and data types from INFORMATION_SCHEMA.COLUMNS. + /// The table name pattern matches using REGEXP_CONTAINS to match the sequenced + /// table name format: `{table_id}_{sequence_number}`. + pub async fn query_table_schema(&self, table_name: TableName) -> Option { + let client = self.client().unwrap(); + + let project_id = self.project_id(); + let dataset_id = self.dataset_id(); + let table_id = table_name_to_bigquery_table_id(&table_name); + + // Use REGEXP_CONTAINS to match the sequenced table name format. + // BigQuery table names have format: {schema}_{table}_{sequence_number} + // The regex matches the table_id followed by underscore and one or more digits. + let query = format!( + "SELECT column_name, data_type, ordinal_position \ + FROM `{project_id}.{dataset_id}.INFORMATION_SCHEMA.COLUMNS` \ + WHERE REGEXP_CONTAINS(table_name, r'^{table_id}_[0-9]+$') \ + ORDER BY ordinal_position" + ); + + let mut attempts_remaining = BIGQUERY_QUERY_MAX_ATTEMPTS; + + loop { + let rows = client + .job() + .query(project_id, QueryRequest::new(query.clone())) + .await + .unwrap() + .rows; + + if rows.is_some() || attempts_remaining == 1 { + return rows.map(|r| BigQueryTableSchema::new(parse_bigquery_table_rows(r))); + } + + attempts_remaining -= 1; + sleep(Duration::from_millis(BIGQUERY_QUERY_RETRY_DELAY_MS)).await; + } + } + /// Manually creates a table in the test dataset using column definitions. /// /// Creates a table by generating a DDL statement from the provided column specifications. @@ -1002,3 +1043,100 @@ where parsed_table_rows } + +/// Represents a column in a BigQuery table schema. +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] +pub struct BigQueryColumnSchema { + pub column_name: String, + pub data_type: String, + pub ordinal_position: i64, +} + +impl From for BigQueryColumnSchema { + fn from(value: TableRow) -> Self { + let columns = value.columns.unwrap(); + + BigQueryColumnSchema { + column_name: parse_table_cell(columns[0].clone()).unwrap(), + data_type: parse_table_cell(columns[1].clone()).unwrap(), + ordinal_position: parse_table_cell(columns[2].clone()).unwrap(), + } + } +} + +/// Wrapper around a BigQuery table schema for cleaner assertions in tests. +/// +/// Provides convenient methods to check column presence and absence, making +/// schema validation in tests more readable and reducing boilerplate. +#[derive(Debug)] +pub struct BigQueryTableSchema(Vec); + +impl BigQueryTableSchema { + /// Creates a new schema wrapper from a vector of column schemas. + pub fn new(columns: Vec) -> Self { + Self(columns) + } + + /// Returns true if a column with the given name exists in the schema. + pub fn has_column(&self, name: &str) -> bool { + self.0.iter().any(|c| c.column_name == name) + } + + /// Asserts that a column with the given name exists in the schema. + /// + /// Panics with a descriptive message if the column is not found. + pub fn assert_has_column(&self, name: &str) { + assert!( + self.has_column(name), + "expected column '{}' to exist in schema, but it was not found. Columns: {:?}", + name, + self.column_names() + ); + } + + /// Asserts that a column with the given name does not exist in the schema. + /// + /// Panics with a descriptive message if the column is found. + pub fn assert_no_column(&self, name: &str) { + assert!( + !self.has_column(name), + "expected column '{}' to not exist in schema, but it was found. Columns: {:?}", + name, + self.column_names() + ); + } + + /// Asserts that the schema contains exactly the specified columns (by name). + /// + /// The order of columns does not matter. CDC columns (`_CHANGE_TYPE` and + /// `_CHANGE_SEQUENCE_NUMBER`) are excluded from the comparison. + pub fn assert_columns(&self, expected: &[&str]) { + let actual: Vec<&str> = self + .0 + .iter() + .map(|c| c.column_name.as_str()) + .filter(|name| !name.starts_with("_CHANGE")) + .collect(); + + let mut expected_sorted: Vec<&str> = expected.to_vec(); + expected_sorted.sort(); + + let mut actual_sorted: Vec<&str> = actual.clone(); + actual_sorted.sort(); + + assert_eq!( + actual_sorted, expected_sorted, + "schema columns mismatch. Expected: {expected_sorted:?}, Actual: {actual_sorted:?}" + ); + } + + /// Returns the names of all columns in the schema. + pub fn column_names(&self) -> Vec<&str> { + self.0.iter().map(|c| c.column_name.as_str()).collect() + } + + /// Returns a reference to the underlying column schemas. + pub fn columns(&self) -> &[BigQueryColumnSchema] { + &self.0 + } +} diff --git a/etl-postgres/src/replication/destination_metadata.rs b/etl-postgres/src/replication/destination_metadata.rs new file mode 100644 index 000000000..42714a579 --- /dev/null +++ b/etl-postgres/src/replication/destination_metadata.rs @@ -0,0 +1,217 @@ +use sqlx::postgres::types::Oid as SqlxTableId; +use sqlx::{PgExecutor, PgPool, Row, Type}; +use std::collections::HashMap; + +use crate::types::{SnapshotId, TableId}; + +/// Database enum type for destination table schema status. +/// +/// Maps to the `etl.destination_table_schema_status` PostgreSQL enum type. +#[derive(Debug, Clone, Copy, Type, PartialEq, Eq)] +#[sqlx( + type_name = "etl.destination_table_schema_status", + rename_all = "snake_case" +)] +pub enum DestinationTableSchemaStatus { + /// A schema change is currently being applied. + Applying, + /// The schema has been successfully applied. + Applied, +} + +/// Database row representation of destination table metadata. +#[derive(Debug, Clone)] +pub struct DestinationTableMetadataRow { + pub table_id: TableId, + pub destination_table_id: String, + pub snapshot_id: SnapshotId, + /// The schema version before the current change. None for initial schemas. + pub previous_snapshot_id: Option, + pub schema_status: DestinationTableSchemaStatus, + pub replication_mask: Vec, +} + +/// Stores destination table metadata in the database. +/// +/// Inserts or updates the complete metadata for a table at a destination. +/// Uses upsert semantics: if a row exists for (pipeline_id, table_id), +/// all fields are updated. +#[allow(clippy::too_many_arguments)] +pub async fn store_destination_table_metadata( + pool: &PgPool, + pipeline_id: i64, + table_id: TableId, + destination_table_id: &str, + snapshot_id: SnapshotId, + previous_snapshot_id: Option, + schema_status: DestinationTableSchemaStatus, + replication_mask: &[u8], +) -> Result<(), sqlx::Error> { + sqlx::query( + r#" + insert into etl.destination_tables_metadata + (pipeline_id, table_id, destination_table_id, snapshot_id, + previous_snapshot_id, schema_status, replication_mask) + values ($1, $2, $3, $4::pg_lsn, $5::pg_lsn, $6, $7) + on conflict (pipeline_id, table_id) + do update set + destination_table_id = excluded.destination_table_id, + snapshot_id = excluded.snapshot_id, + previous_snapshot_id = excluded.previous_snapshot_id, + schema_status = excluded.schema_status, + replication_mask = excluded.replication_mask, + updated_at = now() + "#, + ) + .bind(pipeline_id) + .bind(SqlxTableId(table_id.into_inner())) + .bind(destination_table_id) + .bind(snapshot_id.to_pg_lsn_string()) + .bind(previous_snapshot_id.map(|s| s.to_pg_lsn_string())) + .bind(schema_status) + .bind(replication_mask) + .execute(pool) + .await?; + + Ok(()) +} + +/// Loads all destination table metadata for a pipeline. +/// +/// Returns a map from table_id to the complete metadata row. +pub async fn load_destination_tables_metadata( + pool: &PgPool, + pipeline_id: i64, +) -> Result, sqlx::Error> { + let rows = sqlx::query( + r#" + select table_id, destination_table_id, snapshot_id::text as snapshot_id, + previous_snapshot_id::text as previous_snapshot_id, schema_status, replication_mask + from etl.destination_tables_metadata + where pipeline_id = $1 + "#, + ) + .bind(pipeline_id) + .fetch_all(pool) + .await?; + + let mut metadata = HashMap::new(); + for row in rows { + let table_id: SqlxTableId = row.get("table_id"); + let table_id = TableId::new(table_id.0); + let snapshot_id_str: String = row.get("snapshot_id"); + let previous_snapshot_id_str: Option = row.get("previous_snapshot_id"); + + let snapshot_id = SnapshotId::from_pg_lsn_string(&snapshot_id_str) + .map_err(|e| sqlx::Error::Protocol(e.to_string()))?; + let previous_snapshot_id = previous_snapshot_id_str + .map(|s| SnapshotId::from_pg_lsn_string(&s)) + .transpose() + .map_err(|e| sqlx::Error::Protocol(e.to_string()))?; + + metadata.insert( + table_id, + DestinationTableMetadataRow { + table_id, + destination_table_id: row.get("destination_table_id"), + snapshot_id, + previous_snapshot_id, + schema_status: row.get("schema_status"), + replication_mask: row.get("replication_mask"), + }, + ); + } + + Ok(metadata) +} + +/// Gets destination table metadata for a single table. +pub async fn get_destination_table_metadata( + pool: &PgPool, + pipeline_id: i64, + table_id: TableId, +) -> Result, sqlx::Error> { + let row = sqlx::query( + r#" + select table_id, destination_table_id, snapshot_id::text as snapshot_id, + previous_snapshot_id::text as previous_snapshot_id, schema_status, replication_mask + from etl.destination_tables_metadata + where pipeline_id = $1 and table_id = $2 + "#, + ) + .bind(pipeline_id) + .bind(SqlxTableId(table_id.into_inner())) + .fetch_optional(pool) + .await?; + + match row { + Some(r) => { + let table_id: SqlxTableId = r.get("table_id"); + let snapshot_id_str: String = r.get("snapshot_id"); + let previous_snapshot_id_str: Option = r.get("previous_snapshot_id"); + + let snapshot_id = SnapshotId::from_pg_lsn_string(&snapshot_id_str) + .map_err(|e| sqlx::Error::Protocol(e.to_string()))?; + let previous_snapshot_id = previous_snapshot_id_str + .map(|s| SnapshotId::from_pg_lsn_string(&s)) + .transpose() + .map_err(|e| sqlx::Error::Protocol(e.to_string()))?; + + Ok(Some(DestinationTableMetadataRow { + table_id: TableId::new(table_id.0), + destination_table_id: r.get("destination_table_id"), + snapshot_id, + previous_snapshot_id, + schema_status: r.get("schema_status"), + replication_mask: r.get("replication_mask"), + })) + } + None => Ok(None), + } +} + +/// Deletes all destination table metadata for a pipeline. +/// +/// Used during pipeline cleanup. +pub async fn delete_destination_tables_metadata_for_all_tables<'c, E>( + executor: E, + pipeline_id: i64, +) -> Result +where + E: PgExecutor<'c>, +{ + let result = sqlx::query( + r#" + delete from etl.destination_tables_metadata + where pipeline_id = $1 + "#, + ) + .bind(pipeline_id) + .execute(executor) + .await?; + + Ok(result.rows_affected()) +} + +/// Deletes destination table metadata for a single table. +pub async fn delete_destination_table_metadata<'c, E>( + executor: E, + pipeline_id: i64, + table_id: TableId, +) -> Result +where + E: PgExecutor<'c>, +{ + let result = sqlx::query( + r#" + delete from etl.destination_tables_metadata + where pipeline_id = $1 and table_id = $2 + "#, + ) + .bind(pipeline_id) + .bind(SqlxTableId(table_id.into_inner())) + .execute(executor) + .await?; + + Ok(result.rows_affected()) +} diff --git a/etl-postgres/src/replication/health.rs b/etl-postgres/src/replication/health.rs index c29fcf0b8..52b9ca37e 100644 --- a/etl-postgres/src/replication/health.rs +++ b/etl-postgres/src/replication/health.rs @@ -3,7 +3,7 @@ use sqlx::PgExecutor; /// Fully-qualified table names required by ETL. pub const ETL_TABLE_NAMES: [&str; 4] = [ "etl.replication_state", - "etl.table_mappings", + "etl.destination_tables_metadata", "etl.table_schemas", "etl.table_columns", ]; @@ -12,7 +12,7 @@ pub const ETL_TABLE_NAMES: [&str; 4] = [ /// /// Checks presence of the following relations: /// - etl.replication_state -/// - etl.table_mappings +/// - etl.destination_tables_metadata /// - etl.table_schemas /// - etl.table_columns pub async fn etl_tables_present<'c, E>(executor: E) -> Result diff --git a/etl-postgres/src/replication/mod.rs b/etl-postgres/src/replication/mod.rs index d403e0a28..026ce977d 100644 --- a/etl-postgres/src/replication/mod.rs +++ b/etl-postgres/src/replication/mod.rs @@ -1,10 +1,10 @@ mod db; +pub mod destination_metadata; pub mod health; pub mod lag; pub mod schema; pub mod slots; pub mod state; -pub mod table_mappings; pub mod worker; pub use db::*; diff --git a/etl-postgres/src/replication/schema.rs b/etl-postgres/src/replication/schema.rs index cd540d653..a540f88ba 100644 --- a/etl-postgres/src/replication/schema.rs +++ b/etl-postgres/src/replication/schema.rs @@ -4,7 +4,7 @@ use sqlx::{PgExecutor, PgPool, Row}; use std::collections::HashMap; use tokio_postgres::types::Type as PgType; -use crate::types::{ColumnSchema, TableId, TableName, TableSchema}; +use crate::types::{ColumnSchema, SnapshotId, TableId, TableName, TableSchema}; macro_rules! define_type_mappings { ( @@ -134,10 +134,11 @@ define_type_mappings! { DATE_RANGE => "DATE_RANGE" } -/// Stores a table schema in the database. +/// Stores a table schema in the database with a specific snapshot ID. /// -/// Inserts or updates table schema and column information in schema storage tables -/// using a transaction to ensure atomicity. +/// Upserts table schema and replaces all column information in schema storage tables +/// using a transaction to ensure atomicity. If a schema version already exists for +/// the same (pipeline_id, table_id, snapshot_id), columns are deleted and re-inserted. pub async fn store_table_schema( pool: &PgPool, pipeline_id: i64, @@ -145,16 +146,13 @@ pub async fn store_table_schema( ) -> Result<(), sqlx::Error> { let mut tx = pool.begin().await?; - // Insert or update table schema record + // Upsert table schema version let table_schema_id: i64 = sqlx::query( r#" - insert into etl.table_schemas (pipeline_id, table_id, schema_name, table_name) - values ($1, $2, $3, $4) - on conflict (pipeline_id, table_id) - do update set - schema_name = excluded.schema_name, - table_name = excluded.table_name, - updated_at = now() + insert into etl.table_schemas (pipeline_id, table_id, schema_name, table_name, snapshot_id) + values ($1, $2, $3, $4, $5::pg_lsn) + on conflict (pipeline_id, table_id, snapshot_id) + do update set schema_name = excluded.schema_name, table_name = excluded.table_name returning id "#, ) @@ -162,6 +160,7 @@ pub async fn store_table_schema( .bind(table_schema.id.into_inner() as i64) .bind(&table_schema.name.schema) .bind(&table_schema.name.name) + .bind(table_schema.snapshot_id.to_pg_lsn_string()) .fetch_one(&mut *tx) .await? .get(0); @@ -173,23 +172,23 @@ pub async fn store_table_schema( .await?; // Insert all columns - for (column_order, column_schema) in table_schema.column_schemas.iter().enumerate() { - let column_type_str = postgres_type_to_string(&column_schema.typ); - + for column_schema in table_schema.column_schemas.iter() { sqlx::query( r#" - insert into etl.table_columns - (table_schema_id, column_name, column_type, type_modifier, nullable, primary_key, column_order) - values ($1, $2, $3, $4, $5, $6, $7) + insert into etl.table_columns + (table_schema_id, column_name, column_type, type_modifier, nullable, primary_key, + column_order, primary_key_ordinal_position) + values ($1, $2, $3, $4, $5, $6, $7, $8) "#, ) .bind(table_schema_id) .bind(&column_schema.name) - .bind(column_type_str) + .bind(postgres_type_to_string(&column_schema.typ)) .bind(column_schema.modifier) .bind(column_schema.nullable) - .bind(column_schema.primary) - .bind(column_order as i32) + .bind(column_schema.primary_key()) + .bind(column_schema.ordinal_position) + .bind(column_schema.primary_key_ordinal_position) .execute(&mut *tx) .await?; } @@ -199,33 +198,130 @@ pub async fn store_table_schema( Ok(()) } -/// Loads all table schemas for a pipeline from the database. +/// Loads all table schemas for a pipeline from the database at the latest snapshot. /// /// Retrieves table schemas and columns from schema storage tables, -/// reconstructing complete [`TableSchema`] objects. +/// reconstructing complete [`TableSchema`] objects. This is equivalent to +/// calling [`load_table_schemas_at_snapshot`] with the maximum LSN value. pub async fn load_table_schemas( pool: &PgPool, pipeline_id: i64, ) -> Result, sqlx::Error> { + load_table_schemas_at_snapshot(pool, pipeline_id, SnapshotId::max()).await +} + +/// Loads a single table schema with the largest snapshot_id <= the requested snapshot. +/// +/// Returns `None` if no schema version exists for the table at or before the given snapshot. +pub async fn load_table_schema_at_snapshot( + pool: &PgPool, + pipeline_id: i64, + table_id: TableId, + snapshot_id: SnapshotId, +) -> Result, sqlx::Error> { let rows = sqlx::query( r#" select ts.table_id, ts.schema_name, ts.table_name, + ts.snapshot_id::text as snapshot_id, tc.column_name, tc.column_type, tc.type_modifier, tc.nullable, tc.primary_key, - tc.column_order + tc.column_order, + tc.primary_key_ordinal_position from etl.table_schemas ts inner join etl.table_columns tc on ts.id = tc.table_schema_id - where ts.pipeline_id = $1 - order by ts.table_id, tc.column_order + where ts.id = ( + select id from etl.table_schemas + where pipeline_id = $1 and table_id = $2 and snapshot_id <= $3::pg_lsn + order by snapshot_id desc + limit 1 + ) + order by tc.column_order + "#, + ) + .bind(pipeline_id) + .bind(SqlxTableId(table_id.into_inner())) + .bind(snapshot_id.to_pg_lsn_string()) + .fetch_all(pool) + .await?; + + if rows.is_empty() { + return Ok(None); + } + + let first_row = &rows[0]; + let table_oid: SqlxTableId = first_row.get("table_id"); + let table_id = TableId::new(table_oid.0); + let schema_name: String = first_row.get("schema_name"); + let table_name: String = first_row.get("table_name"); + let snapshot_id_str: String = first_row.get("snapshot_id"); + let snapshot_id = SnapshotId::from_pg_lsn_string(&snapshot_id_str) + .map_err(|e| sqlx::Error::Protocol(e.to_string()))?; + + let mut table_schema = TableSchema::with_snapshot_id( + table_id, + TableName::new(schema_name, table_name), + vec![], + snapshot_id, + ); + + for row in rows { + table_schema.add_column_schema(parse_column_schema(&row)); + } + + Ok(Some(table_schema)) +} + +/// Loads all table schemas for a pipeline at a specific snapshot point. +/// +/// For each table, retrieves the schema version with the largest snapshot_id +/// that is <= the requested snapshot_id. Tables without any schema version +/// at or before the snapshot are excluded from the result. +pub async fn load_table_schemas_at_snapshot( + pool: &PgPool, + pipeline_id: i64, + snapshot_id: SnapshotId, +) -> Result, sqlx::Error> { + // Use DISTINCT ON to efficiently find the latest schema version for each table. + // PostgreSQL optimizes DISTINCT ON with ORDER BY using index scans when possible. + let rows = sqlx::query( + r#" + with latest_schemas as ( + select distinct on (ts.table_id) + ts.id, + ts.table_id, + ts.schema_name, + ts.table_name, + ts.snapshot_id + from etl.table_schemas ts + where ts.pipeline_id = $1 + and ts.snapshot_id <= $2::pg_lsn + order by ts.table_id, ts.snapshot_id desc + ) + select + ls.table_id, + ls.schema_name, + ls.table_name, + ls.snapshot_id::text as snapshot_id, + tc.column_name, + tc.column_type, + tc.type_modifier, + tc.nullable, + tc.primary_key, + tc.column_order, + tc.primary_key_ordinal_position + from latest_schemas ls + inner join etl.table_columns tc on ls.id = tc.table_schema_id + order by ls.table_id, tc.column_order "#, ) .bind(pipeline_id) + .bind(snapshot_id.to_pg_lsn_string()) .fetch_all(pool) .await?; @@ -236,9 +332,17 @@ pub async fn load_table_schemas( let table_id = TableId::new(table_oid.0); let schema_name: String = row.get("schema_name"); let table_name: String = row.get("table_name"); + let snapshot_id_str: String = row.get("snapshot_id"); + let row_snapshot_id = SnapshotId::from_pg_lsn_string(&snapshot_id_str) + .map_err(|e| sqlx::Error::Protocol(e.to_string()))?; let entry = table_schemas.entry(table_id).or_insert_with(|| { - TableSchema::new(table_id, TableName::new(schema_name, table_name), vec![]) + TableSchema::with_snapshot_id( + table_id, + TableName::new(schema_name, table_name), + vec![], + row_snapshot_id, + ) }); entry.add_column_schema(parse_column_schema(&row)); @@ -303,15 +407,17 @@ fn parse_column_schema(row: &PgRow) -> ColumnSchema { let column_name: String = row.get("column_name"); let column_type: String = row.get("column_type"); let type_modifier: i32 = row.get("type_modifier"); + let ordinal_position: i32 = row.get("column_order"); + let primary_key_ordinal_position: Option = row.get("primary_key_ordinal_position"); let nullable: bool = row.get("nullable"); - let primary_key: bool = row.get("primary_key"); ColumnSchema::new( column_name, string_to_postgres_type(&column_type), type_modifier, + ordinal_position, + primary_key_ordinal_position, nullable, - primary_key, ) } diff --git a/etl-postgres/src/replication/state.rs b/etl-postgres/src/replication/state.rs index fc4137fe6..b64b52fd2 100644 --- a/etl-postgres/src/replication/state.rs +++ b/etl-postgres/src/replication/state.rs @@ -3,8 +3,8 @@ use serde::{Deserialize, Serialize}; use sqlx::{PgExecutor, PgPool, Type, postgres::types::Oid as SqlxTableId, prelude::FromRow}; use tokio_postgres::types::PgLsn; +use crate::replication::destination_metadata::delete_destination_table_metadata; use crate::replication::schema::delete_table_schema_for_table; -use crate::replication::table_mappings::delete_table_mappings_for_table; use crate::types::TableId; /// Replication state of a table during the ETL process. @@ -272,12 +272,12 @@ pub async fn rollback_replication_state( .await?; // If the rollback goes to `Init` or `DataSync`, we also want to clean up the table schema - // and table mappings; this way the rollback starts clean. + // and destination metadata; this way the rollback starts clean. if matches!( restored_row.state, TableReplicationStateType::Init | TableReplicationStateType::DataSync ) { - delete_table_mappings_for_table(&mut *tx, pipeline_id, &table_id).await?; + delete_destination_table_metadata(&mut *tx, pipeline_id, table_id).await?; delete_table_schema_for_table(&mut *tx, pipeline_id, table_id).await?; } @@ -316,8 +316,8 @@ pub async fn reset_replication_state( .execute(&mut *tx) .await?; - // We want to clean up the table schema and table mappings to start fresh. - delete_table_mappings_for_table(&mut *tx, pipeline_id, &table_id).await?; + // We want to clean up the table schema and destination metadata to start fresh. + delete_destination_table_metadata(&mut *tx, pipeline_id, table_id).await?; delete_table_schema_for_table(&mut *tx, pipeline_id, table_id).await?; // Insert a new ` Init ` state entry and return it diff --git a/etl-postgres/src/replication/table_mappings.rs b/etl-postgres/src/replication/table_mappings.rs deleted file mode 100644 index 31c5c7bce..000000000 --- a/etl-postgres/src/replication/table_mappings.rs +++ /dev/null @@ -1,109 +0,0 @@ -use sqlx::{PgExecutor, PgPool, Row, postgres::types::Oid as SqlxTableId}; -use std::collections::HashMap; - -use crate::types::TableId; - -/// Stores a table mapping in the database. -/// -/// Inserts or updates a mapping between source table ID and destination table ID -/// for the specified pipeline. -pub async fn store_table_mapping( - pool: &PgPool, - pipeline_id: i64, - source_table_id: &TableId, - destination_table_id: &str, -) -> Result<(), sqlx::Error> { - sqlx::query( - r#" - insert into etl.table_mappings (pipeline_id, source_table_id, destination_table_id) - values ($1, $2, $3) - on conflict (pipeline_id, source_table_id) - do update set - destination_table_id = excluded.destination_table_id, - updated_at = now() - "#, - ) - .bind(pipeline_id) - .bind(SqlxTableId(source_table_id.into_inner())) - .bind(destination_table_id) - .execute(pool) - .await?; - - Ok(()) -} - -/// Loads all table mappings for a pipeline from the database. -/// -/// Retrieves all source table ID to destination table ID mappings for the specified pipeline. -pub async fn load_table_mappings( - pool: &PgPool, - pipeline_id: i64, -) -> Result, sqlx::Error> { - let rows = sqlx::query( - r#" - select source_table_id, destination_table_id - from etl.table_mappings - where pipeline_id = $1 - "#, - ) - .bind(pipeline_id) - .fetch_all(pool) - .await?; - - let mut mappings = HashMap::new(); - for row in rows { - let source_table_id: SqlxTableId = row.get("source_table_id"); - let destination_table_id: String = row.get("destination_table_id"); - - mappings.insert(TableId::new(source_table_id.0), destination_table_id); - } - - Ok(mappings) -} - -/// Deletes all table mappings for a pipeline from the database. -/// -/// Removes all table mapping records for the specified pipeline. -/// Used during pipeline cleanup. -pub async fn delete_table_mappings_for_all_tables<'c, E>( - executor: E, - pipeline_id: i64, -) -> Result -where - E: PgExecutor<'c>, -{ - let result = sqlx::query( - r#" - delete from etl.table_mappings - where pipeline_id = $1 - "#, - ) - .bind(pipeline_id) - .execute(executor) - .await?; - - Ok(result.rows_affected()) -} - -/// Deletes a single table mapping for a given pipeline and source table id. -pub async fn delete_table_mappings_for_table<'c, E>( - executor: E, - pipeline_id: i64, - source_table_id: &TableId, -) -> Result -where - E: PgExecutor<'c>, -{ - let result = sqlx::query( - r#" - delete from etl.table_mappings - where pipeline_id = $1 and source_table_id = $2 - "#, - ) - .bind(pipeline_id) - .bind(SqlxTableId(source_table_id.into_inner())) - .execute(executor) - .await?; - - Ok(result.rows_affected()) -} diff --git a/etl-postgres/src/tokio/test_utils.rs b/etl-postgres/src/tokio/test_utils.rs index 8934094f8..1b15c097b 100644 --- a/etl-postgres/src/tokio/test_utils.rs +++ b/etl-postgres/src/tokio/test_utils.rs @@ -22,11 +22,16 @@ pub enum TableModification<'a> { DropColumn { name: &'a str, }, - /// Alter an existing column with the specified alteration. + /// Alter an existing column with the specified alteration (e.g., "type bigint"). AlterColumn { name: &'a str, alteration: &'a str, }, + /// Rename an existing column. + RenameColumn { + old_name: &'a str, + new_name: &'a str, + }, ReplicaIdentity { value: &'a str, }, @@ -211,6 +216,9 @@ impl PgDatabase { TableModification::AlterColumn { name, alteration } => { format!("alter column {name} {alteration}") } + TableModification::RenameColumn { old_name, new_name } => { + format!("rename column {old_name} to {new_name}") + } TableModification::ReplicaIdentity { value } => { format!("replica identity {value}") } @@ -506,13 +514,7 @@ impl Drop for PgDatabase { /// Creates a [`ColumnSchema`] for a non-nullable, primary key column named "id" /// of type `INT8` that is added by default to tables created by [`PgDatabase`]. pub fn id_column_schema() -> ColumnSchema { - ColumnSchema { - name: "id".to_string(), - typ: Type::INT8, - modifier: -1, - nullable: false, - primary: true, - } + ColumnSchema::new("id".to_string(), Type::INT8, -1, 1, Some(1), false) } /// Creates a new Postgres database and returns a connected client. diff --git a/etl-postgres/src/types/schema.rs b/etl-postgres/src/types/schema.rs index c66ab41f1..8c4be04cb 100644 --- a/etl-postgres/src/types/schema.rs +++ b/etl-postgres/src/types/schema.rs @@ -1,12 +1,115 @@ use pg_escape::quote_identifier; -use std::cmp::Ordering; +use std::collections::{HashMap, HashSet}; use std::fmt; +use std::hash::{Hash, Hasher}; use std::str::FromStr; -use tokio_postgres::types::{FromSql, ToSql, Type}; +use std::sync::Arc; +use thiserror::Error; +use tokio_postgres::types::{FromSql, PgLsn, ToSql, Type}; + +/// Errors that can occur during schema operations. +#[derive(Debug, Error)] +pub enum SchemaError { + /// Columns were received during replication that do not exist in the stored table schema. + #[error("received columns during replication that are not in the stored table schema: {0:?}")] + UnknownReplicatedColumns(Vec), + + /// A snapshot ID string could not be converted to the [`SnapshotId`] type. + #[error("invalid snapshot id '{0}'")] + InvalidSnapshotId(String), +} /// An object identifier in Postgres. type Oid = u32; +/// Snapshot identifier for schema versioning. +/// +/// Wraps a [`PgLsn`] to represent the start_lsn of the DDL message that created a schema version. +/// A value of 0/0 indicates the initial schema before any DDL changes. +/// Stored as `pg_lsn` in the database. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] +pub struct SnapshotId(PgLsn); + +impl SnapshotId { + /// Returns the initial snapshot ID (0/0) for the first schema version. + pub fn initial() -> Self { + Self(PgLsn::from(0)) + } + + /// Returns the maximum possible snapshot ID. + pub fn max() -> Self { + Self(PgLsn::from(u64::MAX)) + } + + /// Creates a new [`SnapshotId`] from a [`PgLsn`]. + pub fn new(lsn: PgLsn) -> Self { + Self(lsn) + } + + /// Returns the inner [`PgLsn`] value. + pub fn into_inner(self) -> PgLsn { + self.0 + } + + /// Returns the underlying `u64` representation. + pub fn as_u64(self) -> u64 { + self.0.into() + } + + /// Converts to a `pg_lsn` string. + pub fn to_pg_lsn_string(self) -> String { + self.0.to_string() + } + + /// Parses a `pg_lsn` string. + /// + /// # Errors + /// + /// Returns [`SchemaError::InvalidSnapshotId`] if the string is not a valid `pg_lsn` format. + pub fn from_pg_lsn_string(s: &str) -> Result { + s.parse::() + .map(Self) + .map_err(|_| SchemaError::InvalidSnapshotId(s.to_string())) + } +} + +impl Hash for SnapshotId { + fn hash(&self, state: &mut H) { + let value: u64 = self.0.into(); + value.hash(state); + } +} + +impl From for SnapshotId { + fn from(lsn: PgLsn) -> Self { + Self(lsn) + } +} + +impl From for PgLsn { + fn from(snapshot_id: SnapshotId) -> Self { + snapshot_id.0 + } +} + +impl From for SnapshotId { + fn from(value: u64) -> Self { + Self(PgLsn::from(value)) + } +} + +impl From for u64 { + fn from(snapshot_id: SnapshotId) -> Self { + snapshot_id.0.into() + } +} + +impl fmt::Display for SnapshotId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + /// A fully qualified Postgres table name consisting of a schema and table name. /// /// This type represents a table identifier in Postgres, which requires both a schema name @@ -51,50 +154,46 @@ type TypeModifier = i32; /// Represents the schema of a single column in a Postgres table. /// /// This type contains all metadata about a column including its name, data type, -/// type modifier, nullability, and whether it's part of the primary key. +/// type modifier, ordinal position, primary key information, and nullability. #[derive(Debug, Clone, Eq, PartialEq)] pub struct ColumnSchema { - /// The name of the column + /// The name of the column. pub name: String, - /// The Postgres data type of the column + /// The Postgres data type of the column. pub typ: Type, - /// Type-specific modifier value (e.g., length for varchar) + /// Type-specific modifier value (e.g., length for varchar). pub modifier: TypeModifier, - /// Whether the column can contain NULL values + /// The 1-based ordinal position of the column in the table. + pub ordinal_position: i32, + /// The 1-based ordinal position of this column in the primary key, or None if not a primary key. + pub primary_key_ordinal_position: Option, + /// Whether the column can contain NULL values. pub nullable: bool, - /// Whether the column is part of the table's primary key - pub primary: bool, } impl ColumnSchema { + /// Creates a new [`ColumnSchema`] with all fields specified. pub fn new( name: String, typ: Type, modifier: TypeModifier, + ordinal_position: i32, + primary_key_ordinal_position: Option, nullable: bool, - primary: bool, ) -> ColumnSchema { Self { name, typ, modifier, + ordinal_position, + primary_key_ordinal_position, nullable, - primary, } } - /// Compares two [`ColumnSchema`] instances, excluding the `nullable` field. - /// - /// Return `true` if all fields except `nullable` are equal, `false` otherwise. - /// - /// This method is used for comparing table schemas loaded via the initial table sync and the - /// relation messages received via CDC. The reason for skipping the `nullable` field is that - /// unfortunately Postgres doesn't seem to propagate nullable information of a column via - /// relation messages. The reason for skipping the `primary` field is that if the replica - /// identity of a table is set to full, the relation message sets all columns as primary - /// key, irrespective of what the actual primary key in the table is. - fn partial_eq(&self, other: &ColumnSchema) -> bool { - self.name == other.name && self.typ == other.typ && self.modifier == other.modifier + /// Returns whether this column is part of the table's primary key. + pub fn primary_key(&self) -> bool { + self.primary_key_ordinal_position.is_some() } } @@ -183,23 +282,39 @@ impl ToSql for TableId { /// Represents the complete schema of a Postgres table. /// /// This type contains all metadata about a table including its name, OID, -/// and the schemas of all its columns. +/// the schemas of all its columns, and a snapshot identifier for versioning. #[derive(Debug, Clone, Eq, PartialEq)] pub struct TableSchema { - /// The Postgres OID of the table + /// The Postgres OID of the table. pub id: TableId, - /// The fully qualified name of the table + /// The fully qualified name of the table. pub name: TableName, - /// The schemas of all columns in the table + /// The schemas of all columns in the table. pub column_schemas: Vec, + /// The snapshot identifier for this schema version. + /// + /// Value 0 indicates the initial schema, other values are start_lsn positions of DDL changes. + pub snapshot_id: SnapshotId, } impl TableSchema { + /// Creates a new [`TableSchema`] with the initial snapshot ID (0/0). pub fn new(id: TableId, name: TableName, column_schemas: Vec) -> Self { + Self::with_snapshot_id(id, name, column_schemas, SnapshotId::initial()) + } + + /// Creates a new [`TableSchema`] with a specific snapshot ID. + pub fn with_snapshot_id( + id: TableId, + name: TableName, + column_schemas: Vec, + snapshot_id: SnapshotId, + ) -> Self { Self { id, name, column_schemas, + snapshot_id, } } @@ -212,32 +327,690 @@ impl TableSchema { /// /// This method checks if any column in the table is marked as part of the primary key. pub fn has_primary_keys(&self) -> bool { - self.column_schemas.iter().any(|cs| cs.primary) + self.column_schemas.iter().any(|cs| cs.primary_key()) } +} + +/// A bitmask indicating which columns are being replicated. +/// +/// Each element is either 0 (not replicated) or 1 (replicated), with indices +/// corresponding to the columns in the table schema. Wrapped in [`Arc`] for +/// efficient sharing across multiple events. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ReplicationMask(Arc>); - /// Compares two [`TableSchema`] instances, excluding the [`ColumnSchema`]'s `nullable` field. +impl fmt::Display for ReplicationMask { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(")?; + for (i, &v) in self.0.iter().enumerate() { + if i > 0 { + write!(f, ",")?; + } + write!(f, "{v}")?; + } + write!(f, ")") + } +} + +impl ReplicationMask { + /// Tries to create a new [`ReplicationMask`] from a table schema and column names. + /// + /// The mask is constructed by checking which column names from the schema are present + /// in the provided set of replicated column names. + /// + /// # Errors + /// + /// Returns [`SchemaError::UnknownReplicatedColumns`] if any column in + /// `replicated_column_names` does not exist in the table schema. + /// + /// The column validation occurs because we have to make sure that the stored table schema is always + /// up to date, if not, it's a critical problem. + pub fn try_build( + table_schema: &TableSchema, + replicated_column_names: &HashSet, + ) -> Result { + let schema_column_names: HashSet<&str> = table_schema + .column_schemas + .iter() + .map(|column_schema| column_schema.name.as_str()) + .collect(); + + let unknown_columns: Vec = replicated_column_names + .iter() + .filter(|name| !schema_column_names.contains(name.as_str())) + .cloned() + .collect(); + + // This check ensures all replicated columns are present in the schema. + // + // Limitation: If a column exists in the schema but is absent from the replicated columns, + // we assume publication-level column filtering is enabled. However, this is indistinguishable + // from an invalid state where the schema has diverged, we cannot detect the difference. + // + // How schema divergence occurs: When progress tracking fails and the system restarts, + // we may receive a `Relation` message reflecting the *current* table schema rather than + // the schema at the time the in-flight events were emitted. This is how Postgres handles + // initial `Relation` messages on reconnection. It's not the wrong behavior since the data + // has the columns that it announces, but it conflicts with our schema management logic. + // + // Invariant: Our schema management assumes the schema in `Relation` messages is consistent + // with the schema under which the corresponding row events were produced. + // + // In the future we might want to implement a system to go around this edge case. + if !unknown_columns.is_empty() { + return Err(SchemaError::UnknownReplicatedColumns(unknown_columns)); + } + + Ok(Self::build(table_schema, replicated_column_names)) + } + + /// Creates a new [`ReplicationMask`] from a table schema and column names, falling back + /// to an all-replicated mask if validation fails. + /// + /// This method attempts to validate that all replicated column names exist in the schema. + /// If validation succeeds, it builds a mask based on matching columns. If validation fails + /// (unknown columns are present), it returns a mask with all columns marked as replicated. + /// + /// This fallback behavior handles the case where Postgres sends a `Relation` message on + /// reconnection with the current schema, but the stored schema is from an earlier point + /// before DDL changes. Rather than failing, we enable all columns and let the system + /// converge when the actual DDL message is replayed. + pub fn build_or_all( + table_schema: &TableSchema, + replicated_column_names: &HashSet, + ) -> Self { + match Self::try_build(table_schema, replicated_column_names) { + Ok(mask) => mask, + Err(_) => Self::all(table_schema), + } + } + + /// Creates a new [`ReplicationMask`] from a table schema and column names. + pub fn build(table_schema: &TableSchema, replicated_column_names: &HashSet) -> Self { + let mask = table_schema + .column_schemas + .iter() + .map(|cs| { + if replicated_column_names.contains(&cs.name) { + 1 + } else { + 0 + } + }) + .collect(); + + Self(Arc::new(mask)) + } + + /// Creates a [`ReplicationMask`] with all columns marked as replicated. + pub fn all(table_schema: &TableSchema) -> Self { + let mask = vec![1; table_schema.column_schemas.len()]; + Self(Arc::new(mask)) + } + + /// Creates a [`ReplicationMask`] from raw bytes. + /// + /// Used for deserializing a mask from storage. + pub fn from_bytes(bytes: Vec) -> Self { + Self(Arc::new(bytes)) + } + + /// Returns the underlying mask as a slice. + pub fn as_slice(&self) -> &[u8] { + &self.0 + } + + /// Returns the underlying mask as a vector of bytes. /// - /// Return `true` if all fields except `nullable` are equal, `false` otherwise. - pub fn partial_eq(&self, other: &TableSchema) -> bool { - self.id == other.id - && self.name == other.name - && self.column_schemas.len() == other.column_schemas.len() - && self - .column_schemas - .iter() - .zip(other.column_schemas.iter()) - .all(|(c1, c2)| c1.partial_eq(c2)) + /// Used for serializing the mask to storage. + pub fn to_bytes(&self) -> Vec { + self.0.as_ref().clone() + } + + /// Returns the number of columns in the mask. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns `true` if the mask is empty. + pub fn is_empty(&self) -> bool { + self.0.is_empty() } } -impl PartialOrd for TableSchema { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) +/// A wrapper around [`TableSchema`] that tracks which columns are being replicated. +/// +/// This struct holds a reference to the underlying table schema and a [`ReplicationMask`] +/// indicating which columns are included in the replication. +#[derive(Debug, Clone)] +pub struct ReplicatedTableSchema { + /// The underlying table schema. + table_schema: Arc, + /// A bitmask where 1 indicates the column at that index is replicated. + replication_mask: ReplicationMask, +} + +impl ReplicatedTableSchema { + /// Creates a [`ReplicatedTableSchema`] from a schema and a pre-computed mask. + pub fn from_mask(table_schema: Arc, replication_mask: ReplicationMask) -> Self { + debug_assert_eq!( + table_schema.column_schemas.len(), + replication_mask.len(), + "mask length must match column count" + ); + + Self { + table_schema, + replication_mask, + } + } + + /// Creates a [`ReplicatedTableSchema`] where all columns are replicated. + pub fn all(table_schema: Arc) -> Self { + let replication_mask = ReplicationMask::all(&table_schema); + Self { + table_schema, + replication_mask, + } + } + + /// Returns the table ID. + pub fn id(&self) -> TableId { + self.table_schema.id + } + + /// Returns the table name. + pub fn name(&self) -> &TableName { + &self.table_schema.name + } + + /// Returns the underlying table schema. + pub fn get_inner(&self) -> &TableSchema { + &self.table_schema + } + + /// Returns the replication mask. + pub fn replication_mask(&self) -> &ReplicationMask { + &self.replication_mask + } + + /// Returns an iterator over only the column schemas that are being replicated. + /// + /// This filters the columns based on the mask, returning only those where the + /// corresponding mask value is 1. + pub fn column_schemas(&self) -> impl Iterator + Clone + '_ { + // Assuming that the schema is created via the constructor, we can safely assume that the + // column schemas and replication mask are of the same length. + debug_assert!( + self.replication_mask.len() == self.table_schema.column_schemas.len(), + "the replication mask columns have a different len from the table schema columns, they should be the same" + ); + + self.table_schema + .column_schemas + .iter() + .zip(self.replication_mask.as_slice().iter()) + .filter_map(|(cs, &m)| if m == 1 { Some(cs) } else { None }) + } + + /// Computes the diff between this schema (old) and another schema (new). + /// + /// Only consider replicated columns. Uses ordinal positions to track columns: + /// - Columns in the same position with different names are renamed. + /// - Positions in old but not in new are columns to remove. + /// - Positions in new but not in old are columns to add. + pub fn diff(&self, new_schema: &ReplicatedTableSchema) -> SchemaDiff { + // Build maps: ordinal_position -> ColumnSchema for replicated columns only. + let old_columns: HashMap = self + .column_schemas() + .map(|col| (col.ordinal_position, col)) + .collect(); + + let new_columns: HashMap = new_schema + .column_schemas() + .map(|col| (col.ordinal_position, col)) + .collect(); + + let old_positions: HashSet = old_columns.keys().copied().collect(); + let new_positions: HashSet = new_columns.keys().copied().collect(); + + // Intersection: common positions (potential renames). + let common_positions: HashSet = old_positions + .intersection(&new_positions) + .copied() + .collect(); + + // Columns to rename: same position, different name. + let columns_to_rename: Vec = common_positions + .iter() + .filter_map(|pos| { + let old_col = old_columns.get(pos).unwrap(); + let new_col = new_columns.get(pos).unwrap(); + + if old_col.name != new_col.name { + Some(ColumnRename { + old_name: old_col.name.clone(), + new_name: new_col.name.clone(), + ordinal_position: *pos, + }) + } else { + None + } + }) + .collect(); + + // Columns to remove: positions in old but not in new. + let positions_to_remove: HashSet = + old_positions.difference(&new_positions).copied().collect(); + let columns_to_remove: Vec = positions_to_remove + .iter() + .map(|pos| old_columns.get(pos).unwrap()) + .cloned() + .cloned() + .collect(); + + // Columns to add: positions in new but not in old. + let positions_to_add: HashSet = + new_positions.difference(&old_positions).copied().collect(); + let columns_to_add: Vec = positions_to_add + .iter() + .map(|pos| new_columns.get(pos).unwrap()) + .cloned() + .cloned() + .collect(); + + SchemaDiff { + columns_to_add, + columns_to_remove, + columns_to_rename, + } + } +} + +/// Represents differences between two schema versions. +/// +/// Used to determine what schema changes need to be applied to a destination +/// when the source schema has evolved. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SchemaDiff { + /// Columns that need to be added to the destination. + pub columns_to_add: Vec, + /// Columns that need to be removed from the destination. + pub columns_to_remove: Vec, + /// Columns that need to be renamed in the destination. + pub columns_to_rename: Vec, +} + +impl SchemaDiff { + /// Returns `true` if there are no schema changes. + pub fn is_empty(&self) -> bool { + self.columns_to_add.is_empty() + && self.columns_to_remove.is_empty() + && self.columns_to_rename.is_empty() } } -impl Ord for TableSchema { - fn cmp(&self, other: &Self) -> Ordering { - self.name.cmp(&other.name) +/// Represents a column rename operation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ColumnRename { + /// The old name of the column. + pub old_name: String, + /// The new name of the column. + pub new_name: String, + /// The ordinal position of the column (used to identify the column across renames). + pub ordinal_position: i32, +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_table_schema() -> TableSchema { + TableSchema::new( + TableId::new(123), + TableName::new("public".to_string(), "test_table".to_string()), + vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ColumnSchema::new("age".to_string(), Type::INT4, -1, 3, None, true), + ], + ) + } + + #[test] + fn test_replication_mask_try_build_all_columns_replicated() { + let schema = create_test_table_schema(); + let replicated_columns: HashSet = ["id", "name", "age"] + .into_iter() + .map(String::from) + .collect(); + + let mask = ReplicationMask::try_build(&schema, &replicated_columns).unwrap(); + + assert_eq!(mask.as_slice(), &[1, 1, 1]); + } + + #[test] + fn test_replication_mask_try_build_partial_columns_replicated() { + let schema = create_test_table_schema(); + let replicated_columns: HashSet = + ["id", "age"].into_iter().map(String::from).collect(); + + let mask = ReplicationMask::try_build(&schema, &replicated_columns).unwrap(); + + assert_eq!(mask.as_slice(), &[1, 0, 1]); + } + + #[test] + fn test_replication_mask_try_build_no_columns_replicated() { + let schema = create_test_table_schema(); + let replicated_columns: HashSet = HashSet::new(); + + let mask = ReplicationMask::try_build(&schema, &replicated_columns).unwrap(); + + assert_eq!(mask.as_slice(), &[0, 0, 0]); + } + + #[test] + fn test_replication_mask_try_build_unknown_column_error() { + let schema = create_test_table_schema(); + let replicated_columns: HashSet = ["id", "unknown_column"] + .into_iter() + .map(String::from) + .collect(); + + let result = ReplicationMask::try_build(&schema, &replicated_columns); + + assert!(result.is_err()); + let err = result.unwrap_err(); + match err { + SchemaError::UnknownReplicatedColumns(columns) => { + assert_eq!(columns, vec!["unknown_column".to_string()]); + } + _ => panic!("expected UnknownReplicatedColumns error"), + } + } + + #[test] + fn test_replication_mask_try_build_multiple_unknown_columns_error() { + let schema = create_test_table_schema(); + let replicated_columns: HashSet = + ["id", "foo", "bar"].into_iter().map(String::from).collect(); + + let result = ReplicationMask::try_build(&schema, &replicated_columns); + + assert!(result.is_err()); + let err = result.unwrap_err(); + match err { + SchemaError::UnknownReplicatedColumns(mut columns) => { + columns.sort(); + assert_eq!(columns, vec!["bar".to_string(), "foo".to_string()]); + } + _ => panic!("expected UnknownReplicatedColumns error"), + } + } + + #[test] + fn test_replication_mask_build_or_all_success() { + let schema = create_test_table_schema(); + let replicated_columns: HashSet = + ["id", "age"].into_iter().map(String::from).collect(); + + let mask = ReplicationMask::build_or_all(&schema, &replicated_columns); + + assert_eq!(mask.as_slice(), &[1, 0, 1]); + } + + #[test] + fn test_replication_mask_build_or_all_falls_back_to_all() { + let schema = create_test_table_schema(); + let replicated_columns: HashSet = ["id", "unknown_column"] + .into_iter() + .map(String::from) + .collect(); + + let mask = ReplicationMask::build_or_all(&schema, &replicated_columns); + + // Falls back to all columns being replicated. + assert_eq!(mask.as_slice(), &[1, 1, 1]); + } + + #[test] + fn test_replication_mask_all() { + let schema = create_test_table_schema(); + let mask = ReplicationMask::all(&schema); + + assert_eq!(mask.as_slice(), &[1, 1, 1]); + } + + fn create_replicated_schema(columns: Vec) -> ReplicatedTableSchema { + let column_names: HashSet = columns.iter().map(|c| c.name.clone()).collect(); + let table_schema = Arc::new(TableSchema::new( + TableId::new(123), + TableName::new("public".to_string(), "test_table".to_string()), + columns, + )); + let mask = ReplicationMask::build(&table_schema, &column_names); + ReplicatedTableSchema::from_mask(table_schema, mask) + } + + #[test] + fn test_schema_diff_no_changes() { + let old_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ]); + let new_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ]); + + let diff = old_schema.diff(&new_schema); + + assert!(diff.is_empty()); + assert!(diff.columns_to_add.is_empty()); + assert!(diff.columns_to_remove.is_empty()); + assert!(diff.columns_to_rename.is_empty()); + } + + #[test] + fn test_schema_diff_column_added() { + let old_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ]); + let new_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ColumnSchema::new("email".to_string(), Type::TEXT, -1, 3, None, true), + ]); + + let diff = old_schema.diff(&new_schema); + + assert!(!diff.is_empty()); + assert_eq!(diff.columns_to_add.len(), 1); + assert_eq!(diff.columns_to_add[0].name, "email"); + assert_eq!(diff.columns_to_add[0].ordinal_position, 3); + assert!(diff.columns_to_remove.is_empty()); + assert!(diff.columns_to_rename.is_empty()); + } + + #[test] + fn test_schema_diff_column_removed() { + let old_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ColumnSchema::new("age".to_string(), Type::INT4, -1, 3, None, true), + ]); + let new_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ]); + + let diff = old_schema.diff(&new_schema); + + assert!(!diff.is_empty()); + assert!(diff.columns_to_add.is_empty()); + assert_eq!(diff.columns_to_remove.len(), 1); + assert_eq!(diff.columns_to_remove[0].name, "age"); + assert_eq!(diff.columns_to_remove[0].ordinal_position, 3); + assert!(diff.columns_to_rename.is_empty()); + } + + #[test] + fn test_schema_diff_column_renamed() { + let old_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ]); + let new_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("full_name".to_string(), Type::TEXT, -1, 2, None, true), + ]); + + let diff = old_schema.diff(&new_schema); + + assert!(!diff.is_empty()); + assert!(diff.columns_to_add.is_empty()); + assert!(diff.columns_to_remove.is_empty()); + assert_eq!(diff.columns_to_rename.len(), 1); + assert_eq!(diff.columns_to_rename[0].old_name, "name"); + assert_eq!(diff.columns_to_rename[0].new_name, "full_name"); + assert_eq!(diff.columns_to_rename[0].ordinal_position, 2); + } + + #[test] + fn test_schema_diff_mixed_operations() { + // Old schema: id (pos 1), name (pos 2), age (pos 3) + // New schema: id (pos 1), full_name (pos 2), email (pos 4) + // Expected: age removed (pos 3), name -> full_name renamed (pos 2), email added (pos 4) + let old_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ColumnSchema::new("age".to_string(), Type::INT4, -1, 3, None, true), + ]); + let new_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("full_name".to_string(), Type::TEXT, -1, 2, None, true), + ColumnSchema::new("email".to_string(), Type::TEXT, -1, 4, None, true), + ]); + + let diff = old_schema.diff(&new_schema); + + assert!(!diff.is_empty()); + + // Column added: email at position 4. + assert_eq!(diff.columns_to_add.len(), 1); + assert_eq!(diff.columns_to_add[0].name, "email"); + + // Column removed: age at position 3. + assert_eq!(diff.columns_to_remove.len(), 1); + assert_eq!(diff.columns_to_remove[0].name, "age"); + + // Column renamed: name -> full_name at position 2. + assert_eq!(diff.columns_to_rename.len(), 1); + assert_eq!(diff.columns_to_rename[0].old_name, "name"); + assert_eq!(diff.columns_to_rename[0].new_name, "full_name"); + } + + #[test] + fn test_schema_diff_multiple_additions() { + let old_schema = create_replicated_schema(vec![ColumnSchema::new( + "id".to_string(), + Type::INT4, + -1, + 1, + Some(1), + false, + )]); + let new_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ColumnSchema::new("email".to_string(), Type::TEXT, -1, 3, None, true), + ]); + + let diff = old_schema.diff(&new_schema); + + assert_eq!(diff.columns_to_add.len(), 2); + let added_names: HashSet<&str> = diff + .columns_to_add + .iter() + .map(|c| c.name.as_str()) + .collect(); + assert!(added_names.contains("name")); + assert!(added_names.contains("email")); + assert!(diff.columns_to_remove.is_empty()); + assert!(diff.columns_to_rename.is_empty()); + } + + #[test] + fn test_schema_diff_multiple_removals() { + let old_schema = create_replicated_schema(vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ColumnSchema::new("email".to_string(), Type::TEXT, -1, 3, None, true), + ]); + let new_schema = create_replicated_schema(vec![ColumnSchema::new( + "id".to_string(), + Type::INT4, + -1, + 1, + Some(1), + false, + )]); + + let diff = old_schema.diff(&new_schema); + + assert!(diff.columns_to_add.is_empty()); + assert_eq!(diff.columns_to_remove.len(), 2); + let removed_names: HashSet<&str> = diff + .columns_to_remove + .iter() + .map(|c| c.name.as_str()) + .collect(); + assert!(removed_names.contains("name")); + assert!(removed_names.contains("email")); + assert!(diff.columns_to_rename.is_empty()); + } + + #[test] + fn test_replication_mask_display_all_replicated() { + let schema = create_test_table_schema(); + let replicated_columns: HashSet = ["id", "name", "age"] + .into_iter() + .map(String::from) + .collect(); + + let mask = ReplicationMask::build(&schema, &replicated_columns); + + assert_eq!(mask.to_string(), "(1,1,1)"); + } + + #[test] + fn test_replication_mask_display_partial_replicated() { + let schema = create_test_table_schema(); + let replicated_columns: HashSet = + ["id", "age"].into_iter().map(String::from).collect(); + + let mask = ReplicationMask::build(&schema, &replicated_columns); + + assert_eq!(mask.to_string(), "(1,0,1)"); + } + + #[test] + fn test_replication_mask_display_none_replicated() { + let schema = create_test_table_schema(); + let replicated_columns: HashSet = HashSet::new(); + + let mask = ReplicationMask::build(&schema, &replicated_columns); + + assert_eq!(mask.to_string(), "(0,0,0)"); + } + + #[test] + fn test_replication_mask_display_empty() { + let mask = ReplicationMask::from_bytes(vec![]); + + assert_eq!(mask.to_string(), "()"); } } diff --git a/etl-postgres/src/types/utils.rs b/etl-postgres/src/types/utils.rs index 731f09f97..388180fe9 100644 --- a/etl-postgres/src/types/utils.rs +++ b/etl-postgres/src/types/utils.rs @@ -42,13 +42,14 @@ pub fn is_array_type(typ: &Type) -> bool { /// even when they have the same system time. The format is compatible with BigQuery's /// `_CHANGE_SEQUENCE_NUMBER` column requirements. /// -/// The rationale for using the LSN is that BigQuery will preserve the highest sequence number -/// in case of equal primary key, which is what we want since in case of updates, we want the -/// latest update in Postgres order to be the winner. We have first the `commit_lsn` in the key -/// so that BigQuery can first order operations based on the LSN at which the transaction committed -/// and if two operations belong to the same transaction (meaning they have the same LSN), the -/// `start_lsn` will be used. We first order by `commit_lsn` to preserve the order in which operations -/// are received by the pipeline since transactions are ordered by commit time and not interleaved. +/// The rationale for using the LSN is that downstream systems will preserve the highest sequence +/// number in case of equal primary key, which is what we want since in case of updates, we want +/// the latest update in Postgres order to be the winner. We have first the `commit_lsn` in the key +/// so that operations are first ordered based on the LSN at which the transaction committed, +/// and if two operations belong to the same transaction (meaning they have the same `commit_lsn`), the +/// `start_lsn` will be used as a tiebreaker. We first order by `commit_lsn` to preserve the order +/// in which operations are received by the pipeline since transactions are ordered by commit time +/// and not interleaved. pub fn generate_sequence_number(start_lsn: PgLsn, commit_lsn: PgLsn) -> String { let start_lsn = u64::from(start_lsn); let commit_lsn = u64::from(commit_lsn); diff --git a/etl-replicator/scripts/run_migrations.sh b/etl-replicator/scripts/run_migrations.sh deleted file mode 100755 index 6786f6b53..000000000 --- a/etl-replicator/scripts/run_migrations.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env bash -set -eo pipefail - -if [ ! -d "etl-replicator/migrations" ]; then - echo >&2 "❌ Error: 'etl-replicator/migrations' folder not found." - echo >&2 "Please run this script from the 'etl' directory." - exit 1 -fi - -if ! [ -x "$(command -v sqlx)" ]; then - echo >&2 "❌ Error: SQLx CLI is not installed." - echo >&2 "To install it, run:" - echo >&2 " cargo install --version='~0.7' sqlx-cli --no-default-features --features rustls,postgres" - exit 1 -fi - -if ! [ -x "$(command -v psql)" ]; then - echo >&2 "❌ Error: Postgres client (psql) is not installed." - echo >&2 "Please install it using your system's package manager." - exit 1 -fi - -# Database configuration -DB_USER="${POSTGRES_USER:=postgres}" -DB_PASSWORD="${POSTGRES_PASSWORD:=postgres}" -DB_NAME="${POSTGRES_DB:=postgres}" -DB_PORT="${POSTGRES_PORT:=5430}" -DB_HOST="${POSTGRES_HOST:=localhost}" - -# Set up the database URL -export DATABASE_URL=postgres://${DB_USER}:${DB_PASSWORD}@${DB_HOST}:${DB_PORT}/${DB_NAME} - -echo "🔄 Running replicator state store migrations..." - -# Create the etl schema if it doesn't exist -# This matches the behavior in etl-replicator/src/migrations.rs -psql "${DATABASE_URL}" -v ON_ERROR_STOP=1 -c "create schema if not exists etl;" > /dev/null - -# Create a temporary sqlx-cli compatible database URL that sets the search_path -# This ensures the _sqlx_migrations table is created in the etl schema -SQLX_MIGRATIONS_OPTS="options=-csearch_path%3Detl" -MIGRATION_URL="${DATABASE_URL}?${SQLX_MIGRATIONS_OPTS}" - -# Run migrations with the modified URL -sqlx database create --database-url "${DATABASE_URL}" -sqlx migrate run --source etl-replicator/migrations --database-url "${MIGRATION_URL}" - -echo "✨ Replicator state store migrations complete! Ready to go!" diff --git a/etl-replicator/src/core.rs b/etl-replicator/src/core.rs index fdaaf051d..ff1626aaf 100644 --- a/etl-replicator/src/core.rs +++ b/etl-replicator/src/core.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; -use crate::migrations::migrate_state_store; use etl::destination::memory::MemoryDestination; use etl::pipeline::Pipeline; use etl::store::both::postgres::PostgresStore; @@ -38,11 +37,11 @@ pub async fn start_replicator_with_config( log_config(&replicator_config); // We initialize the state store, which for the replicator is not configurable. + // Migrations are run by the pipeline during startup. let state_store = init_store( replicator_config.pipeline.id, replicator_config.pipeline.pg_connection.clone(), - ) - .await?; + ); // For each destination, we start the pipeline. This is more verbose due to static dispatch, but // we prefer more performance at the cost of ergonomics. @@ -258,17 +257,15 @@ fn log_batch_config(config: &BatchConfig) { ); } -/// Initializes the state store with migrations. +/// Initializes the state store. /// -/// Runs necessary database migrations on the state store and creates a -/// [`PostgresStore`] instance for the given pipeline and connection configuration. -async fn init_store( +/// Creates a [`PostgresStore`] instance for the given pipeline and connection configuration. +/// Migrations are handled by the pipeline during startup. +fn init_store( pipeline_id: PipelineId, pg_connection_config: PgConnectionConfig, -) -> anyhow::Result { - migrate_state_store(&pg_connection_config).await?; - - Ok(PostgresStore::new(pipeline_id, pg_connection_config)) +) -> impl StateStore + SchemaStore + CleanupStore + Clone { + PostgresStore::new(pipeline_id, pg_connection_config) } /// Starts a pipeline and handles graceful shutdown signals. diff --git a/etl-replicator/src/main.rs b/etl-replicator/src/main.rs index f4ae1534d..07490d4ef 100644 --- a/etl-replicator/src/main.rs +++ b/etl-replicator/src/main.rs @@ -51,7 +51,6 @@ mod core; mod feature_flags; #[cfg(not(target_env = "msvc"))] mod jemalloc_metrics; -mod migrations; mod notification; /// The name of the environment variable which contains version information for this replicator. diff --git a/etl/Cargo.toml b/etl/Cargo.toml index 64afe8446..f5386fcda 100644 --- a/etl/Cargo.toml +++ b/etl/Cargo.toml @@ -26,10 +26,12 @@ metrics = { workspace = true } pg_escape = { workspace = true } pin-project-lite = { workspace = true } postgres-replication = { workspace = true } +rand = { workspace = true, features = ["thread_rng"] } ring = { workspace = true, default-features = false } rustls = { workspace = true, features = ["aws-lc-rs", "logging"] } +serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, features = ["std"] } -sqlx = { workspace = true, features = ["runtime-tokio-rustls", "postgres"] } +sqlx = { workspace = true, features = ["runtime-tokio-rustls", "postgres", "migrate"] } tokio = { workspace = true, features = ["rt-multi-thread"] } tokio-postgres = { workspace = true, features = [ "runtime", @@ -49,5 +51,3 @@ etl-postgres = { workspace = true, features = [ "test-utils", ] } etl-telemetry = { workspace = true } - -rand = { workspace = true, features = ["thread_rng"] } diff --git a/etl-replicator/migrations/20250827000000_base.sql b/etl/migrations/20250827000000_base.sql similarity index 100% rename from etl-replicator/migrations/20250827000000_base.sql rename to etl/migrations/20250827000000_base.sql diff --git a/etl/migrations/20251127090000_schema_change_messages.sql b/etl/migrations/20251127090000_schema_change_messages.sql new file mode 100644 index 000000000..620566ba8 --- /dev/null +++ b/etl/migrations/20251127090000_schema_change_messages.sql @@ -0,0 +1,167 @@ +-- Schema change logical messages (DDL) +-- Adds helpers and trigger to emit logical decoding messages when tables change. + +create or replace function etl.describe_table_schema( + p_table pg_catalog.oid +) returns table ( + name pg_catalog.text, + type_oid pg_catalog.oid, + type_name pg_catalog.text, + type_modifier pg_catalog.int4, + ordinal_position pg_catalog.int4, + primary_key_ordinal_position pg_catalog.int4, + nullable pg_catalog.bool +) +language sql +stable +set search_path = pg_catalog +as +$fnc$ +with direct_parent as ( + select i.inhparent as parent_oid + from pg_catalog.pg_inherits i + where i.inhrelid = p_table + limit 1 +), +primary_key as ( + select x.attnum, x.n as position + from pg_catalog.pg_constraint con + cross join lateral unnest(con.conkey) with ordinality as x(attnum, n) + where con.contype = 'p' + and con.conrelid = p_table +), +parent_primary_key as ( + select a.attname, x.n as position + from pg_catalog.pg_constraint con + cross join lateral unnest(con.conkey) with ordinality as x(attnum, n) + join pg_catalog.pg_attribute a on a.attrelid = con.conrelid and a.attnum = x.attnum + join direct_parent dp on dp.parent_oid = con.conrelid + where con.contype = 'p' +) +select + a.attname::pg_catalog.text, + a.atttypid, + case + when tn.nspname = 'pg_catalog' then t.typname + else tn.nspname || '.' || t.typname + end::pg_catalog.text, + a.atttypmod::pg_catalog.int4, + a.attnum::pg_catalog.int4, + coalesce(pk.position, ppk.position)::pg_catalog.int4, + not a.attnotnull +from pg_catalog.pg_attribute a +join pg_catalog.pg_type t on t.oid = a.atttypid +join pg_catalog.pg_namespace tn on tn.oid = t.typnamespace +left join primary_key pk on pk.attnum = a.attnum +left join parent_primary_key ppk on ppk.attname = a.attname +where a.attrelid = p_table + and a.attnum > 0 + and not a.attisdropped + and a.attgenerated = '' +order by a.attnum; +$fnc$; + +create or replace function etl.emit_schema_change_messages() +returns pg_catalog.event_trigger +language plpgsql +set search_path = pg_catalog +as +$fnc$ +declare + v_object_type pg_catalog.text; + v_objid pg_catalog.oid; + v_command_tag pg_catalog.text; + v_table_schema pg_catalog.text; + v_table_name pg_catalog.text; + v_schema_json pg_catalog.jsonb; + v_msg_json pg_catalog.jsonb; + v_wal_level pg_catalog.text; +begin + -- Check if logical replication is enabled; if not, silently skip. + -- This prevents crashes when Supabase ETL is installed but wal_level != logical. + v_wal_level := current_setting('wal_level', true); + if v_wal_level is distinct from 'logical' then + raise warning '[Supabase ETL] wal_level is %, not logical. Schema change will not be captured.', v_wal_level; + return; + end if; + + for v_object_type, v_objid, v_command_tag in + select object_type, objid, command_tag from pg_event_trigger_ddl_commands() + loop + begin + -- 'table' covers most ALTER TABLE operations (ADD/DROP COLUMN, ALTER TYPE, etc.) + -- 'table column' is returned specifically for RENAME COLUMN operations + if v_object_type not in ('table', 'table column') then + continue; + end if; + + if v_objid is null then + continue; + end if; + + select n.nspname, c.relname + into v_table_schema, v_table_name + from pg_catalog.pg_class c + join pg_catalog.pg_namespace n on n.oid = c.relnamespace + where c.oid = v_objid + and c.relkind in ('r', 'p'); + + if v_table_schema is null or v_table_name is null then + continue; + end if; + + select pg_catalog.jsonb_agg( + pg_catalog.jsonb_build_object( + 'name', s.name, + 'type_oid', s.type_oid::pg_catalog.int8, + 'type_modifier', s.type_modifier, + 'ordinal_position', s.ordinal_position, + 'primary_key_ordinal_position', s.primary_key_ordinal_position, + 'nullable', s.nullable + ) + ) + into v_schema_json + from etl.describe_table_schema(v_objid) s; + + if v_schema_json is null then + continue; + end if; + + v_msg_json := pg_catalog.jsonb_build_object( + 'event', v_command_tag, + 'schema_name', v_table_schema, + 'table_name', v_table_name, + 'table_id', v_objid::pg_catalog.int8, + 'columns', v_schema_json + ); + + perform pg_catalog.pg_logical_emit_message( + true, + 'supabase_etl_ddl', + pg_catalog.convert_to(v_msg_json::pg_catalog.text, 'utf8') + ); + + exception when others then + -- Never crash customer DDL; log warning instead. + raise warning using + message = format('[Supabase ETL] emit_schema_change_messages failed for table %s: %s', + coalesce(v_objid::pg_catalog.regclass::pg_catalog.text, 'unknown'), SQLERRM), + detail = 'You may need to repeat this DDL command on the downstream to keep logical replication running.'; + end; + end loop; +exception when others then + -- Outer safety net. + raise warning '[Supabase ETL] emit_schema_change_messages outer exception: %', SQLERRM; +end; +$fnc$; + +drop event trigger if exists etl_ddl_message_trigger; + +-- Only ALTER TABLE is captured because: +-- - CREATE TABLE: No need, since the initial schema is loaded during the first table copy operation. +-- - DROP TABLE: No need, since dropped tables are not supported right now. +-- This trigger focuses on schema changes to existing replicated tables. +create event trigger etl_ddl_message_trigger + on ddl_command_end + when tag in ('ALTER TABLE') + execute function etl.emit_schema_change_messages(); diff --git a/etl/migrations/20251127120000_column_schema_extensions.sql b/etl/migrations/20251127120000_column_schema_extensions.sql new file mode 100644 index 000000000..af28e119a --- /dev/null +++ b/etl/migrations/20251127120000_column_schema_extensions.sql @@ -0,0 +1,27 @@ +-- Add new column fields to support extended schema information. +-- +-- These columns store additional metadata about each column: +-- - primary_key_ordinal_position: The order within the primary key (1-based), NULL if not a primary key + +-- Add new columns +alter table etl.table_columns + add column if not exists primary_key_ordinal_position pg_catalog.int4; + +-- Backfill primary_key_ordinal_position by querying the actual PK constraint from pg_constraint. +-- This assumes the source tables still exist and their PK order hasn't changed. +update etl.table_columns tc +set primary_key_ordinal_position = pk_info.pk_position +from ( + select + tc_inner.id as column_id, + x.n as pk_position + from etl.table_columns tc_inner + join etl.table_schemas ts on ts.id = tc_inner.table_schema_id + join pg_catalog.pg_constraint con on con.conrelid = ts.table_id and con.contype = 'p' + join pg_catalog.pg_attribute a on a.attrelid = ts.table_id and a.attname = tc_inner.column_name + cross join lateral unnest(con.conkey) with ordinality as x(attnum, n) + where x.attnum = a.attnum + and tc_inner.primary_key = true +) pk_info +where tc.id = pk_info.column_id + and tc.primary_key_ordinal_position is null; diff --git a/etl/migrations/20251205000000_schema_versioning.sql b/etl/migrations/20251205000000_schema_versioning.sql new file mode 100644 index 000000000..4796029e7 --- /dev/null +++ b/etl/migrations/20251205000000_schema_versioning.sql @@ -0,0 +1,19 @@ +-- Add snapshot_id column to table_schemas for schema versioning. +-- The snapshot_id value is the start_lsn of the DDL message that created this schema version. +-- Initial schemas use snapshot_id='0/0'. + +ALTER TABLE etl.table_schemas + ADD COLUMN IF NOT EXISTS snapshot_id PG_LSN NOT NULL DEFAULT '0/0'; + +-- Change unique constraint from (pipeline_id, table_id) to (pipeline_id, table_id, snapshot_id) +-- to allow multiple schema versions per table. +ALTER TABLE etl.table_schemas + DROP CONSTRAINT IF EXISTS table_schemas_pipeline_id_table_id_key; + +ALTER TABLE etl.table_schemas + ADD CONSTRAINT table_schemas_pipeline_id_table_id_snapshot_id_key + UNIQUE (pipeline_id, table_id, snapshot_id); + +-- Index for efficient "find largest snapshot_id <= X" queries. +CREATE INDEX IF NOT EXISTS idx_table_schemas_pipeline_table_snapshot_id + ON etl.table_schemas (pipeline_id, table_id, snapshot_id DESC); diff --git a/etl/migrations/20251211000000_destination_tables_metadata.sql b/etl/migrations/20251211000000_destination_tables_metadata.sql new file mode 100644 index 000000000..4611918c2 --- /dev/null +++ b/etl/migrations/20251211000000_destination_tables_metadata.sql @@ -0,0 +1,66 @@ +-- Unified destination table metadata. +-- +-- Tracks all destination-related state for each replicated table in a single row. + +-- Enum for destination table schema status. +CREATE TYPE etl.destination_table_schema_status AS ENUM ( + 'applying', + 'applied' +); + +CREATE TABLE etl.destination_tables_metadata ( + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + pipeline_id BIGINT NOT NULL, + table_id OID NOT NULL, + -- The name/identifier of the table in the destination system. + destination_table_id TEXT NOT NULL, + -- The snapshot_id of the schema currently applied at the destination. + snapshot_id PG_LSN NOT NULL, + -- The schema version before the current change. NULL for initial schemas. + -- Destinations that support atomic DDL can use this for recovery by rolling back + -- to the previous snapshot when schema_status is 'applying' on startup. + previous_snapshot_id PG_LSN, + -- Status: 'applying' when a schema change is in progress, 'applied' when complete. + -- If 'applying' is found on startup, recovery may be needed. + schema_status etl.destination_table_schema_status NOT NULL, + -- The replication mask as a byte array where each byte is 0 (not replicated) or 1 (replicated). + -- The index corresponds to the column's ordinal position in the schema. + replication_mask BYTEA NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + -- One metadata row per table per pipeline. + UNIQUE (pipeline_id, table_id) +); + +-- Backfill destination_tables_metadata from existing table_mappings and table_schemas. +-- For each mapped table, create metadata with snapshot_id = 0/0 (initial schema) +-- and all columns marked as replicated. +INSERT INTO etl.destination_tables_metadata ( + pipeline_id, + table_id, + destination_table_id, + snapshot_id, + schema_status, + replication_mask +) +SELECT + tm.pipeline_id, + tm.source_table_id AS table_id, + tm.destination_table_id, + '0/0'::pg_lsn AS snapshot_id, + 'applied'::etl.destination_table_schema_status, + -- Create a bytea of 1s with length equal to the number of columns. + decode(repeat('01', column_counts.num_columns::int), 'hex') AS replication_mask +FROM etl.table_mappings tm +JOIN etl.table_schemas ts ON ts.pipeline_id = tm.pipeline_id + AND ts.table_id = tm.source_table_id + AND ts.snapshot_id = '0/0'::pg_lsn +JOIN ( + SELECT table_schema_id, COUNT(*) AS num_columns + FROM etl.table_columns + GROUP BY table_schema_id +) column_counts ON column_counts.table_schema_id = ts.id; + +-- Drop the old table_mappings table as it's now unified into destination_tables_metadata. +-- This is a breaking change, for now we assume that no two pipelines share the same state storage. +DROP TABLE etl.table_mappings; diff --git a/etl/src/conversions/event.rs b/etl/src/conversions/event.rs index 235a85828..6391b0048 100644 --- a/etl/src/conversions/event.rs +++ b/etl/src/conversions/event.rs @@ -1,20 +1,97 @@ use core::str; +use std::collections::HashSet; + use etl_postgres::types::{ - ColumnSchema, TableId, TableName, TableSchema, convert_type_oid_to_type, + ColumnSchema, ReplicatedTableSchema, SnapshotId, TableId, TableName, TableSchema, + convert_type_oid_to_type, }; use postgres_replication::protocol; -use std::sync::Arc; +use serde::Deserialize; use tokio_postgres::types::PgLsn; use crate::conversions::text::{default_value_for_type, parse_cell_from_postgres_text}; use crate::error::{ErrorKind, EtlResult}; -use crate::store::schema::SchemaStore; use crate::types::{ - BeginEvent, Cell, CommitEvent, DeleteEvent, InsertEvent, RelationEvent, TableRow, - TruncateEvent, UpdateEvent, + BeginEvent, Cell, CommitEvent, DeleteEvent, InsertEvent, TableRow, TruncateEvent, UpdateEvent, }; use crate::{bail, etl_error}; +/// The prefix used for DDL schema change messages emitted by the `etl.emit_schema_change_messages` +/// event trigger. Messages with this prefix contain JSON-encoded schema information. +pub const DDL_MESSAGE_PREFIX: &str = "supabase_etl_ddl"; + +/// Represents a schema change message emitted by Postgres event trigger. +/// +/// This message is emitted when ALTER TABLE commands are executed on tables +/// that are part of a publication. +#[derive(Debug, Clone, Deserialize)] +pub struct SchemaChangeMessage { + /// The DDL command that triggered this message (e.g., "ALTER TABLE"). + pub event: String, + /// The schema name of the affected table. + pub schema_name: String, + /// The name of the affected table. + pub table_name: String, + /// The OID of the affected table. + /// + /// PostgreSQL table OIDs are `u32` values, but JSON serialization from the event trigger + /// uses `bigint` (i64) for transmission. The cast back to `u32` in [`into_table_schema`] + /// is safe because PostgreSQL OIDs are always within the `u32` range. + pub table_id: i64, + /// The columns of the table after the schema change. + pub columns: Vec, +} + +impl SchemaChangeMessage { + /// Converts a [`SchemaChangeMessage`] to a [`TableSchema`] with a specific snapshot ID. + /// + /// This is used to update the stored table schema when a DDL change is detected. + /// The snapshot_id should be the start_lsn of the DDL message. + pub fn into_table_schema(self, snapshot_id: SnapshotId) -> TableSchema { + let table_name = TableName::new(self.schema_name, self.table_name); + let column_schemas = self + .columns + .into_iter() + .map(|column| { + let typ = convert_type_oid_to_type(column.type_oid); + ColumnSchema::new( + column.name, + typ, + column.type_modifier, + column.ordinal_position, + column.primary_key_ordinal_position, + column.nullable, + ) + }) + .collect(); + + TableSchema::with_snapshot_id( + TableId::new(self.table_id as u32), + table_name, + column_schemas, + snapshot_id, + ) + } +} + +/// Represents a column schema in a schema change message. +#[allow(dead_code)] +#[derive(Debug, Clone, Deserialize)] +pub struct ColumnSchemaMessage { + /// The name of the column. + pub name: String, + /// The OID of the column's data type. + pub type_oid: u32, + /// Type-specific modifier value (e.g., length for varchar). + pub type_modifier: i32, + /// The 1-based ordinal position of the column in the table. + pub ordinal_position: i32, + /// The 1-based ordinal position of this column in the primary key, or null if not a primary key. + pub primary_key_ordinal_position: Option, + /// Whether the column can contain NULL values. + pub nullable: bool, +} + /// Creates a [`BeginEvent`] from Postgres protocol data. /// /// This method parses the replication protocol begin message and extracts @@ -50,56 +127,38 @@ pub fn parse_event_from_commit_message( } } -/// Creates a [`RelationEvent`] from Postgres protocol data. -/// -/// This method parses the replication protocol relation message and builds -/// a complete table schema for use in interpreting subsequent data events. -pub fn parse_event_from_relation_message( - start_lsn: PgLsn, - commit_lsn: PgLsn, +/// Returns the set of column names to replicate from a relation message. +pub fn parse_replicated_column_names( relation_body: &protocol::RelationBody, -) -> EtlResult { - let table_name = TableName::new( - relation_body.namespace()?.to_string(), - relation_body.name()?.to_string(), - ); - let column_schemas = relation_body +) -> EtlResult> { + let column_names = relation_body .columns() .iter() - .map(build_column_schema) - .collect::, _>>()?; - let table_schema = TableSchema::new( - TableId::new(relation_body.rel_id()), - table_name, - column_schemas, - ); + .map(parse_column_name_from_column) + .collect::, _>>()?; - Ok(RelationEvent { - start_lsn, - commit_lsn, - table_schema, - }) + Ok(column_names) +} + +/// Extracts the column name from a [`protocol::Column`] object. +fn parse_column_name_from_column(column: &protocol::Column) -> EtlResult { + let column_name = column.name()?.to_string(); + + Ok(column_name) } /// Converts a Postgres insert message into an [`InsertEvent`]. /// -/// This function processes an insert operation from the replication stream, -/// retrieves the table schema from the store, and constructs a complete -/// insert event with the new row data ready for ETL processing. -pub async fn parse_event_from_insert_message( - schema_store: &S, +/// This function processes an insert operation from the replication stream +/// and constructs an insert event with the new row data ready for ETL processing. +pub fn parse_event_from_insert_message( + replicated_table_schema: ReplicatedTableSchema, start_lsn: PgLsn, commit_lsn: PgLsn, insert_body: &protocol::InsertBody, -) -> EtlResult -where - S: SchemaStore, -{ - let table_id = insert_body.rel_id(); - let table_schema = get_table_schema(schema_store, TableId::new(table_id)).await?; - +) -> EtlResult { let table_row = convert_tuple_to_row( - &table_schema.column_schemas, + replicated_table_schema.column_schemas(), insert_body.tuple().tuple_data(), &mut None, false, @@ -108,7 +167,7 @@ where Ok(InsertEvent { start_lsn, commit_lsn, - table_id: TableId::new(table_id), + replicated_table_schema, table_row, }) } @@ -119,25 +178,19 @@ where /// handling both the old and new row data. The old row data may be either /// the complete row or just the key columns, depending on the table's /// `REPLICA IDENTITY` setting in Postgres. -pub async fn parse_event_from_update_message( - schema_store: &S, +pub fn parse_event_from_update_message( + replicated_table_schema: ReplicatedTableSchema, start_lsn: PgLsn, commit_lsn: PgLsn, update_body: &protocol::UpdateBody, -) -> EtlResult -where - S: SchemaStore, -{ - let table_id = update_body.rel_id(); - let table_schema = get_table_schema(schema_store, TableId::new(table_id)).await?; - +) -> EtlResult { // We try to extract the old tuple by either taking the entire old tuple or the key of the old // tuple. let is_key = update_body.old_tuple().is_none(); let old_tuple = update_body.old_tuple().or(update_body.key_tuple()); let old_table_row = match old_tuple { Some(identity) => Some(convert_tuple_to_row( - &table_schema.column_schemas, + replicated_table_schema.column_schemas(), identity.tuple_data(), &mut None, true, @@ -147,7 +200,7 @@ where let mut old_table_row_mut = old_table_row; let table_row = convert_tuple_to_row( - &table_schema.column_schemas, + replicated_table_schema.column_schemas(), update_body.new_tuple().tuple_data(), &mut old_table_row_mut, false, @@ -158,7 +211,7 @@ where Ok(UpdateEvent { start_lsn, commit_lsn, - table_id: TableId::new(table_id), + replicated_table_schema, table_row, old_table_row, }) @@ -170,25 +223,19 @@ where /// extracting the old row data that was deleted. The old row data may be /// either the complete row or just the key columns, depending on the table's /// `REPLICA IDENTITY` setting in Postgres. -pub async fn parse_event_from_delete_message( - schema_store: &S, +pub fn parse_event_from_delete_message( + replicated_table_schema: ReplicatedTableSchema, start_lsn: PgLsn, commit_lsn: PgLsn, delete_body: &protocol::DeleteBody, -) -> EtlResult -where - S: SchemaStore, -{ - let table_id = delete_body.rel_id(); - let table_schema = get_table_schema(schema_store, TableId::new(table_id)).await?; - +) -> EtlResult { // We try to extract the old tuple by either taking the entire old tuple or the key of the old // tuple. let is_key = delete_body.old_tuple().is_none(); let old_tuple = delete_body.old_tuple().or(delete_body.key_tuple()); let old_table_row = match old_tuple { Some(identity) => Some(convert_tuple_to_row( - &table_schema.column_schemas, + replicated_table_schema.column_schemas(), identity.tuple_data(), &mut None, true, @@ -200,7 +247,7 @@ where Ok(DeleteEvent { start_lsn, commit_lsn, - table_id: TableId::new(table_id), + replicated_table_schema, old_table_row, }) } @@ -213,56 +260,16 @@ pub fn parse_event_from_truncate_message( start_lsn: PgLsn, commit_lsn: PgLsn, truncate_body: &protocol::TruncateBody, - overridden_rel_ids: Vec, + truncated_tables: Vec, ) -> TruncateEvent { TruncateEvent { start_lsn, commit_lsn, options: truncate_body.options(), - rel_ids: overridden_rel_ids, + truncated_tables, } } -/// Retrieves a table schema from the schema store by table ID. -/// -/// This function looks up the table schema for the specified table ID in the -/// schema store. If the schema is not found, it returns an error indicating -/// that the table is missing from the cache. -async fn get_table_schema(schema_store: &S, table_id: TableId) -> EtlResult> -where - S: SchemaStore, -{ - schema_store - .get_table_schema(&table_id) - .await? - .ok_or_else(|| { - etl_error!( - ErrorKind::MissingTableSchema, - "Table schema not found in cache", - format!("Table schema for table {} not found in cache", table_id) - ) - }) -} - -/// Constructs a [`ColumnSchema`] from Postgres protocol column data. -/// -/// This helper method extracts column metadata from the replication protocol -/// and converts it into the internal column schema representation. Some fields -/// like nullable status have default values due to protocol limitations. -fn build_column_schema(column: &protocol::Column) -> EtlResult { - Ok(ColumnSchema::new( - column.name()?.to_string(), - convert_type_oid_to_type(column.type_id() as u32), - column.type_modifier(), - // We do not have access to this information, so we default it to `false`. - // TODO: figure out how to fill this value correctly or how to handle the missing value - // better. - false, - // Currently 1 means that the column is part of the primary key. - column.flags() == 1, - )) -} - /// Converts Postgres tuple data into a [`TableRow`] using column schemas. /// /// This function transforms raw tuple data from the replication protocol into @@ -275,15 +282,15 @@ fn build_column_schema(column: &protocol::Column) -> EtlResult { /// Panics if a required (non-nullable) column receives null data and /// `use_default_for_missing_cols` is false, as this indicates protocol-level /// corruption that should not be handled gracefully. -pub fn convert_tuple_to_row( - column_schemas: &[ColumnSchema], +pub fn convert_tuple_to_row<'a>( + column_schemas: impl Iterator, tuple_data: &[protocol::TupleData], old_table_row: &mut Option, use_default_for_missing_cols: bool, ) -> EtlResult { - let mut values = Vec::with_capacity(column_schemas.len()); + let mut values = Vec::with_capacity(tuple_data.len()); - for (i, column_schema) in column_schemas.iter().enumerate() { + for (i, column_schema) in column_schemas.enumerate() { // We are expecting that for each column, there is corresponding tuple data, even for null // values. let Some(tuple_data) = &tuple_data.get(i) else { @@ -300,7 +307,7 @@ pub fn convert_tuple_to_row( } else if use_default_for_missing_cols { default_value_for_type(&column_schema.typ)? } else { - // This is protocol level error, so we panic instead of carrying on + // This is a protocol level error, so we panic instead of carrying on // with incorrect data to avoid corruption downstream. panic!( "A required column {} was missing from the tuple", @@ -341,3 +348,16 @@ pub fn convert_tuple_to_row( Ok(TableRow { values }) } + +/// Parses a DDL schema change message from its JSON content. +/// +/// Returns the parsed message if successful, or an error if the JSON is malformed. +pub fn parse_schema_change_message(content: &str) -> EtlResult { + serde_json::from_str(content).map_err(|e| { + etl_error!( + ErrorKind::ConversionError, + "Failed to parse schema change message", + format!("Invalid JSON in schema change message: {}", e) + ) + }) +} diff --git a/etl/src/conversions/table_row.rs b/etl/src/conversions/table_row.rs index c34c95944..bb58fd486 100644 --- a/etl/src/conversions/table_row.rs +++ b/etl/src/conversions/table_row.rs @@ -16,14 +16,13 @@ use crate::types::{Cell, TableRow}; /// # Panics /// /// Panics if the number of parsed values doesn't match the number of column schemas. -pub fn parse_table_row_from_postgres_copy_bytes( +pub fn parse_table_row_from_postgres_copy_bytes<'a>( row: &[u8], - column_schemas: &[ColumnSchema], + mut column_schemas: impl Iterator, ) -> EtlResult { - let mut values = Vec::with_capacity(column_schemas.len()); + let mut values = Vec::new(); let row_str = str::from_utf8(row)?; - let mut column_schemas_iter = column_schemas.iter(); let mut chars = row_str.chars(); let mut val_str = String::with_capacity(10); let mut in_escape = false; @@ -99,15 +98,10 @@ pub fn parse_table_row_from_postgres_copy_bytes( // Process the parsed field value if we're not done with the entire row if !done { // Get the next column schema - error if we have more fields than expected - let Some(column_schema) = column_schemas_iter.next() else { + let Some(column_schema) = column_schemas.next() else { bail!( ErrorKind::ConversionError, - "Column count mismatch between schema and row", - format!( - "Schema has {} columns but row has {} columns", - column_schemas.len(), - values.len() - ) + "Column count mismatch between schema and row" ); }; @@ -143,15 +137,10 @@ pub fn parse_table_row_from_postgres_copy_bytes( // Validate that all expected columns were present in the row // If there are still columns left in the schema iterator, it means the row // had fewer fields than expected, which is an error - if column_schemas_iter.next().is_some() { + if column_schemas.next().is_some() { bail!( ErrorKind::ConversionError, - "Column count mismatch between schema and row", - format!( - "Schema has {} columns but row has {} columns", - column_schemas.len(), - values.len() - ) + "Column count mismatch between schema and row" ); } @@ -165,24 +154,43 @@ mod tests { use etl_postgres::types::ColumnSchema; use tokio_postgres::types::Type; - fn create_test_schema() -> Vec { + /// Creates a test column schema with sensible defaults. + fn test_column( + name: &str, + typ: Type, + ordinal_position: i32, + nullable: bool, + primary_key: bool, + ) -> ColumnSchema { + ColumnSchema::new( + name.to_string(), + typ, + -1, + ordinal_position, + if primary_key { Some(1) } else { None }, + nullable, + ) + } + + fn create_test_column_schemas() -> Vec { vec![ - ColumnSchema::new("id".to_string(), Type::INT4, -1, false, true), - ColumnSchema::new("name".to_string(), Type::TEXT, -1, true, false), - ColumnSchema::new("active".to_string(), Type::BOOL, -1, false, false), + test_column("id", Type::INT4, 1, false, true), + test_column("name", Type::TEXT, 2, true, false), + test_column("active", Type::BOOL, 3, false, false), ] } fn create_single_column_schema(name: &str, typ: Type) -> Vec { - vec![ColumnSchema::new(name.to_string(), typ, -1, false, false)] + vec![test_column(name, typ, 1, false, false)] } #[test] fn try_from_simple_row() { - let schema = create_test_schema(); + let column_schemas = create_test_column_schemas(); let row_data = b"123\tJohn Doe\tt\n"; - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!(result.values.len(), 3); assert_eq!(result.values[0], Cell::I32(123)); @@ -192,10 +200,11 @@ mod tests { #[test] fn try_from_with_null_values() { - let schema = create_test_schema(); + let column_schemas = create_test_column_schemas(); let row_data = b"456\t\\N\tf\n"; - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!(result.values.len(), 3); assert_eq!(result.values[0], Cell::I32(456)); @@ -205,10 +214,11 @@ mod tests { #[test] fn try_from_empty_strings() { - let schema = create_test_schema(); + let column_schemas = create_test_column_schemas(); let row_data = b"0\t\tf\n"; - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!(result.values.len(), 3); assert_eq!(result.values[0], Cell::I32(0)); @@ -218,10 +228,11 @@ mod tests { #[test] fn try_from_single_column() { - let schema = create_single_column_schema("value", Type::INT4); + let column_schemas = create_single_column_schema("value", Type::INT4); let row_data = b"42\n"; - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!(result.values.len(), 1); assert_eq!(result.values[0], Cell::I32(42)); @@ -229,16 +240,17 @@ mod tests { #[test] fn try_from_multiple_columns_different_types() { - let schema = vec![ - ColumnSchema::new("int_col".to_string(), Type::INT4, -1, false, false), - ColumnSchema::new("float_col".to_string(), Type::FLOAT8, -1, false, false), - ColumnSchema::new("text_col".to_string(), Type::TEXT, -1, false, false), - ColumnSchema::new("bool_col".to_string(), Type::BOOL, -1, false, false), + let column_schemas = vec![ + test_column("int_col", Type::INT4, 1, false, false), + test_column("float_col", Type::FLOAT8, 2, false, false), + test_column("text_col", Type::TEXT, 3, false, false), + test_column("bool_col", Type::BOOL, 4, false, false), ]; let row_data = b"123\t3.15\tHello World\tt\n"; - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!(result.values.len(), 4); assert_eq!(result.values[0], Cell::I32(123)); @@ -249,10 +261,10 @@ mod tests { #[test] fn try_from_not_terminated() { - let schema = create_single_column_schema("value", Type::INT4); + let column_schemas = create_single_column_schema("value", Type::INT4); let row_data = b"42"; // Missing newline - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema); + let result = parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()); assert!(result.is_err()); let err = result.unwrap_err(); @@ -262,39 +274,41 @@ mod tests { #[test] fn try_from_column_count_mismatch() { - let schema = create_test_schema(); // Expects 3 columns + let column_schemas = create_test_column_schemas(); // Expects 3 columns let row_data = b"123\tJohn\n"; // Only 2 values - this should actually fail at parsing the bool because there's no third column - let result_empty = parse_table_row_from_postgres_copy_bytes(row_data, &schema); + let result_empty = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()); assert!(result_empty.is_err()); } #[test] fn try_from_invalid_utf8() { - let schema = create_single_column_schema("value", Type::TEXT); + let column_schemas = create_single_column_schema("value", Type::TEXT); let row_data = &[0xFF, 0xFE, 0xFD, b'\n']; // Invalid UTF-8 - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema); + let result = parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()); assert!(result.is_err()); } #[test] fn try_from_parsing_error() { - let schema = create_single_column_schema("number", Type::INT4); + let column_schemas = create_single_column_schema("number", Type::INT4); let row_data = b"not_a_number\n"; - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema); + let result = parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()); assert!(result.is_err()); } #[test] fn try_from_trailing_escape() { - let schema = create_single_column_schema("data", Type::TEXT); + let column_schemas = create_single_column_schema("data", Type::TEXT); let row_data = b"Text\\\\\n"; - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!(result.values.len(), 1); assert_eq!(result.values[0], Cell::String("Text\\".to_string())); @@ -302,27 +316,31 @@ mod tests { #[test] fn try_from_null_literal_vs_null_marker() { - let schema = create_single_column_schema("value", Type::TEXT); + let column_schemas = create_single_column_schema("value", Type::TEXT); let row_data = b"\\N\n"; - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!(result.values[0], Cell::Null); let row_data = b"\\\\N\n"; - let result_test = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result_test = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!(result_test.values[0], Cell::Null); let row_data = b"\\\\A\n"; - let result_test = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result_test = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!(result_test.values[0], Cell::String("\\A".to_string())); } #[test] fn try_from_whitespace_handling() { - let schema = create_test_schema(); + let column_schemas = create_test_column_schemas(); let row_data = b"123\t John Doe \tt\n"; - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!(result.values.len(), 3); assert_eq!(result.values[0], Cell::I32(123)); @@ -332,14 +350,14 @@ mod tests { #[test] fn try_from_large_row() { - let mut schema = Vec::new(); + let mut column_schemas = Vec::new(); let mut expected_row = String::new(); - for i in 0..50 { - schema.push(ColumnSchema::new( - format!("col{i}"), + for i in 0i32..50 { + column_schemas.push(test_column( + &format!("col{i}"), Type::INT4, - -1, + i + 1, false, false, )); @@ -350,8 +368,11 @@ mod tests { } expected_row.push('\n'); - let result = - parse_table_row_from_postgres_copy_bytes(expected_row.as_bytes(), &schema).unwrap(); + let result = parse_table_row_from_postgres_copy_bytes( + expected_row.as_bytes(), + column_schemas.iter(), + ) + .unwrap(); assert_eq!(result.values.len(), 50); for i in 0..50 { @@ -361,24 +382,25 @@ mod tests { #[test] fn try_from_empty_row_with_columns() { - let schema = create_test_schema(); + let column_schemas = create_test_column_schemas(); let row_data = b"\t\t\n"; // Empty values but correct number of tabs - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema); + let result = parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()); assert!(result.is_err()); } #[test] fn try_from_postgres_delimiter_escaping() { - let schema = vec![ - ColumnSchema::new("col1".to_string(), Type::TEXT, -1, false, false), - ColumnSchema::new("col2".to_string(), Type::TEXT, -1, false, false), + let column_schemas = [ + test_column("col1", Type::TEXT, 1, false, false), + test_column("col2", Type::TEXT, 2, false, false), ]; // Postgres escapes tab characters in data with \\t let row_data = b"value\\twith\\ttabs\tnormal\\tvalue\n"; - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!( result.values[0], @@ -389,15 +411,16 @@ mod tests { #[test] fn try_from_postgres_escape_at_field_boundaries() { - let schema = vec![ - ColumnSchema::new("col1".to_string(), Type::TEXT, -1, false, false), - ColumnSchema::new("col2".to_string(), Type::TEXT, -1, false, false), - ColumnSchema::new("col3".to_string(), Type::TEXT, -1, false, false), + let column_schemas = [ + test_column("col1", Type::TEXT, 1, false, false), + test_column("col2", Type::TEXT, 2, false, false), + test_column("col3", Type::TEXT, 3, false, false), ]; // Escapes at the beginning, middle, and end of fields let row_data = b"\\tstart\tmiddle\\nvalue\tend\\r\n"; - let result = parse_table_row_from_postgres_copy_bytes(row_data, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(row_data, column_schemas.iter()).unwrap(); assert_eq!(result.values[0], Cell::String("\tstart".to_string())); assert_eq!(result.values[1], Cell::String("middle\nvalue".to_string())); @@ -406,14 +429,16 @@ mod tests { #[test] fn try_from_postgres_multibyte_with_escapes() { - let schema = create_single_column_schema("data", Type::TEXT); + let column_schemas = create_single_column_schema("data", Type::TEXT); // Unicode text with escape sequences (testing multibyte character handling) let row_data = "Hello\\t🌍\\nWorld\\r测试".as_bytes(); let mut row_with_newline = row_data.to_vec(); row_with_newline.push(b'\n'); - let result = parse_table_row_from_postgres_copy_bytes(&row_with_newline, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(&row_with_newline, column_schemas.iter()) + .unwrap(); assert_eq!( result.values[0], @@ -423,7 +448,7 @@ mod tests { #[test] fn try_from_postgres_escape_sequences() { - let schema = create_single_column_schema("data", Type::TEXT); + let column_schemas = create_single_column_schema("data", Type::TEXT); // Comprehensive test of all escape sequences that Postgres COPY TO produces let test_cases: Vec<(&[u8], &str)> = vec![ @@ -464,7 +489,8 @@ mod tests { ]; for (input, expected) in test_cases { - let result = parse_table_row_from_postgres_copy_bytes(input, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(input, column_schemas.iter()).unwrap(); assert_eq!( result.values[0], Cell::String(expected.to_string()), @@ -476,7 +502,7 @@ mod tests { #[test] fn try_from_postgres_null_handling() { - let schema = create_single_column_schema("data", Type::TEXT); + let column_schemas = create_single_column_schema("data", Type::TEXT); // Test NULL marker vs empty string vs literal \N let test_cases: Vec<(&[u8], Cell)> = vec![ @@ -486,7 +512,8 @@ mod tests { ]; for (input, expected) in test_cases { - let result = parse_table_row_from_postgres_copy_bytes(input, &schema).unwrap(); + let result = + parse_table_row_from_postgres_copy_bytes(input, column_schemas.iter()).unwrap(); assert_eq!( result.values[0], expected, diff --git a/etl/src/destination/base.rs b/etl/src/destination/base.rs index 0526895da..7bcd76986 100644 --- a/etl/src/destination/base.rs +++ b/etl/src/destination/base.rs @@ -1,4 +1,4 @@ -use etl_postgres::types::TableId; +use etl_postgres::types::ReplicatedTableSchema; use std::future::Future; use crate::error::EtlResult; @@ -23,7 +23,10 @@ pub trait Destination { /// destination table starts from a clean state before bulk loading. The operation /// should be atomic and handle cases where the table and its states may not exist, since /// truncation is unconditionally called before a table is copied. - fn truncate_table(&self, table_id: TableId) -> impl Future> + Send; + fn truncate_table( + &self, + replicated_table_schema: &ReplicatedTableSchema, + ) -> impl Future> + Send; /// Writes a batch of table rows to the destination. /// @@ -37,7 +40,7 @@ pub trait Destination { /// prepare the initial tables before starting streaming. fn write_table_rows( &self, - table_id: TableId, + replicated_table_schema: &ReplicatedTableSchema, table_rows: Vec, ) -> impl Future> + Send; @@ -49,5 +52,8 @@ pub trait Destination { /// /// Event ordering within a transaction is guaranteed, and transactions are ordered according to /// their commit time. + /// + /// Each [`Event`] that involves data changes (Insert/Update/Delete) contains its own + /// [`ReplicatedTableSchema`] which destinations can use to understand the schema for that event. fn write_events(&self, events: Vec) -> impl Future> + Send; } diff --git a/etl/src/destination/memory.rs b/etl/src/destination/memory.rs index 06976fd74..a7cf293d9 100644 --- a/etl/src/destination/memory.rs +++ b/etl/src/destination/memory.rs @@ -5,7 +5,7 @@ use tracing::info; use crate::destination::Destination; use crate::error::EtlResult; -use crate::types::{Event, TableId, TableRow}; +use crate::types::{Event, ReplicatedTableSchema, TableId, TableRow}; #[derive(Debug)] struct Inner { @@ -80,25 +80,27 @@ impl Destination for MemoryDestination { fn name() -> &'static str { "memory" } - async fn truncate_table(&self, table_id: TableId) -> EtlResult<()> { + async fn truncate_table( + &self, + replicated_table_schema: &ReplicatedTableSchema, + ) -> EtlResult<()> { // For truncation, we simulate removing all table rows for a specific table and also the events // of that table. let mut inner = self.inner.lock().await; + let table_id = replicated_table_schema.id(); info!("truncating table {}", table_id); inner.table_rows.remove(&table_id); inner.events.retain_mut(|event| { let has_table_id = event.has_table_id(&table_id); - if let Event::Truncate(event) = event + if let Event::Truncate(truncate_event) = event && has_table_id { - let Some(index) = event.rel_ids.iter().position(|&id| table_id.0 == id) else { - return true; - }; - - event.rel_ids.remove(index); - if event.rel_ids.is_empty() { + truncate_event + .truncated_tables + .retain(|s| s.id() != table_id); + if truncate_event.truncated_tables.is_empty() { return false; } @@ -113,10 +115,11 @@ impl Destination for MemoryDestination { async fn write_table_rows( &self, - table_id: TableId, + replicated_table_schema: &ReplicatedTableSchema, table_rows: Vec, ) -> EtlResult<()> { let mut inner = self.inner.lock().await; + let table_id = replicated_table_schema.id(); info!("writing a batch of {} table rows:", table_rows.len()); diff --git a/etl/src/error.rs b/etl/src/error.rs index c8aeb3d23..328f2ec64 100644 --- a/etl/src/error.rs +++ b/etl/src/error.rs @@ -96,15 +96,16 @@ pub enum ErrorKind { SourceLockTimeout, SourceOperationCanceled, - // Schema & Mapping Errors + // Schema Errors SourceSchemaError, MissingTableSchema, - MissingTableMapping, + CorruptedTableSchema, DestinationTableNameInvalid, DestinationNamespaceAlreadyExists, DestinationTableAlreadyExists, DestinationNamespaceMissing, DestinationTableMissing, + DestinationSchemaMismatch, // Data & Transformation Errors ConversionError, @@ -1032,6 +1033,40 @@ impl From for EtlErro } } +/// Converts [`etl_postgres::types::SchemaError`] to [`EtlError`] with [`ErrorKind::CorruptedTableSchema`]. +impl From for EtlError { + #[track_caller] + fn from(err: etl_postgres::types::SchemaError) -> EtlError { + match err { + etl_postgres::types::SchemaError::UnknownReplicatedColumns(columns) => { + EtlError::from_components( + ErrorKind::CorruptedTableSchema, + Cow::Borrowed( + "Received columns during replication that are not in the stored table schema", + ), + Some(Cow::Owned(format!( + "Unknown columns: {columns:?}\n\n\ + Cause: The pipeline crashed after a schema change but before reporting progress \ + back to Postgres. On restart, event streaming resumed from past events with an \ + outdated schema." + ))), + None, + ) + } + etl_postgres::types::SchemaError::InvalidSnapshotId(lsn_str) => { + EtlError::from_components( + ErrorKind::CorruptedTableSchema, + Cow::Borrowed("Invalid snapshot id"), + Some(Cow::Owned(format!( + "Failed to parse snapshot '{lsn_str}' as PgLsn." + ))), + None, + ) + } + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/etl/src/failpoints.rs b/etl/src/failpoints.rs index 7c3073a31..b4c4551a9 100644 --- a/etl/src/failpoints.rs +++ b/etl/src/failpoints.rs @@ -8,9 +8,10 @@ use fail::fail_point; use crate::bail; use crate::error::{ErrorKind, EtlResult}; -pub const START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION: &str = - "start_table_sync.before_data_sync_slot_creation"; -pub const START_TABLE_SYNC_DURING_DATA_SYNC: &str = "start_table_sync.during_data_sync"; +pub const START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION_FP: &str = + "start_table_sync.before_data_sync_slot_creation_fp"; +pub const START_TABLE_SYNC_DURING_DATA_SYNC_FP: &str = "start_table_sync.during_data_sync_fp"; +pub const SEND_STATUS_UPDATE_FP: &str = "send_status_update_fp"; /// Executes a configurable failpoint for testing error scenarios. /// @@ -45,3 +46,10 @@ pub fn etl_fail_point(name: &str) -> EtlResult<()> { Ok(()) } + +/// Returns `true` if a specific failpoint is active, `false` otherwise. +/// +/// A failpoint is considered active if it throws an error. +pub fn etl_fail_point_active(name: &str) -> bool { + etl_fail_point(name).is_err() +} diff --git a/etl/src/lib.rs b/etl/src/lib.rs index 01a4dc249..a0c730fd9 100644 --- a/etl/src/lib.rs +++ b/etl/src/lib.rs @@ -107,6 +107,7 @@ pub mod error; pub mod failpoints; pub mod macros; pub mod metrics; +pub mod migrations; pub mod pipeline; pub mod replication; pub mod state; diff --git a/etl-replicator/src/migrations.rs b/etl/src/migrations.rs similarity index 56% rename from etl-replicator/src/migrations.rs rename to etl/src/migrations.rs index 09261aae0..73f643574 100644 --- a/etl-replicator/src/migrations.rs +++ b/etl/src/migrations.rs @@ -1,33 +1,32 @@ +//! Database migration management for the ETL pipeline. +//! +//! Handles schema creation and migration execution for the ETL state store. +//! Migrations are embedded at compile time and run automatically during +//! pipeline startup. + use etl_config::shared::{IntoConnectOptions, PgConnectionConfig}; -use sqlx::{ - Executor, - postgres::{PgConnectOptions, PgPoolOptions}, -}; +use sqlx::{Executor, postgres::PgPoolOptions}; use tracing::info; -/// Number of database connections to use for the migration pool. -const NUM_POOL_CONNECTIONS: u32 = 1; - -/// Runs database migrations on the state store. +/// Runs database migrations on the source database `etl` schema. /// /// Creates a connection pool to the source database, sets up the `etl` schema, /// and applies all pending migrations. The migrations are run in the `etl` schema /// to avoid cluttering the public schema with migration metadata tables created by `sqlx`. -pub async fn migrate_state_store( +pub async fn apply_etl_migrations( connection_config: &PgConnectionConfig, ) -> Result<(), sqlx::Error> { - let options: PgConnectOptions = connection_config.with_db(); + let options = connection_config.with_db(); let pool = PgPoolOptions::new() - .max_connections(NUM_POOL_CONNECTIONS) - .min_connections(NUM_POOL_CONNECTIONS) .after_connect(|conn, _meta| { Box::pin(async move { - // Create the etl schema if it doesn't exist + // Create the `etl` schema if it doesn't exist. conn.execute("create schema if not exists etl;").await?; - // We set the search_path to etl so that the _sqlx_migrations + + // Set the `search_path` to `etl` so that the `_sqlx_migrations` // metadata table is created inside that schema instead of the public - // schema + // schema. conn.execute("set search_path = 'etl';").await?; Ok(()) @@ -36,12 +35,12 @@ pub async fn migrate_state_store( .connect_with(options) .await?; - info!("applying migrations in the state store before starting replicator"); + info!("applying etl migrations before starting pipeline"); let migrator = sqlx::migrate!("./migrations"); migrator.run(&pool).await?; - info!("migrations successfully applied in the state store"); + info!("etl migrations successfully applied"); Ok(()) } diff --git a/etl/src/pipeline.rs b/etl/src/pipeline.rs index a0e15605e..77f2ce8b9 100644 --- a/etl/src/pipeline.rs +++ b/etl/src/pipeline.rs @@ -8,7 +8,9 @@ use crate::concurrency::shutdown::{ShutdownTx, create_shutdown_channel}; use crate::destination::Destination; use crate::error::{ErrorKind, EtlResult}; use crate::metrics::register_metrics; +use crate::migrations::apply_etl_migrations; use crate::replication::client::PgReplicationClient; +use crate::replication::masks::ReplicationMasks; use crate::state::table::TableReplicationPhase; use crate::store::cleanup::CleanupStore; use crate::store::schema::SchemaStore; @@ -112,26 +114,36 @@ where /// Starts the pipeline and begins replication processing. /// - /// This method initializes the connection to Postgres, sets up table mappings and schemas, - /// creates the worker pool for table synchronization, and starts the apply worker for - /// processing replication stream events. + /// This method runs any pending migrations, initializes the connection to Postgres, + /// sets up table mappings and schemas, creates the worker pool for table synchronization, + /// and starts the apply worker for processing replication stream events. pub async fn start(&mut self) -> EtlResult<()> { info!( "starting pipeline for publication '{}' with id {}", self.config.publication_name, self.config.id ); + // Run migrations before starting the pipeline. + apply_etl_migrations(&self.config.pg_connection) + .await + .map_err(|e| { + crate::etl_error!( + ErrorKind::SourceError, + "Failed to run state store migrations", + format!("{}", e) + ) + })?; + // We create the first connection to Postgres. let replication_client = PgReplicationClient::connect(self.config.pg_connection.clone()).await?; - // We load the table mappings and schemas from the store to have them cached for quick - // access. + // We load the destination table metadata and schemas from the store to have them cached + // for quick access. // - // It's really important to load the mappings and schemas before starting the apply worker - // since downstream code relies on the assumption that the mappings and schemas are loaded - // in the cache. - self.store.load_table_mappings().await?; + // It's really important to load the metadata and schemas before starting the apply worker + // since downstream code relies on the assumption that they are loaded in the cache. + self.store.load_destination_tables_metadata().await?; self.store.load_table_schemas().await?; // We load the table states by checking the table ids of a publication and loading/creating @@ -146,8 +158,11 @@ where let table_sync_worker_permits = Arc::new(Semaphore::new(self.config.max_table_sync_workers as usize)); - // We create and start the apply worker (temporarily leaving out retries_orchestrator) - // TODO: Remove retries_orchestrator from ApplyWorker constructor + // We create the shared replication masks container that will be used by both the apply + // worker and table sync workers to track which columns are being replicated for each table. + let replication_masks = ReplicationMasks::new(); + + // We create and start the apply worker. let apply_worker = ApplyWorker::new( self.config.id, self.config.clone(), @@ -155,6 +170,7 @@ where pool.clone(), self.store.clone(), self.destination.clone(), + replication_masks, self.shutdown_tx.subscribe(), table_sync_worker_permits, ) diff --git a/etl/src/replication/apply.rs b/etl/src/replication/apply.rs index 2a085d774..e9bed5c0e 100644 --- a/etl/src/replication/apply.rs +++ b/etl/src/replication/apply.rs @@ -1,27 +1,12 @@ -use etl_config::shared::PipelineConfig; -use etl_postgres::replication::worker::WorkerType; -use etl_postgres::types::TableId; -use futures::StreamExt; -use metrics::histogram; -use postgres_replication::protocol; -use postgres_replication::protocol::{LogicalReplicationMessage, ReplicationMessage}; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use tokio::pin; -use tokio_postgres::types::PgLsn; -use tracing::log::warn; -use tracing::{debug, info}; - +use crate::bail; use crate::concurrency::shutdown::ShutdownRx; use crate::concurrency::signal::SignalRx; use crate::concurrency::stream::{TimeoutStream, TimeoutStreamResult}; use crate::conversions::event::{ - parse_event_from_begin_message, parse_event_from_commit_message, + DDL_MESSAGE_PREFIX, parse_event_from_begin_message, parse_event_from_commit_message, parse_event_from_delete_message, parse_event_from_insert_message, - parse_event_from_relation_message, parse_event_from_truncate_message, - parse_event_from_update_message, + parse_event_from_truncate_message, parse_event_from_update_message, + parse_replicated_column_names, parse_schema_change_message, }; use crate::destination::Destination; use crate::error::{ErrorKind, EtlResult}; @@ -31,11 +16,26 @@ use crate::metrics::{ ETL_TRANSACTION_SIZE, PIPELINE_ID_LABEL, WORKER_TYPE_LABEL, }; use crate::replication::client::PgReplicationClient; +use crate::replication::masks::ReplicationMasks; use crate::replication::stream::EventsStream; -use crate::state::table::{RetryPolicy, TableReplicationError}; use crate::store::schema::SchemaStore; -use crate::types::{Event, PipelineId}; -use crate::{bail, etl_error}; +use crate::types::{Event, PipelineId, RelationEvent}; +use etl_config::shared::PipelineConfig; +use etl_postgres::replication::worker::WorkerType; +use etl_postgres::types::{ + ReplicatedTableSchema, ReplicationMask, SnapshotId, TableId, TableSchema, +}; +use futures::StreamExt; +use metrics::histogram; +use postgres_replication::protocol; +use postgres_replication::protocol::{LogicalReplicationMessage, ReplicationMessage}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::pin; +use tokio_postgres::types::PgLsn; +use tracing::{debug, info, warn}; /// The minimum interval (in milliseconds) between consecutive status updates. /// @@ -152,14 +152,6 @@ pub trait ApplyLoopHook { update_state: bool, ) -> impl Future> + Send; - /// Called when a table encounters an error during replication. - /// - /// This hook handles error reporting and retry logic for failed tables. - fn mark_table_errored( - &self, - table_replication_error: TableReplicationError, - ) -> impl Future> + Send; - /// Called to check if the events processed by this apply loop should be applied in the destination. /// /// Returns `true` if the event should be applied, `false` otherwise. @@ -208,8 +200,6 @@ impl StatusUpdate { enum EndBatch { /// The batch should include the last processed event and end. Inclusive, - /// The batch should exclude the last processed event and end. - Exclusive, } /// Result returned from `handle_replication_message` and related functions @@ -244,20 +234,6 @@ struct HandleMessageResult { /// mark the table as skipped in this case. The replication event will be excluded /// from the batch. end_batch: Option, - /// Set when the table has encountered an error, and it should consequently be marked as errored - /// in the state store. - /// - /// This error is a "caught" error, meaning that it doesn't crash the apply loop, but it makes it - /// continue or gracefully stop based on the worker type that runs the loop. - /// - /// Other errors that make the apply loop fail, will be propagated to the caller and handled differently - /// based on the worker that runs the loop: - /// - Apply worker -> the error will make the apply loop crash, which will be propagated to the - /// worker and up if the worker is awaited. - /// - Table sync worker -> the error will make the apply loop crash, which will be propagated - /// to the worker, however the error will be caught and persisted via the observer mechanism - /// in place for the table sync workers. - table_replication_error: Option, /// The action that this event should have on the loop. /// /// Note that this action might be overridden by operations that are happening when a batch is flushed @@ -284,18 +260,6 @@ impl HandleMessageResult { ..Default::default() } } - - /// Creates a result that excludes the current event and requests batch termination. - /// - /// Used when the current message triggers a recoverable table-level error. - /// The error is propagated to be handled by the apply loop hook. - fn finish_batch_and_exclude_event(error: TableReplicationError) -> Self { - Self { - end_batch: Some(EndBatch::Exclusive), - table_replication_error: Some(error), - ..Default::default() - } - } } /// A shared state that is used throughout the apply loop to track progress. @@ -329,14 +293,22 @@ struct ApplyLoopState { /// transaction boundary is found. If not found, the process will continue until it is killed via /// a `SIGKILL`. shutdown_discarded: bool, + /// The current schema snapshot being tracked. + /// + /// This is updated when DDL messages are processed, tracking the latest schema version. + current_schema_snapshot_id: SnapshotId, } impl ApplyLoopState { - /// Creates a new [`ApplyLoopState`] with initial status update and event batch. + /// Creates a new [`ApplyLoopState`] with initial status update, event batch, and schema snapshot. /// /// This constructor initializes the state tracking structure used throughout /// the apply loop to maintain replication progress and coordinate batching. - fn new(next_status_update: StatusUpdate, events_batch: Vec) -> Self { + fn new( + next_status_update: StatusUpdate, + events_batch: Vec, + current_schema_snapshot_id: SnapshotId, + ) -> Self { Self { last_commit_end_lsn: None, remote_final_lsn: None, @@ -346,9 +318,15 @@ impl ApplyLoopState { current_tx_begin_ts: None, current_tx_events: 0, shutdown_discarded: false, + current_schema_snapshot_id, } } + /// Updates the current schema snapshot to a new value. + fn update_schema_snapshot_id(&mut self, snapshot_id: SnapshotId) { + self.current_schema_snapshot_id = snapshot_id; + } + /// Updates the last commit end LSN to track transaction boundaries. /// /// This method maintains the highest commit end LSN seen, which represents @@ -490,6 +468,7 @@ pub async fn start_apply_loop( schema_store: S, destination: D, hook: T, + replication_masks: ReplicationMasks, mut shutdown_rx: ShutdownRx, mut force_syncing_tables_rx: Option, ) -> EtlResult @@ -514,6 +493,12 @@ where return Ok(result); } + // Initialize the current schema snapshot from the start LSN (this is fine to do even if the + // start lsn is not like any of the existing snapshot ids since the system is designed to return + // the biggest snapshot id <= the current snapshot id). + // Schemas will be loaded on-demand when get_table_schema is called. + let current_schema_snapshot_id: SnapshotId = start_lsn.into(); + // The first status update is defaulted from the start lsn since at this point we haven't // processed anything. let first_status_update = StatusUpdate { @@ -551,6 +536,7 @@ where let mut state = ApplyLoopState::new( first_status_update, Vec::with_capacity(config.batch.max_size), + current_schema_snapshot_id, ); // Main event processing loop - continues until shutdown or fatal error @@ -613,6 +599,7 @@ where &schema_store, &destination, &hook, + &replication_masks, config.batch.max_size, pipeline_id ) @@ -706,6 +693,7 @@ async fn handle_replication_message_with_timeout( schema_store: &S, destination: &D, hook: &T, + replication_masks: &ReplicationMasks, max_batch_size: usize, pipeline_id: PipelineId, ) -> EtlResult @@ -722,6 +710,7 @@ where message?, schema_store, hook, + replication_masks, pipeline_id, ) .await?; @@ -767,24 +756,6 @@ where .await?; } - // If we have a caught table error, we want to mark the table as errored. - // - // Note that if we have a failure after marking a table as errored and events will - // be reprocessed, even the events before the failure will be skipped. - // - // Usually in the apply loop, errors are propagated upstream and handled based on if - // we are in a table sync worker or apply worker, however we have an edge case (for - // relation messages that change the schema) where we want to mark a table as errored - // manually, not propagating the error outside the loop, which is going to be handled - // differently based on the worker: - // - Apply worker -> will continue the loop skipping the table. - // - Table sync worker -> will stop the work (as if it had a normal uncaught error). - // Ideally we would get rid of this since it's an anomalous case which adds unnecessary - // complexity. - if let Some(error) = result.table_replication_error { - action = action.merge(hook.mark_table_errored(error).await?); - } - // Once the batch is sent, we have the guarantee that all events up to this point have // been durably persisted, so we do synchronization. // @@ -923,6 +894,32 @@ where .await } +/// Retrieves a table schema from the schema store by table ID and snapshot. +/// +/// Returns an error if the schema is not found in the store. +async fn get_table_schema( + schema_store: &S, + table_id: &TableId, + snapshot_id: SnapshotId, +) -> EtlResult> +where + S: SchemaStore, +{ + schema_store + .get_table_schema(table_id, snapshot_id) + .await? + .ok_or_else(|| { + crate::etl_error!( + ErrorKind::MissingTableSchema, + "Table schema not found", + format!( + "Table schema for table {} at snapshot {} not found", + table_id, snapshot_id + ) + ) + }) +} + /// Dispatches replication protocol messages to appropriate handlers. /// /// This function serves as the main routing mechanism for Postgres replication @@ -938,6 +935,7 @@ async fn handle_replication_message( message: ReplicationMessage, schema_store: &S, hook: &T, + replication_masks: &ReplicationMasks, pipeline_id: PipelineId, ) -> EtlResult where @@ -965,6 +963,7 @@ where message.into_data(), schema_store, hook, + replication_masks, pipeline_id, ) .await @@ -1007,6 +1006,7 @@ async fn handle_logical_replication_message( message: LogicalReplicationMessage, schema_store: &S, hook: &T, + replication_masks: &ReplicationMasks, pipeline_id: PipelineId, ) -> EtlResult where @@ -1023,29 +1023,75 @@ where handle_commit_message(state, start_lsn, commit_body, hook, pipeline_id).await } LogicalReplicationMessage::Relation(relation_body) => { - handle_relation_message(state, start_lsn, relation_body, schema_store, hook).await + handle_relation_message( + state, + start_lsn, + relation_body, + hook, + schema_store, + replication_masks, + ) + .await } LogicalReplicationMessage::Insert(insert_body) => { - handle_insert_message(state, start_lsn, insert_body, hook, schema_store).await + handle_insert_message( + state, + start_lsn, + insert_body, + hook, + schema_store, + replication_masks, + ) + .await } LogicalReplicationMessage::Update(update_body) => { - handle_update_message(state, start_lsn, update_body, hook, schema_store).await + handle_update_message( + state, + start_lsn, + update_body, + hook, + schema_store, + replication_masks, + ) + .await } LogicalReplicationMessage::Delete(delete_body) => { - handle_delete_message(state, start_lsn, delete_body, hook, schema_store).await + handle_delete_message( + state, + start_lsn, + delete_body, + hook, + schema_store, + replication_masks, + ) + .await } LogicalReplicationMessage::Truncate(truncate_body) => { - handle_truncate_message(state, start_lsn, truncate_body, hook).await + handle_truncate_message( + state, + start_lsn, + truncate_body, + hook, + schema_store, + replication_masks, + ) + .await } - LogicalReplicationMessage::Origin(_) => { - debug!("received unsupported ORIGIN message"); - Ok(HandleMessageResult::default()) + LogicalReplicationMessage::Message(message_body) => { + handle_logical_message( + state, + start_lsn, + message_body, + hook, + schema_store, + replication_masks, + ) + .await } - LogicalReplicationMessage::Type(_) => { - debug!("received unsupported TYPE message"); - Ok(HandleMessageResult::default()) + message => { + debug!("received unsupported message: {:?}", message); + Ok(HandleMessageResult::no_event()) } - _ => Ok(HandleMessageResult::default()), } } @@ -1186,25 +1232,35 @@ where Ok(result) } -/// Handles Postgres RELATION messages that describe table schemas. +/// Handles Postgres RELATION messages that describe the schema of data in the replication stream. +/// +/// RELATION messages are sent by Postgres before any DML events for a table, describing which +/// columns are being replicated. This function extracts the replicated column names and builds +/// a [`ReplicationMask`] which is stored in the shared [`ReplicationMasks`] container. /// -/// This function processes schema definition messages by validating that table -/// schemas haven't changed unexpectedly during replication. Schema stability -/// is critical for maintaining data consistency between source and destination. +/// The mask is built by matching the replicated column names from the RELATION message against +/// the table schema from the schema store. /// -/// When schema changes are detected, the function creates appropriate error -/// conditions and signals batch termination to prevent processing of events -/// with mismatched schemas. This protection mechanism ensures data integrity -/// by failing fast on incompatible schema evolution. +/// Emits an [`Event::Relation`] containing the [`ReplicatedTableSchema`] to notify downstream +/// consumers about which columns are being replicated for this table. +/// +/// # Errors +/// +/// Returns [`ErrorKind::CorruptedTableSchema`] if the replicated columns in the `Relation` +/// message do not match the stored table schema. This can occur when DDL changes happen +/// but the pipeline crashes before acknowledging progress, causing the stored schema to +/// be out of sync with the source. Manual intervention is required to update the stored +/// schema before the pipeline can continue. async fn handle_relation_message( - state: &mut ApplyLoopState, + state: &ApplyLoopState, start_lsn: PgLsn, message: &protocol::RelationBody, - schema_store: &S, hook: &T, + schema_store: &S, + replication_masks: &ReplicationMasks, ) -> EtlResult where - S: SchemaStore + Clone + Send + 'static, + S: SchemaStore, T: ApplyLoopHook, { let Some(remote_final_lsn) = state.remote_final_lsn else { @@ -1217,6 +1273,7 @@ where let table_id = TableId::new(message.rel_id()); + // Skip relation messages for tables we should not apply changes to. if !hook .should_apply_changes(table_id, remote_final_lsn) .await? @@ -1224,38 +1281,42 @@ where return Ok(HandleMessageResult::no_event()); } - // If no table schema is found, it means that something went wrong since we should have schemas - // ready before starting the apply loop. - let existing_table_schema = - schema_store - .get_table_schema(&table_id) - .await? - .ok_or_else(|| { - etl_error!( - ErrorKind::MissingTableSchema, - "Table schema not found in cache", - format!("Table schema for table {} not found in cache", table_id) - ) - })?; + let replicated_columns = parse_replicated_column_names(message)?; - // Convert event from the protocol message. - let event = parse_event_from_relation_message(start_lsn, remote_final_lsn, message)?; - - // We compare the table schema from the relation message with the existing schema (if any). - // The purpose of this comparison is that we want to throw an error and stop the processing - // of any table that incurs in a schema change after the initial table sync is performed. - if !existing_table_schema.partial_eq(&event.table_schema) { - let error = TableReplicationError::with_solution( - table_id, - format!("The schema for table {table_id} has changed during streaming"), - "ETL doesn't support schema changes at this point in time, rollback the schema", - RetryPolicy::ManualRetry, - ); + let table_schema = + get_table_schema(schema_store, &table_id, state.current_schema_snapshot_id).await?; - return Ok(HandleMessageResult::finish_batch_and_exclude_event(error)); - } + info!( + table_id = %table_id, + replicated_columns = ?replicated_columns, + "received relation message, building replication mask" + ); - Ok(HandleMessageResult::return_event(Event::Relation(event))) + // Build the replication mask by validating that all replicated columns exist in the schema. + // If validation fails, it indicates the stored schema is out of sync with the source + // database and requires manual intervention to update the stored schema. + // + // TODO: Currently we fail and require manual intervention. In the future, we might want to + // handle this case automatically (e.g., by rebuilding the schema from the source) if this + // error becomes common. + let replication_mask = ReplicationMask::try_build(&table_schema, &replicated_columns)?; + + replication_masks + .set(table_id, replication_mask.clone()) + .await; + + // Build the ReplicatedTableSchema and emit a Relation event. + let replicated_table_schema = ReplicatedTableSchema::from_mask(table_schema, replication_mask); + + let relation_event = RelationEvent { + start_lsn, + commit_lsn: remote_final_lsn, + replicated_table_schema, + }; + + Ok(HandleMessageResult::return_event(Event::Relation( + relation_event, + ))) } /// Handles Postgres INSERT messages for row insertion events. @@ -1265,6 +1326,7 @@ async fn handle_insert_message( message: &protocol::InsertBody, hook: &T, schema_store: &S, + replication_masks: &ReplicationMasks, ) -> EtlResult where S: SchemaStore + Clone + Send + 'static, @@ -1278,16 +1340,30 @@ where ); }; + let table_id = TableId::new(message.rel_id()); + if !hook - .should_apply_changes(TableId::new(message.rel_id()), remote_final_lsn) + .should_apply_changes(table_id, remote_final_lsn) .await? { return Ok(HandleMessageResult::no_event()); } + let replicated_table_schema = get_replicated_table_schema( + &table_id, + state.current_schema_snapshot_id, + schema_store, + replication_masks, + ) + .await?; + // Convert event from the protocol message. - let event = - parse_event_from_insert_message(schema_store, start_lsn, remote_final_lsn, message).await?; + let event = parse_event_from_insert_message( + replicated_table_schema, + start_lsn, + remote_final_lsn, + message, + )?; Ok(HandleMessageResult::return_event(Event::Insert(event))) } @@ -1299,6 +1375,7 @@ async fn handle_update_message( message: &protocol::UpdateBody, hook: &T, schema_store: &S, + replication_masks: &ReplicationMasks, ) -> EtlResult where S: SchemaStore + Clone + Send + 'static, @@ -1312,16 +1389,30 @@ where ); }; + let table_id = TableId::new(message.rel_id()); + if !hook - .should_apply_changes(TableId::new(message.rel_id()), remote_final_lsn) + .should_apply_changes(table_id, remote_final_lsn) .await? { return Ok(HandleMessageResult::no_event()); } + let replicated_table_schema = get_replicated_table_schema( + &table_id, + state.current_schema_snapshot_id, + schema_store, + replication_masks, + ) + .await?; + // Convert event from the protocol message. - let event = - parse_event_from_update_message(schema_store, start_lsn, remote_final_lsn, message).await?; + let event = parse_event_from_update_message( + replicated_table_schema, + start_lsn, + remote_final_lsn, + message, + )?; Ok(HandleMessageResult::return_event(Event::Update(event))) } @@ -1333,6 +1424,7 @@ async fn handle_delete_message( message: &protocol::DeleteBody, hook: &T, schema_store: &S, + replication_masks: &ReplicationMasks, ) -> EtlResult where S: SchemaStore + Clone + Send + 'static, @@ -1346,16 +1438,30 @@ where ); }; + let table_id = TableId::new(message.rel_id()); + if !hook - .should_apply_changes(TableId::new(message.rel_id()), remote_final_lsn) + .should_apply_changes(table_id, remote_final_lsn) .await? { return Ok(HandleMessageResult::no_event()); } + let replicated_table_schema = get_replicated_table_schema( + &table_id, + state.current_schema_snapshot_id, + schema_store, + replication_masks, + ) + .await?; + // Convert event from the protocol message. - let event = - parse_event_from_delete_message(schema_store, start_lsn, remote_final_lsn, message).await?; + let event = parse_event_from_delete_message( + replicated_table_schema, + start_lsn, + remote_final_lsn, + message, + )?; Ok(HandleMessageResult::return_event(Event::Delete(event))) } @@ -1366,13 +1472,16 @@ where /// ensuring transaction context, and filtering the affected table list based on /// hook decisions. Since TRUNCATE can affect multiple tables simultaneously, /// it evaluates each table individually. -async fn handle_truncate_message( +async fn handle_truncate_message( state: &mut ApplyLoopState, start_lsn: PgLsn, message: &protocol::TruncateBody, hook: &T, + schema_store: &S, + replication_masks: &ReplicationMasks, ) -> EtlResult where + S: SchemaStore + Clone + Send + 'static, T: ApplyLoopHook, { let Some(remote_final_lsn) = state.remote_final_lsn else { @@ -1383,24 +1492,161 @@ where ); }; - // We collect only the relation ids for which we are allow to apply changes, thus in this case - // the truncation. - let mut rel_ids = Vec::with_capacity(message.rel_ids().len()); - for &table_id in message.rel_ids().iter() { + // We collect the replicated schemas for tables we are allowed to apply changes to. + let mut truncated_tables = Vec::with_capacity(message.rel_ids().len()); + for &rel_id in message.rel_ids().iter() { + let table_id = TableId::new(rel_id); + if hook - .should_apply_changes(TableId::new(table_id), remote_final_lsn) + .should_apply_changes(table_id, remote_final_lsn) .await? { - rel_ids.push(table_id) + let replicated_table_schema = get_replicated_table_schema( + &table_id, + state.current_schema_snapshot_id, + schema_store, + replication_masks, + ) + .await?; + truncated_tables.push(replicated_table_schema); } } - // If nothing to apply, skip conversion entirely - if rel_ids.is_empty() { + + // If nothing to apply, skip conversion entirely. + if truncated_tables.is_empty() { return Ok(HandleMessageResult::no_event()); } // Convert event from the protocol message. - let event = parse_event_from_truncate_message(start_lsn, remote_final_lsn, message, rel_ids); + let event = + parse_event_from_truncate_message(start_lsn, remote_final_lsn, message, truncated_tables); Ok(HandleMessageResult::return_event(Event::Truncate(event))) } + +/// Handles a logical replication message. +/// +/// Processes `pg_logical_emit_message` messages from the replication stream. +/// Handles DDL schema change messages with the `supabase_etl_ddl` prefix by +/// storing the new schema version with the start_lsn as the snapshot_id. +async fn handle_logical_message( + state: &mut ApplyLoopState, + start_lsn: PgLsn, + message: &protocol::MessageBody, + hook: &T, + schema_store: &S, + replication_masks: &ReplicationMasks, +) -> EtlResult +where + S: SchemaStore, + T: ApplyLoopHook, +{ + // If the prefix is unknown, we don't want to process it. + let prefix = message.prefix()?; + if prefix != DDL_MESSAGE_PREFIX { + info!( + prefix = %prefix, + "received logical message with unknown prefix, discarding" + ); + + return Ok(HandleMessageResult::no_event()); + } + + // DDL messages must be transactional (emitted with transactional=true in pg_logical_emit_message). + // This ensures they are part of a transaction and have a valid commit LSN for ordering. + let Some(remote_final_lsn) = state.remote_final_lsn else { + bail!( + ErrorKind::InvalidState, + "Invalid transaction state", + "DDL schema change messages must be transactional (transactional=true). \ + Received a DDL message outside of a transaction boundary." + ); + }; + + let content = message.content()?; + let Ok(schema_change_message) = parse_schema_change_message(content) else { + bail!( + ErrorKind::InvalidData, + "Failed to parse DDL schema change message", + "Invalid JSON format in schema change message content" + ); + }; + + let table_id = TableId::new(schema_change_message.table_id as u32); + if !hook + .should_apply_changes(table_id, remote_final_lsn) + .await? + { + return Ok(HandleMessageResult::no_event()); + } + + info!( + table_id = schema_change_message.table_id, + table_name = %schema_change_message.table_name, + schema_name = %schema_change_message.schema_name, + event = %schema_change_message.event, + columns = schema_change_message.columns.len(), + "received ddl schema change message" + ); + + // Build table schema from DDL message with start_lsn as the snapshot_id. + let snapshot_id: SnapshotId = start_lsn.into(); + let table_schema = schema_change_message.into_table_schema(snapshot_id); + + // Store the new schema version in the store. + schema_store.store_table_schema(table_schema).await?; + + // Update the current schema snapshot in the state. + state.update_schema_snapshot_id(snapshot_id); + + // Invalidate the cached replication mask for this table. While PostgreSQL guarantees that + // a RELATION message will be sent before any DML events after a schema change, we + // proactively invalidate the mask to ensure consistency. + replication_masks.remove(&table_id).await; + + let table_id: u32 = table_id.into(); + info!( + table_id = table_id, + %snapshot_id, + "stored new schema version from ddl message" + ); + + Ok(HandleMessageResult::no_event()) +} + +/// Retrieves a [`ReplicatedTableSchema`] for the given table at the specified snapshot. +/// +/// This function combines the table schema from the schema store with the replication mask +/// from the shared [`ReplicationMasks`] to create a [`ReplicatedTableSchema`]. +/// +/// # Errors +/// +/// Returns an error if no replication mask is found for this table in the shared masks +/// container, or if the table schema is not found in the schema store. +async fn get_replicated_table_schema( + table_id: &TableId, + snapshot_id: SnapshotId, + schema_store: &S, + replication_masks: &ReplicationMasks, +) -> EtlResult +where + S: SchemaStore + Clone + Send + 'static, +{ + let Some(replication_mask) = replication_masks.get(table_id).await else { + bail!( + ErrorKind::InvalidState, + "Missing replication mask", + format!( + "No replication mask found for table {}, this event can't be processed", + table_id + ) + ); + }; + + let table_schema = get_table_schema(schema_store, table_id, snapshot_id).await?; + + Ok(ReplicatedTableSchema::from_mask( + table_schema, + replication_mask, + )) +} diff --git a/etl/src/replication/client.rs b/etl/src/replication/client.rs index 7db618685..e6ab59f0e 100644 --- a/etl/src/replication/client.rs +++ b/etl/src/replication/client.rs @@ -2,16 +2,16 @@ use crate::error::{ErrorKind, EtlResult}; use crate::utils::tokio::MakeRustlsConnect; use crate::{bail, etl_error}; use etl_config::shared::{IntoConnectOptions, PgConnectionConfig}; +use etl_postgres::below_version; use etl_postgres::replication::extract_server_version; use etl_postgres::types::convert_type_oid_to_type; use etl_postgres::types::{ColumnSchema, TableId, TableName, TableSchema}; use etl_postgres::version::POSTGRES_15; -use etl_postgres::{below_version, requires_version}; use pg_escape::{quote_identifier, quote_literal}; use postgres_replication::LogicalReplicationStream; use rustls::ClientConfig; use rustls::pki_types::{CertificateDer, pem::PemObject}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::num::NonZeroI32; use std::sync::Arc; @@ -117,32 +117,25 @@ impl PgReplicationSlotTransaction { pub async fn get_table_schemas( &self, table_ids: &[TableId], - publication_name: Option<&str>, ) -> EtlResult> { - self.client - .get_table_schemas(table_ids, publication_name) - .await + self.client.get_table_schemas(table_ids).await } /// Retrieves the schema information for the supplied table. /// /// If a publication is specified, only columns included in that publication /// will be returned. - pub async fn get_table_schema( - &self, - table_id: TableId, - publication: Option<&str>, - ) -> EtlResult { - self.client.get_table_schema(table_id, publication).await + pub async fn get_table_schema(&self, table_id: TableId) -> EtlResult { + self.client.get_table_schema(table_id).await } /// Creates a COPY stream for reading data from the specified table. /// /// The stream will include only the columns specified in `column_schemas`. - pub async fn get_table_copy_stream( + pub async fn get_table_copy_stream<'a>( &self, table_id: TableId, - column_schemas: &[ColumnSchema], + column_schemas: impl Iterator, publication_name: Option<&str>, ) -> EtlResult { self.client @@ -150,6 +143,22 @@ impl PgReplicationSlotTransaction { .await } + /// Retrieves the names of columns being replicated for a table in a publication. + /// + /// Returns a HashSet containing the names of columns that are included in the publication + /// for the specified table. If no publication is specified, returns all column names from + /// the table schema. + pub async fn get_replicated_column_names( + &self, + table_id: TableId, + table_schema: &TableSchema, + publication_name: &str, + ) -> EtlResult> { + self.client + .get_replicated_column_names(table_id, table_schema, publication_name) + .await + } + /// Commits the current transaction. pub async fn commit(self) -> EtlResult<()> { self.client.commit_tx().await @@ -161,14 +170,6 @@ impl PgReplicationSlotTransaction { } } -/// Result of building publication filter SQL components. -struct PublicationFilter { - /// CTEs to include in the WITH clause (empty string if no publication filtering). - ctes: String, - /// Predicate to include in the WHERE clause (empty string if no publication filtering). - predicate: String, -} - /// A client for interacting with Postgres's logical replication features. /// /// This client provides methods for creating replication slots, managing transactions, @@ -478,6 +479,11 @@ impl PgReplicationClient { /// For partitioned tables with `publish_via_partition_root=true`, this returns only the parent /// table OID. The query uses a recursive CTE to walk up the partition inheritance hierarchy /// and identify root tables that have no parent themselves. + /// + /// # Errors + /// + /// Returns [`ErrorKind::ConfigError`] if the publication contains no tables. This typically + /// indicates a misconfigured publication that won't replicate any data. pub async fn get_publication_table_ids( &self, publication_name: &str, @@ -496,9 +502,7 @@ impl PgReplicationClient { hierarchy(relid) as ( -- Start with published tables select oid from pub_tables - union - -- Recursively find parent tables in inheritance hierarchy select i.inhparent from pg_inherits i @@ -514,15 +518,28 @@ impl PgReplicationClient { pub = quote_literal(publication_name) ); - let mut roots = vec![]; - for msg in self.client.simple_query(&query).await? { - if let SimpleQueryMessage::Row(row) = msg { + let mut root_tables = vec![]; + for row in self.client.simple_query(&query).await? { + if let SimpleQueryMessage::Row(row) = row { let table_id = Self::get_row_value::(&row, "oid", "pg_class").await?; - roots.push(table_id); + root_tables.push(table_id); } } - Ok(roots) + if root_tables.is_empty() { + bail!( + ErrorKind::ConfigError, + "Publication has no tables", + format!( + "Publication '{}' does not contain any tables. Ensure the publication \ + is configured with tables using FOR TABLE, FOR ALL TABLES, or \ + FOR TABLES IN SCHEMA.", + publication_name + ) + ); + } + + Ok(root_tables) } /// Starts a logical replication stream from the specified publication and slot. @@ -541,8 +558,8 @@ impl PgReplicationClient { // Do not convert the query or the options to lowercase, see comment in `create_slot_internal`. let options = format!( - r#"("proto_version" '1', "publication_names" {})"#, - quote_literal(quote_identifier(publication_name).as_ref()), + r#"("proto_version" '1', "publication_names" {}, "messages" 'true')"#, + quote_literal(quote_identifier(publication_name).as_ref()) ); let query = format!( @@ -652,13 +669,12 @@ impl PgReplicationClient { async fn get_table_schemas( &self, table_ids: &[TableId], - publication_name: Option<&str>, ) -> EtlResult> { let mut table_schemas = HashMap::new(); // TODO: consider if we want to fail when at least one table was missing or not. for table_id in table_ids { - let table_schema = self.get_table_schema(*table_id, publication_name).await?; + let table_schema = self.get_table_schema(*table_id).await?; // TODO: this warning and skipping should not happen in this method, // but rather higher in the stack. @@ -680,19 +696,11 @@ impl PgReplicationClient { /// /// If a publication is specified, only columns included in that publication /// will be returned. - async fn get_table_schema( - &self, - table_id: TableId, - publication: Option<&str>, - ) -> EtlResult { + async fn get_table_schema(&self, table_id: TableId) -> EtlResult { let table_name = self.get_table_name(table_id).await?; - let column_schemas = self.get_column_schemas(table_id, publication).await?; + let column_schemas = self.get_column_schemas(table_id).await?; - Ok(TableSchema { - name: table_name, - id: table_id, - column_schemas, - }) + Ok(TableSchema::new(table_id, table_name, column_schemas)) } /// Loads the table name and schema information for a given table OID. @@ -727,168 +735,51 @@ impl PgReplicationClient { ); } - /// Builds SQL fragments for filtering columns based on publication settings. - /// - /// Returns CTEs and predicates that filter columns according to: - /// - Postgres 15+: Column-level filtering using `prattrs` - /// - Postgres 14 and earlier: Table-level filtering only - /// - No publication: No filtering (empty strings) - fn build_publication_filter_sql( - &self, - table_id: TableId, - publication_name: Option<&str>, - ) -> PublicationFilter { - let Some(publication_name) = publication_name else { - return PublicationFilter { - ctes: String::new(), - predicate: String::new(), - }; - }; - - // Postgres 15+ supports column-level filtering via prattrs - if requires_version!(self.server_version, POSTGRES_15) { - return PublicationFilter { - ctes: format!( - "pub_info as ( - select p.oid as puboid, p.puballtables, r.prattrs - from pg_publication p - left join pg_publication_rel r on r.prpubid = p.oid and r.prrelid = {table_id} - where p.pubname = {publication} - ), - pub_attrs as ( - select unnest(prattrs) as attnum - from pub_info - where prattrs is not null - ), - pub_schema as ( - select 1 as exists_in_schema_pub - from pub_info - join pg_publication_namespace pn on pn.pnpubid = pub_info.puboid - join pg_class c on c.relnamespace = pn.pnnspid - where c.oid = {table_id} - ),", - publication = quote_literal(publication_name), - ), - predicate: "and ( - (select puballtables from pub_info) = true - or (select count(*) from pub_schema) > 0 - or ( - case (select count(*) from pub_attrs) - when 0 then true - else (a.attnum in (select attnum from pub_attrs)) - end - ) - )" - .to_string(), - }; - } - - // Postgres 14 and earlier: table-level filtering only - PublicationFilter { - ctes: format!( - "pub_info as ( - select p.puballtables - from pg_publication p - where p.pubname = {publication} - ), - pub_table as ( - select 1 as exists_in_pub - from pg_publication_rel r - join pg_publication p on r.prpubid = p.oid - where p.pubname = {publication} - and r.prrelid = {table_id} - ),", - publication = quote_literal(publication_name), - ), - predicate: "and ((select puballtables from pub_info) = true or (select count(*) from pub_table) > 0)".to_string(), - } - } - /// Retrieves schema information for all columns in a table. /// - /// If a publication is specified, only columns included in that publication - /// will be returned. - async fn get_column_schemas( - &self, - table_id: TableId, - publication: Option<&str>, - ) -> EtlResult> { - // Build publication filter CTEs and predicates based on Postgres version. - let publication_filter = self.build_publication_filter_sql(table_id, publication); - + /// All columns are initially loaded with `replicated = true`. The replication status + /// will be updated later when relation messages are received during CDC streaming. + async fn get_column_schemas(&self, table_id: TableId) -> EtlResult> { let column_info_query = format!( r#" - with {publication_ctes} - -- Find the direct parent table (for child partitions) - direct_parent as ( - select i.inhparent as parent_oid - from pg_inherits i - where i.inhrelid = {table_id} - limit 1 - ), - -- Extract primary key column names from the parent table - parent_pk_cols as ( - select array_agg(a.attname order by x.n) as pk_column_names - from pg_constraint con - join unnest(con.conkey) with ordinality as x(attnum, n) on true - join pg_attribute a on a.attrelid = con.conrelid and a.attnum = x.attnum - join direct_parent dp on con.conrelid = dp.parent_oid - where con.contype = 'p' - group by con.conname - ) select - a.attname, - a.atttypid, - a.atttypmod, - a.attnotnull, - case - -- Check if column has a direct primary key index - when coalesce(i.indisprimary, false) = true then true - -- Check if column name matches parent's primary key (for partitions) - when exists ( - select 1 - from parent_pk_cols pk - where a.attname = any(pk.pk_column_names) - ) then true - else false - end as primary - from pg_attribute a - left join pg_index i - on a.attrelid = i.indrelid - and a.attnum = any(i.indkey) - and i.indisprimary = true - where a.attnum > 0::int2 - and not a.attisdropped - and a.attgenerated = '' - and a.attrelid = {table_id} - {publication_predicate} - order by a.attnum - "#, - publication_ctes = publication_filter.ctes, - publication_predicate = publication_filter.predicate, + name, + type_oid, + type_modifier, + ordinal_position, + primary_key_ordinal_position, + nullable + from etl.describe_table_schema({table_id}) + order by ordinal_position + "# ); let mut column_schemas = vec![]; for message in self.client.simple_query(&column_info_query).await? { if let SimpleQueryMessage::Row(row) = message { - let name = Self::get_row_value::(&row, "attname", "pg_attribute").await?; - let type_oid = Self::get_row_value::(&row, "atttypid", "pg_attribute").await?; - let modifier = - Self::get_row_value::(&row, "atttypmod", "pg_attribute").await?; - let nullable = - Self::get_row_value::(&row, "attnotnull", "pg_attribute").await? == "f"; - let primary = - Self::get_row_value::(&row, "primary", "pg_index").await? == "t"; + let name = Self::get_row_value::(&row, "name", "pg_attribute").await?; + let type_oid = Self::get_row_value::(&row, "type_oid", "pg_type").await?; + let type_modifier = + Self::get_row_value::(&row, "type_modifier", "pg_attribute").await?; + let ordinal_position = + Self::get_row_value::(&row, "ordinal_position", "pg_attribute").await?; + let primary_key_ordinal_position: Option = row + .try_get("primary_key_ordinal_position")? + .and_then(|s: &str| s.parse().ok()); + let nullable_str = + Self::get_row_value::(&row, "nullable", "pg_attribute").await?; + let nullable = nullable_str == "t" || nullable_str == "true"; let typ = convert_type_oid_to_type(type_oid); - column_schemas.push(ColumnSchema { + column_schemas.push(ColumnSchema::new( name, typ, - modifier, + type_modifier, + ordinal_position, + primary_key_ordinal_position, nullable, - primary, - }) + )) } } @@ -939,17 +830,113 @@ impl PgReplicationClient { Ok(None) } + /// Retrieves the names of columns being replicated for a table in a publication. + /// + /// Returns a HashSet containing the names of columns that are included in the publication + /// for the specified table. If the PostgreSQL version is below 15 (which doesn't support + /// column filtering), returns all column names from the table schema. + /// + /// For publications created with `FOR ALL TABLES` or `FOR TABLES IN SCHEMA`, all columns + /// are replicated since these publication types don't support column filtering. + /// + /// # Errors + /// + /// Returns [`ErrorKind::ConfigError`] if the table is not included in the publication. + /// This prevents silently syncing tables that won't receive CDC updates. + /// + /// This method should be called in the same transaction as describe_table_schema to ensure + /// consistency during initial table sync. + pub async fn get_replicated_column_names( + &self, + table_id: TableId, + table_schema: &TableSchema, + publication_name: &str, + ) -> EtlResult> { + // Column filtering in publications was added in Postgres 15. For earlier versions, + // all columns are replicated. + if below_version!(self.server_version, POSTGRES_15) { + return Ok(table_schema + .column_schemas + .iter() + .map(|cs| cs.name.clone()) + .collect()); + } + + // Query pg_publication_tables using unnest() to properly decode the attnames array. + // This correctly handles column names containing special characters (spaces, commas, + // quotes) that would break naive string parsing. + // + // The query returns two columns: + // - table_in_publication: true if the table is in the publication + // - column_name: the column name (NULL if attnames is NULL, meaning all columns) + // + // When attnames is NULL (FOR ALL TABLES or FOR TABLES IN SCHEMA publications), + // all columns are replicated. When attnames has values, only those columns are + // replicated. If no rows are returned, the table is not in the publication. + let column_query = format!( + "select true as table_in_publication, u.column_name + from pg_publication_tables pt + left join lateral unnest(pt.attnames) as u(column_name) on true + join pg_namespace n on n.nspname = pt.schemaname + join pg_class c on c.relnamespace = n.oid and c.relname = pt.tablename + where pt.pubname = {} and c.oid = {};", + quote_literal(publication_name), + table_id, + ); + + let rows = self.client.simple_query(&column_query).await?; + let mut column_names: HashSet = HashSet::new(); + let mut table_in_publication = false; + + for row in rows { + if let SimpleQueryMessage::Row(row) = row { + // If we got any row, the table is in the publication. + table_in_publication = true; + + if let Some(column_name) = row.try_get::<&str>("column_name")? { + column_names.insert(column_name.to_string()); + } + } + } + + // If the table is not in the publication, error out. This prevents silently syncing + // tables that won't receive events, leaving the destination stale. + if !table_in_publication { + bail!( + ErrorKind::ConfigError, + "Table not in publication", + format!( + "Table '{}' is not included in publication '{}'. \ + The table must be added to the publication to receive events.", + table_schema.name, publication_name + ) + ); + } + + // If column_names is empty but table is in publication, it means attnames was NULL, + // which indicates all columns are replicated (FOR ALL TABLES or FOR TABLES IN SCHEMA). + if column_names.is_empty() { + return Ok(table_schema + .column_schemas + .iter() + .map(|cs| cs.name.clone()) + .collect()); + } + + Ok(column_names) + } + /// Creates a COPY stream for reading data from a table using its OID. /// - /// The stream will include only the specified columns and use text format, and respect publication row filters (if a publication is specified) - pub async fn get_table_copy_stream( + /// The stream will include only the specified columns and use text format, and respect + /// publication row filters (if a publication is specified) + pub async fn get_table_copy_stream<'a>( &self, table_id: TableId, - column_schemas: &[ColumnSchema], + column_schemas: impl Iterator, publication: Option<&str>, ) -> EtlResult { let column_list = column_schemas - .iter() .map(|col| quote_identifier(&col.name)) .collect::>() .join(", "); diff --git a/etl/src/replication/masks.rs b/etl/src/replication/masks.rs new file mode 100644 index 000000000..06a946168 --- /dev/null +++ b/etl/src/replication/masks.rs @@ -0,0 +1,129 @@ +//! Shared replication mask storage for tracking replicated columns across workers. +//! +//! PostgreSQL 15+ supports column-level publication filtering, where only specific columns +//! are replicated rather than all columns. A [`ReplicationMask`] is a bitmask indicating +//! which columns in a table are being replicated (1 = replicated, 0 = not replicated). +//! +//! This mask is needed to correctly decode replication events, since the stream only +//! contains values for replicated columns. Both the apply worker (for CDC events) and +//! table sync workers (for initial copy) need access to these masks, so they are stored +//! in a shared container passed to all workers. +//! +//! The replication mask is kept in-memory only because PostgreSQL guarantees that RELATION +//! messages are sent at the start of each connection and whenever the schema changes. This +//! ensures we always receive schema information before any data events that depend on it, +//! allowing us to compute the mask on-demand without persistence. +//! +//! **Limitation**: Adding or removing columns from a publication while the pipeline is +//! running will cause schema mismatches. Downstream tables that rely on a fixed schema +//! will break because the replicated column set changes but the destination schema does +//! not automatically update. + +use etl_postgres::types::{ReplicationMask, TableId}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// Thread-safe container for replication masks shared across workers. +#[derive(Debug, Clone, Default)] +pub struct ReplicationMasks { + inner: Arc>>, +} + +impl ReplicationMasks { + /// Creates a new empty [`ReplicationMasks`] container. + pub fn new() -> Self { + Self { + inner: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Stores a replication mask for a table. + /// + /// This is typically called by the table sync worker after it has determined + /// which columns are being replicated for a table. + pub async fn set(&self, table_id: TableId, mask: ReplicationMask) { + let mut guard = self.inner.write().await; + guard.insert(table_id, mask); + } + + /// Retrieves the replication mask for a table. + /// + /// Returns `None` if no mask has been set for the given table. + pub async fn get(&self, table_id: &TableId) -> Option { + let guard = self.inner.read().await; + guard.get(table_id).cloned() + } + + /// Removes the replication mask for a table. + /// + /// This is called after processing a DDL schema change message to invalidate the cached + /// mask. While PostgreSQL guarantees that a RELATION message will be sent before any DML + /// events after a schema change, we proactively invalidate the mask to ensure consistency. + /// The next RELATION message will rebuild the mask with the updated schema. + pub async fn remove(&self, table_id: &TableId) { + let mut guard = self.inner.write().await; + guard.remove(table_id); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use etl_postgres::types::{ColumnSchema, TableName, TableSchema}; + use std::collections::HashSet; + use tokio_postgres::types::Type; + + fn create_test_mask() -> ReplicationMask { + let schema = TableSchema::new( + TableId::new(123), + TableName::new("public".to_string(), "test_table".to_string()), + vec![ + ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false), + ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true), + ColumnSchema::new("age".to_string(), Type::INT4, -1, 3, None, true), + ], + ); + + let replicated_columns: HashSet = + ["id".to_string(), "age".to_string()].into_iter().collect(); + ReplicationMask::build(&schema, &replicated_columns) + } + + #[tokio::test] + async fn test_set_and_get() { + let masks = ReplicationMasks::new(); + let table_id = TableId::new(123); + let mask = create_test_mask(); + + masks.set(table_id, mask.clone()).await; + + let retrieved = masks.get(&table_id).await; + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().as_slice(), mask.as_slice()); + } + + #[tokio::test] + async fn test_get_nonexistent() { + let masks = ReplicationMasks::new(); + let table_id = TableId::new(123); + + let retrieved = masks.get(&table_id).await; + assert!(retrieved.is_none()); + } + + #[tokio::test] + async fn test_clone_shares_state() { + let masks1 = ReplicationMasks::new(); + let masks2 = masks1.clone(); + let table_id = TableId::new(123); + let mask = create_test_mask(); + + masks1.set(table_id, mask.clone()).await; + + // masks2 should see the same data + let retrieved = masks2.get(&table_id).await; + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().as_slice(), mask.as_slice()); + } +} diff --git a/etl/src/replication/mod.rs b/etl/src/replication/mod.rs index 39fed1317..77ef39338 100644 --- a/etl/src/replication/mod.rs +++ b/etl/src/replication/mod.rs @@ -6,5 +6,6 @@ pub mod apply; pub mod client; pub mod common; +pub mod masks; pub mod stream; pub mod table_sync; diff --git a/etl/src/replication/stream.rs b/etl/src/replication/stream.rs index 555b18b73..2df9cb256 100644 --- a/etl/src/replication/stream.rs +++ b/etl/src/replication/stream.rs @@ -10,10 +10,14 @@ use std::time::{Duration, Instant}; use tokio_postgres::CopyOutStream; use tokio_postgres::types::PgLsn; use tracing::debug; +#[cfg(feature = "failpoints")] +use tracing::warn; use crate::conversions::table_row::parse_table_row_from_postgres_copy_bytes; use crate::error::{ErrorKind, EtlResult}; use crate::etl_error; +#[cfg(feature = "failpoints")] +use crate::failpoints::{SEND_STATUS_UPDATE_FP, etl_fail_point_active}; use crate::metrics::{ETL_COPIED_TABLE_ROW_SIZE_BYTES, PIPELINE_ID_LABEL}; use crate::types::{PipelineId, TableRow}; use metrics::histogram; @@ -29,23 +33,19 @@ pin_project! { /// using the provided column schemas. The conversion process handles both text and /// binary format data. #[must_use = "streams do nothing unless polled"] - pub struct TableCopyStream<'a> { + pub struct TableCopyStream { #[pin] stream: CopyOutStream, - column_schemas: &'a [ColumnSchema], + column_schemas: I, pipeline_id: PipelineId, } } -impl<'a> TableCopyStream<'a> { +impl TableCopyStream { /// Creates a new [`TableCopyStream`] from a [`CopyOutStream`] and column schemas. /// /// The column schemas are used to convert the raw Postgres data into [`TableRow`]s. - pub fn wrap( - stream: CopyOutStream, - column_schemas: &'a [ColumnSchema], - pipeline_id: PipelineId, - ) -> Self { + pub fn wrap(stream: CopyOutStream, column_schemas: I, pipeline_id: PipelineId) -> Self { Self { stream, column_schemas, @@ -54,7 +54,10 @@ impl<'a> TableCopyStream<'a> { } } -impl<'a> Stream for TableCopyStream<'a> { +impl<'a, I> Stream for TableCopyStream +where + I: Iterator + Clone, +{ type Item = EtlResult; /// Polls the stream for the next converted table row with comprehensive error handling. @@ -66,7 +69,7 @@ impl<'a> Stream for TableCopyStream<'a> { match ready!(this.stream.poll_next(cx)) { // TODO: allow pluggable table row conversion based on if the data is in text or binary format. Some(Ok(row)) => { - // Emit raw row size in bytes. This is a low effort way to estimate table rows size. + // Emit raw row size in bytes. This is a low-effort way to estimate table rows size. histogram!( ETL_COPIED_TABLE_ROW_SIZE_BYTES, PIPELINE_ID_LABEL => this.pipeline_id.to_string() @@ -75,7 +78,7 @@ impl<'a> Stream for TableCopyStream<'a> { // CONVERSION PHASE: Transform raw bytes into structured TableRow // This is where most errors occur due to data format or type issues - match parse_table_row_from_postgres_copy_bytes(&row, this.column_schemas) { + match parse_table_row_from_postgres_copy_bytes(&row, this.column_schemas.clone()) { Ok(row) => Poll::Ready(Some(Ok(row))), Err(err) => { // CONVERSION ERROR: Preserve full error context for debugging @@ -128,6 +131,15 @@ impl EventsStream { mut flush_lsn: PgLsn, force: bool, ) -> EtlResult<()> { + // If the failpoint is active, we do not send any status update. This is useful for testing + // the system when we want to check what happens when no status updates are sent. + #[cfg(feature = "failpoints")] + if etl_fail_point_active(SEND_STATUS_UPDATE_FP) { + warn!("not sending status update due to active failpoint"); + + return Ok(()); + } + let this = self.project(); // If the new LSN is less than the last one, we can safely ignore it, since we only want diff --git a/etl/src/replication/table_sync.rs b/etl/src/replication/table_sync.rs index 10bef2416..64082952b 100644 --- a/etl/src/replication/table_sync.rs +++ b/etl/src/replication/table_sync.rs @@ -1,6 +1,6 @@ use etl_config::shared::PipelineConfig; use etl_postgres::replication::slots::EtlReplicationSlot; -use etl_postgres::types::TableId; +use etl_postgres::types::{ReplicatedTableSchema, ReplicationMask, SchemaError, TableId}; use futures::StreamExt; use metrics::histogram; use std::sync::Arc; @@ -17,7 +17,7 @@ use crate::destination::Destination; use crate::error::{ErrorKind, EtlResult}; #[cfg(feature = "failpoints")] use crate::failpoints::{ - START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION, START_TABLE_SYNC_DURING_DATA_SYNC, + START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION_FP, START_TABLE_SYNC_DURING_DATA_SYNC_FP, etl_fail_point, }; use crate::metrics::{ @@ -26,6 +26,7 @@ use crate::metrics::{ PIPELINE_ID_LABEL, WORKER_TYPE_LABEL, }; use crate::replication::client::PgReplicationClient; +use crate::replication::masks::ReplicationMasks; use crate::replication::stream::TableCopyStream; use crate::state::table::{TableReplicationPhase, TableReplicationPhaseType}; use crate::store::schema::SchemaStore; @@ -64,6 +65,7 @@ pub async fn start_table_sync( table_sync_worker_state: TableSyncWorkerState, store: S, destination: D, + replication_masks: &ReplicationMasks, shutdown_rx: ShutdownRx, force_syncing_tables_tx: SignalTx, ) -> EtlResult @@ -150,20 +152,6 @@ where } } - // We must truncate the destination table before starting a copy to avoid data inconsistencies. - // Example scenario: - // 1. The source table has a single row (id = 1) that is copied to the destination during the initial copy. - // 2. Before the table’s phase is set to `FinishedCopy`, the process crashes. - // 3. While down, the source deletes row id = 1 and inserts row id = 2. - // 4. When restarted, the process sees the table in the ` DataSync ` state, deletes the slot, and copies again. - // 5. This time, only row id = 2 is copied, but row id = 1 still exists in the destination. - // Result: the destination has two rows (id = 1 and id = 2) instead of only one (id = 2). - // Fix: Always truncate the destination table before starting a copy. - // - // We try to truncate the table also during `Init` because we support state rollback and - // a table might be there from a previous run. - destination.truncate_table(table_id).await?; - // We are ready to start copying table data, and we update the state accordingly. info!("starting data copy for table {}", table_id); { @@ -175,7 +163,7 @@ where // Fail point to test when the table sync fails before copying data. #[cfg(feature = "failpoints")] - etl_fail_point(START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION)?; + etl_fail_point(START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION_FP)?; // We create the slot with a transaction, since we need to have a consistent snapshot of the database // before copying the schema and tables. @@ -194,10 +182,7 @@ where // - Destination -> we write here because some consumers might want to have the schema of incoming // data. info!("fetching table schema for table {}", table_id); - let table_schema = transaction - .get_table_schema(table_id, Some(&config.publication_name)) - .await?; - + let table_schema = transaction.get_table_schema(table_id).await?; if !table_schema.has_primary_keys() { bail!( ErrorKind::SourceSchemaError, @@ -208,18 +193,63 @@ where // We store the table schema in the schema store to be able to retrieve it even when the // pipeline is restarted, since it's outside the lifecycle of the pipeline. - store.store_table_schema(table_schema.clone()).await?; + let table_schema = store.store_table_schema(table_schema).await?; + + // Get the names of columns being replicated based on the publication's column filter. + // This must be done in the same transaction as `get_table_schema` for consistency. + let replicated_column_names = transaction + .get_replicated_column_names(table_id, &table_schema, &config.publication_name) + .await?; + + // Build and store the replication mask for use during CDC. + // We use `try_build` here because the schema was just loaded and should match + // the publication's column filter. Any mismatch indicates a schema inconsistency. + let replication_mask = + ReplicationMask::try_build(&table_schema, &replicated_column_names).map_err( + |err: SchemaError| { + crate::etl_error!( + ErrorKind::InvalidState, + "Schema mismatch during table sync", + format!("{}", err) + ) + }, + )?; + replication_masks + .set(table_id, replication_mask.clone()) + .await; - // We create the copy table stream. + // Create the replicated table schema with the replication mask. + let replicated_table_schema = + ReplicatedTableSchema::from_mask(table_schema, replication_mask); + + // We must truncate the destination table before starting a copy to avoid data inconsistencies. + // + // Example scenario: + // 1. The source table has a single row (id = 1) that is copied to the destination during the initial copy. + // 2. Before the table’s phase is set to `FinishedCopy`, the process crashes. + // 3. While down, the source deletes row id = 1 and inserts row id = 2. + // 4. When restarted, the process sees the table in the ` DataSync ` state, deletes the slot, and copies again. + // 5. This time, only row id = 2 is copied, but row id = 1 still exists in the destination. + // Result: the destination has two rows (id = 1 and id = 2) instead of only one (id = 2). + // Fix: Always truncate the destination table before starting a copy. + // + // We try to truncate the table also during `Init` because we support state rollback and + // a table might be there from a previous run. + destination.truncate_table(&replicated_table_schema).await?; + + // We create the copy table stream on the replicated columns. let table_copy_stream = transaction .get_table_copy_stream( table_id, - &table_schema.column_schemas, + replicated_table_schema.column_schemas(), Some(&config.publication_name), ) .await?; - let table_copy_stream = - TableCopyStream::wrap(table_copy_stream, &table_schema.column_schemas, pipeline_id); + let table_copy_stream = TableCopyStream::wrap( + table_copy_stream, + replicated_table_schema.column_schemas(), + pipeline_id, + ); let table_copy_stream = TimeoutBatchStream::wrap( table_copy_stream, config.batch.clone(), @@ -253,7 +283,9 @@ where let before_sending = Instant::now(); - destination.write_table_rows(table_id, table_rows).await?; + destination + .write_table_rows(&replicated_table_schema, table_rows) + .await?; table_rows_written = true; metrics::counter!( @@ -277,7 +309,7 @@ where // Fail point to test when the table sync fails after copying one batch. #[cfg(feature = "failpoints")] - etl_fail_point(START_TABLE_SYNC_DURING_DATA_SYNC)?; + etl_fail_point(START_TABLE_SYNC_DURING_DATA_SYNC_FP)?; } ShutdownResult::Shutdown(_) => { // If we received a shutdown in the middle of a table copy, we bail knowing @@ -296,7 +328,9 @@ where // If no table rows were written, we call the method nonetheless with no rows, to kickstart // table creation. if !table_rows_written { - destination.write_table_rows(table_id, vec![]).await?; + destination + .write_table_rows(&replicated_table_schema, vec![]) + .await?; info!( "writing empty table rows since table {} was empty", table_id diff --git a/etl/src/state/destination.rs b/etl/src/state/destination.rs new file mode 100644 index 000000000..c675510bb --- /dev/null +++ b/etl/src/state/destination.rs @@ -0,0 +1,115 @@ +use etl_postgres::types::{ReplicationMask, SnapshotId}; + +/// Status of the schema at a destination. +/// +/// Tracks whether a schema change is in progress or complete. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DestinationTableSchemaStatus { + /// A schema change is currently being applied. + Applying, + /// The schema has been successfully applied. + Applied, +} + +/// Unified metadata for a table at a destination. +/// +/// Tracks all destination-related state for a replicated table in a single +/// structure. This structure is created atomically when a table is first +/// replicated to a destination, containing all the information needed to +/// track and manage that table's destination state. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DestinationTableMetadata { + /// The name/identifier of the table in the destination system. + pub destination_table_id: String, + /// The snapshot_id of the schema currently applied at the destination. + pub snapshot_id: SnapshotId, + /// The schema version before the current change. None for initial schemas. + /// + /// Destinations that support atomic DDL can use this for recovery: if + /// `schema_status` is `Applying` on startup, the destination knows the + /// DDL was rolled back and can reset to this snapshot to retry. + pub previous_snapshot_id: Option, + /// Status of the current schema change operation. + /// + /// If `Applying` is found on startup, the destination schema may be in + /// an unknown state and recovery may be needed depending on the destination. + pub schema_status: DestinationTableSchemaStatus, + /// The replication mask indicating which columns are replicated. + /// + /// Each byte is 0 (not replicated) or 1 (replicated), with the index + /// corresponding to the column's ordinal position in the schema. + pub replication_mask: ReplicationMask, +} + +impl DestinationTableMetadata { + /// Creates new metadata for a table being created at the destination. + /// + /// Initializes with `Applying` status since the table creation is in progress. + /// For initial table creation, `previous_snapshot_id` is None. + pub fn new_applying( + destination_table_id: String, + snapshot_id: SnapshotId, + replication_mask: ReplicationMask, + ) -> Self { + Self { + destination_table_id, + snapshot_id, + previous_snapshot_id: None, + schema_status: DestinationTableSchemaStatus::Applying, + replication_mask, + } + } + + /// Creates new metadata for a table that has been successfully created. + /// + /// Initializes with `Applied` status. + pub fn new_applied( + destination_table_id: String, + snapshot_id: SnapshotId, + replication_mask: ReplicationMask, + ) -> Self { + Self { + destination_table_id, + snapshot_id, + previous_snapshot_id: None, + schema_status: DestinationTableSchemaStatus::Applied, + replication_mask, + } + } + + /// Returns true if a schema change is in progress. + pub fn is_applying(&self) -> bool { + self.schema_status == DestinationTableSchemaStatus::Applying + } + + /// Returns true if the schema has been applied. + pub fn is_applied(&self) -> bool { + self.schema_status == DestinationTableSchemaStatus::Applied + } + + /// Transitions this metadata to applied status. + /// + /// Clears the previous_snapshot_id since the change completed successfully. + pub fn to_applied(mut self) -> Self { + self.schema_status = DestinationTableSchemaStatus::Applied; + self.previous_snapshot_id = None; + self + } + + /// Updates the schema state for a new schema change. + /// + /// Sets `previous_snapshot_id` to the current snapshot before updating, + /// enabling recovery if the change fails on destinations that support atomic DDL. + pub fn with_schema_change( + mut self, + snapshot_id: SnapshotId, + replication_mask: ReplicationMask, + status: DestinationTableSchemaStatus, + ) -> Self { + self.previous_snapshot_id = Some(self.snapshot_id); + self.snapshot_id = snapshot_id; + self.replication_mask = replication_mask; + self.schema_status = status; + self + } +} diff --git a/etl/src/state/mod.rs b/etl/src/state/mod.rs index cba0a1af4..8227c14f9 100644 --- a/etl/src/state/mod.rs +++ b/etl/src/state/mod.rs @@ -3,4 +3,5 @@ //! Defines state types and enums used to track table replication phases and pipeline progress //! across restarts and worker coordination. +pub mod destination; pub mod table; diff --git a/etl/src/state/table.rs b/etl/src/state/table.rs index fbe1d5c82..40df06cc2 100644 --- a/etl/src/state/table.rs +++ b/etl/src/state/table.rs @@ -148,6 +148,12 @@ impl TableReplicationError { "Check replication slot status and database configuration.", RetryPolicy::ManualRetry, ), + ErrorKind::CorruptedTableSchema => Self::with_solution( + table_id, + error, + "Reset the table state and restart the replication.", + RetryPolicy::ManualRetry, + ), // Special handling for error kinds used during failure injection. #[cfg(feature = "failpoints")] diff --git a/etl/src/store/both/memory.rs b/etl/src/store/both/memory.rs index 5bff8dab7..91e612a92 100644 --- a/etl/src/store/both/memory.rs +++ b/etl/src/store/both/memory.rs @@ -1,10 +1,11 @@ -use etl_postgres::types::{TableId, TableSchema}; +use etl_postgres::types::{SnapshotId, TableId, TableSchema}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; use crate::error::{ErrorKind, EtlResult}; use crate::etl_error; +use crate::state::destination::DestinationTableMetadata; use crate::state::table::TableReplicationPhase; use crate::store::cleanup::CleanupStore; use crate::store::schema::SchemaStore; @@ -20,13 +21,10 @@ struct Inner { /// This is an append-only log that grows over time and provides visibility into /// table state evolution. Entries are chronologically ordered. table_state_history: HashMap>, - /// Cached table schema definitions, reference-counted for efficient sharing. - /// Schemas are expensive to fetch from Postgres, so they're cached here - /// once retrieved and shared via Arc across the application. - table_schemas: HashMap>, - /// Mapping from table IDs to human-readable table names for easier debugging - /// and logging. These mappings are established during schema discovery. - table_mappings: HashMap, + /// Cached table schemas keyed by (TableId, SnapshotId) for versioning support. + table_schemas: HashMap<(TableId, SnapshotId), Arc>, + /// Cached destination table metadata indexed by table ID. + destination_tables_metadata: HashMap, } /// In-memory storage for ETL pipeline state and schema information. @@ -53,7 +51,7 @@ impl MemoryStore { table_replication_states: HashMap::new(), table_state_history: HashMap::new(), table_schemas: HashMap::new(), - table_mappings: HashMap::new(), + destination_tables_metadata: HashMap::new(), }; Self { @@ -139,43 +137,54 @@ impl StateStore for MemoryStore { Ok(previous_state) } - async fn get_table_mapping(&self, source_table_id: &TableId) -> EtlResult> { - let inner = self.inner.lock().await; - - Ok(inner.table_mappings.get(source_table_id).cloned()) - } - - async fn get_table_mappings(&self) -> EtlResult> { + async fn get_destination_table_metadata( + &self, + table_id: &TableId, + ) -> EtlResult> { let inner = self.inner.lock().await; - Ok(inner.table_mappings.clone()) + Ok(inner.destination_tables_metadata.get(table_id).cloned()) } - async fn load_table_mappings(&self) -> EtlResult { + async fn load_destination_tables_metadata(&self) -> EtlResult { let inner = self.inner.lock().await; - Ok(inner.table_mappings.len()) + Ok(inner.destination_tables_metadata.len()) } - async fn store_table_mapping( + async fn store_destination_table_metadata( &self, - source_table_id: TableId, - destination_table_id: String, + table_id: TableId, + metadata: DestinationTableMetadata, ) -> EtlResult<()> { let mut inner = self.inner.lock().await; - inner - .table_mappings - .insert(source_table_id, destination_table_id); + inner.destination_tables_metadata.insert(table_id, metadata); Ok(()) } } impl SchemaStore for MemoryStore { - async fn get_table_schema(&self, table_id: &TableId) -> EtlResult>> { + /// Returns the table schema for the given table at the specified snapshot point. + /// + /// Returns the schema version with the largest snapshot_id <= the requested snapshot_id. + /// For MemoryStore, this only looks in the in-memory cache. + async fn get_table_schema( + &self, + table_id: &TableId, + snapshot_id: SnapshotId, + ) -> EtlResult>> { let inner = self.inner.lock().await; - Ok(inner.table_schemas.get(table_id).cloned()) + // Find the best matching schema (largest snapshot_id <= requested) + let best_match = inner + .table_schemas + .iter() + .filter(|((tid, sid), _)| *tid == *table_id && *sid <= snapshot_id) + .max_by_key(|((_, sid), _)| *sid) + .map(|(_, schema)| schema.clone()); + + Ok(best_match) } async fn get_table_schemas(&self) -> EtlResult>> { @@ -190,13 +199,14 @@ impl SchemaStore for MemoryStore { Ok(inner.table_schemas.len()) } - async fn store_table_schema(&self, table_schema: TableSchema) -> EtlResult<()> { + async fn store_table_schema(&self, table_schema: TableSchema) -> EtlResult> { let mut inner = self.inner.lock().await; - inner - .table_schemas - .insert(table_schema.id, Arc::new(table_schema)); - Ok(()) + let key = (table_schema.id, table_schema.snapshot_id); + let table_schema = Arc::new(table_schema); + inner.table_schemas.insert(key, table_schema.clone()); + + Ok(table_schema) } } @@ -206,8 +216,9 @@ impl CleanupStore for MemoryStore { inner.table_replication_states.remove(&table_id); inner.table_state_history.remove(&table_id); - inner.table_schemas.remove(&table_id); - inner.table_mappings.remove(&table_id); + // Remove all schema versions for this table + inner.table_schemas.retain(|(tid, _), _| *tid != table_id); + inner.destination_tables_metadata.remove(&table_id); Ok(()) } diff --git a/etl/src/store/both/postgres.rs b/etl/src/store/both/postgres.rs index 7178e323d..1c21c70e6 100644 --- a/etl/src/store/both/postgres.rs +++ b/etl/src/store/both/postgres.rs @@ -1,8 +1,8 @@ use std::{collections::HashMap, sync::Arc}; use etl_config::shared::PgConnectionConfig; -use etl_postgres::replication::{connect_to_source_database, schema, state, table_mappings}; -use etl_postgres::types::{TableId, TableSchema}; +use etl_postgres::replication::{connect_to_source_database, destination_metadata, schema, state}; +use etl_postgres::types::{ReplicationMask, SnapshotId, TableId, TableSchema}; use metrics::gauge; use sqlx::PgPool; use tokio::sync::Mutex; @@ -10,6 +10,7 @@ use tracing::{debug, info}; use crate::error::{ErrorKind, EtlError, EtlResult}; use crate::metrics::{ETL_TABLES_TOTAL, PHASE_LABEL, PIPELINE_ID_LABEL}; +use crate::state::destination::{DestinationTableMetadata, DestinationTableSchemaStatus}; use crate::state::table::{RetryPolicy, TableReplicationPhase}; use crate::store::cleanup::CleanupStore; use crate::store::schema::SchemaStore; @@ -19,6 +20,13 @@ use crate::{bail, etl_error}; const NUM_POOL_CONNECTIONS: u32 = 1; +/// Maximum number of schema snapshots to keep cached per table. +/// +/// This limits memory usage by evicting older snapshots when new ones are added. +/// In practice, during a single batch of events, it's highly unlikely to need +/// more than 2 schema versions for any given table. +const MAX_CACHED_SCHEMAS_PER_TABLE: usize = 2; + /// Converts ETL table replication phases to Postgres database state format. /// /// This conversion transforms internal ETL replication states into the format @@ -128,6 +136,34 @@ impl TryFrom for TableReplicationPhase { } } +/// Converts application layer destination schema status to postgres layer. +impl From for destination_metadata::DestinationTableSchemaStatus { + fn from(value: DestinationTableSchemaStatus) -> Self { + match value { + DestinationTableSchemaStatus::Applying => { + destination_metadata::DestinationTableSchemaStatus::Applying + } + DestinationTableSchemaStatus::Applied => { + destination_metadata::DestinationTableSchemaStatus::Applied + } + } + } +} + +/// Converts postgres layer destination schema status to application layer. +impl From for DestinationTableSchemaStatus { + fn from(value: destination_metadata::DestinationTableSchemaStatus) -> Self { + match value { + destination_metadata::DestinationTableSchemaStatus::Applying => { + DestinationTableSchemaStatus::Applying + } + destination_metadata::DestinationTableSchemaStatus::Applied => { + DestinationTableSchemaStatus::Applied + } + } + } +} + /// Inner state of [`PostgresStore`]. #[derive(Debug)] struct Inner { @@ -135,10 +171,16 @@ struct Inner { phase_counts: HashMap<&'static str, u64>, /// Cached table replication states indexed by table ID. table_states: HashMap, - /// Cached table schemas indexed by table ID. - table_schemas: HashMap>, - /// Cached table mappings from source table ID to destination table name. - table_mappings: HashMap, + /// Cached table schemas indexed by (table_id, snapshot_id) for versioning support. + /// + /// This cache is optimized for keeping the most actively used schemas in memory, + /// not all historical snapshots. Schemas are loaded on-demand from the database + /// when not found in cache. During normal operation, this typically contains + /// only the latest schema version for each table, since that's what the + /// replication pipeline actively uses. + table_schemas: HashMap<(TableId, SnapshotId), Arc>, + /// Cached destination table metadata indexed by table ID. + destination_tables_metadata: HashMap, } impl Inner { @@ -151,6 +193,39 @@ impl Inner { let count = self.phase_counts.entry(phase).or_default(); *count += 1; } + + /// Inserts a schema into the cache and evicts older snapshots if necessary. + /// + /// Maintains at most [`MAX_CACHED_SCHEMAS_PER_TABLE`] snapshots per table, + /// evicting the oldest snapshots when the limit is exceeded. + fn insert_schema_with_eviction(&mut self, table_schema: Arc) { + let table_id = table_schema.id; + let snapshot_id = table_schema.snapshot_id; + + // Insert the new schema + self.table_schemas + .insert((table_id, snapshot_id), table_schema); + + // Collect all snapshot_ids for this table + let mut snapshots_for_table: Vec = self + .table_schemas + .keys() + .filter(|(tid, _)| *tid == table_id) + .map(|(_, sid)| *sid) + .collect(); + + // If we exceed the limit, evict oldest snapshots + if snapshots_for_table.len() > MAX_CACHED_SCHEMAS_PER_TABLE { + // Sort ascending so oldest are first + snapshots_for_table.sort(); + + // Remove oldest entries until we're at the limit + let to_remove = snapshots_for_table.len() - MAX_CACHED_SCHEMAS_PER_TABLE; + for &old_snapshot_id in snapshots_for_table.iter().take(to_remove) { + self.table_schemas.remove(&(table_id, old_snapshot_id)); + } + } + } } /// Postgres-backed storage for ETL pipeline state and schema information. @@ -181,7 +256,7 @@ impl PostgresStore { phase_counts: HashMap::new(), table_states: HashMap::new(), table_schemas: HashMap::new(), - table_mappings: HashMap::new(), + destination_tables_metadata: HashMap::new(), }; Self { @@ -404,119 +479,186 @@ impl StateStore for PostgresStore { } } - /// Retrieves a table mapping from source table ID to destination name. + /// Retrieves destination table metadata for a specific table from cache. /// - /// This method looks up the destination table name for a given source table - /// ID from the cache. Table mappings define how source tables are mapped - /// to tables in the destination system. - async fn get_table_mapping(&self, source_table_id: &TableId) -> EtlResult> { - let inner = self.inner.lock().await; - - Ok(inner.table_mappings.get(source_table_id).cloned()) - } - - /// Retrieves all table mappings from cache. - /// - /// This method returns a complete snapshot of all cached table mappings, - /// showing how source table IDs map to destination table names. Useful - /// for operations that need visibility into the complete mapping configuration. - async fn get_table_mappings(&self) -> EtlResult> { + /// This method provides fast access to destination metadata by reading + /// from the in-memory cache. + async fn get_destination_table_metadata( + &self, + table_id: &TableId, + ) -> EtlResult> { let inner = self.inner.lock().await; - Ok(inner.table_mappings.clone()) + Ok(inner.destination_tables_metadata.get(table_id).cloned()) } - /// Loads table mappings from Postgres into memory cache. + /// Loads all destination table metadata from Postgres into memory cache. /// - /// This method connects to the source database, retrieves all table mapping - /// definitions for this pipeline, and populates the in-memory cache. - /// Called during pipeline initialization to establish source-to-destination - /// table mappings. - async fn load_table_mappings(&self) -> EtlResult { - debug!("loading table mappings from postgres state store"); + /// This method connects to the source database, retrieves all destination + /// table metadata for this pipeline, and populates the in-memory cache. + async fn load_destination_tables_metadata(&self) -> EtlResult { + debug!("loading destination tables metadata from postgres state store"); let pool = self.connect_to_source().await?; - let table_mappings = table_mappings::load_table_mappings(&pool, self.pipeline_id as i64) - .await - .map_err(|err| { - etl_error!( - ErrorKind::SourceQueryFailed, - "Table mappings loading failed", - format!("Failed to load table mappings from PostgreSQL: {}", err) - ) - })?; - let table_mappings_len = table_mappings.len(); + let rows = + destination_metadata::load_destination_tables_metadata(&pool, self.pipeline_id as i64) + .await + .map_err(|err| { + etl_error!( + ErrorKind::SourceQueryFailed, + "Destination tables metadata loading failed", + format!( + "Failed to load destination tables metadata from PostgreSQL: {}", + err + ) + ) + })?; + + let mut metadata: HashMap = HashMap::new(); + for (table_id, row) in rows { + metadata.insert( + table_id, + DestinationTableMetadata { + destination_table_id: row.destination_table_id, + snapshot_id: row.snapshot_id, + previous_snapshot_id: row.previous_snapshot_id, + schema_status: row.schema_status.into(), + replication_mask: ReplicationMask::from_bytes(row.replication_mask), + }, + ); + } + let metadata_len = metadata.len(); let mut inner = self.inner.lock().await; - inner.table_mappings = table_mappings; + inner.destination_tables_metadata = metadata; info!( - "loaded {} table mappings from postgres state store", - table_mappings_len + "loaded {} destination tables metadata from postgres state store", + metadata_len ); - Ok(table_mappings_len) + Ok(metadata_len) } - /// Stores a table mapping in both database and cache. + /// Stores complete destination table metadata in both database and cache. /// - /// This method persists a table mapping from source table ID to destination - /// table name in the database and updates the in-memory cache atomically. - /// Used when establishing or updating the mapping configuration between - /// source and destination systems. - async fn store_table_mapping( + /// Performs a full upsert - use for initial table creation. + async fn store_destination_table_metadata( &self, - source_table_id: TableId, - destination_table_id: String, + table_id: TableId, + metadata: DestinationTableMetadata, ) -> EtlResult<()> { - debug!( - "storing table mapping: '{}' -> '{}'", - source_table_id, destination_table_id - ); - let pool = self.connect_to_source().await?; let mut inner = self.inner.lock().await; - table_mappings::store_table_mapping( + destination_metadata::store_destination_table_metadata( &pool, self.pipeline_id as i64, - &source_table_id, - &destination_table_id, + table_id, + &metadata.destination_table_id, + metadata.snapshot_id, + metadata.previous_snapshot_id, + metadata.schema_status.into(), + metadata.replication_mask.as_slice(), ) .await .map_err(|err| { etl_error!( ErrorKind::SourceQueryFailed, - "Table mapping storage failed", - format!("Failed to store table mapping in PostgreSQL: {}", err) + "Destination table metadata storage failed", + format!( + "Failed to store destination table metadata in PostgreSQL: {}", + err + ) ) })?; - inner - .table_mappings - .insert(source_table_id, destination_table_id); + + inner.destination_tables_metadata.insert(table_id, metadata); Ok(()) } } impl SchemaStore for PostgresStore { - /// Retrieves a table schema from cache by table ID. + /// Retrieves a table schema at a specific snapshot point. /// - /// This method provides fast access to cached table schemas, which are - /// essential for processing replication events. Schemas are loaded during - /// startup and cached for the lifetime of the pipeline. - async fn get_table_schema(&self, table_id: &TableId) -> EtlResult>> { - let inner = self.inner.lock().await; + /// Returns the schema version with the largest snapshot_id <= the requested snapshot_id. + /// First checks the in-memory cache, then loads from the database if not found. + /// The loaded schema is cached for subsequent requests. Note that the cache is + /// optimized for active schemas, not historical snapshots. + async fn get_table_schema( + &self, + table_id: &TableId, + snapshot_id: SnapshotId, + ) -> EtlResult>> { + // First, check if we have a cached schema that matches the criteria. + // + // We can afford to hold the lock only for this short critical section since we assume that + // there is not really concurrency at the table level since each table is processed by exactly + // one worker. + { + let inner = self.inner.lock().await; + + // Find the best matching schema in the cache (largest snapshot_id <= requested). + let newest_table_schema = inner + .table_schemas + .iter() + .filter(|((tid, sid), _)| *tid == *table_id && *sid <= snapshot_id) + .max_by_key(|((_, sid), _)| *sid) + .map(|(_, schema)| schema.clone()); + + if newest_table_schema.is_some() { + return Ok(newest_table_schema); + } + } + + debug!( + "schema for table {} at snapshot {} not in cache, loading from database", + table_id, snapshot_id + ); + + let pool = self.connect_to_source().await?; + + // Load the schema at the requested snapshot. + let table_schema = schema::load_table_schema_at_snapshot( + &pool, + self.pipeline_id as i64, + *table_id, + snapshot_id, + ) + .await + .map_err(|err| { + etl_error!( + ErrorKind::SourceQueryFailed, + "Table schema loading failed", + format!( + "Failed to load table schema for table {} at snapshot {} from PostgreSQL: {}", + table_id, snapshot_id, err + ) + ) + })?; + + let Some(table_schema) = table_schema else { + return Ok(None); + }; - Ok(inner.table_schemas.get(table_id).cloned()) + let result = { + let mut inner = self.inner.lock().await; + + let table_schema = Arc::new(table_schema); + inner.insert_schema_with_eviction(table_schema.clone()); + + Some(table_schema) + }; + + Ok(result) } /// Retrieves all cached table schemas as a vector. /// /// This method returns all currently cached table schemas, providing a /// complete view of the schema information available to the pipeline. - /// Useful for operations that need to process or analyze all table schemas. async fn get_table_schemas(&self) -> EtlResult>> { let inner = self.inner.lock().await; @@ -525,8 +667,8 @@ impl SchemaStore for PostgresStore { /// Loads table schemas from Postgres into memory cache. /// - /// This method connects to the source database, retrieves schema information - /// for all tables in this pipeline, and populates the in-memory cache. + /// This method connects to the source database, retrieves the latest schema + /// version for all tables in this pipeline, and populates the in-memory cache. /// Called during pipeline initialization to establish the schema context /// needed for processing replication events. async fn load_table_schemas(&self) -> EtlResult { @@ -545,14 +687,11 @@ impl SchemaStore for PostgresStore { })?; let table_schemas_len = table_schemas.len(); - // For performance reasons, since we load the table schemas only once during startup - // and from a single thread, we can afford to have a super short critical section. let mut inner = self.inner.lock().await; inner.table_schemas.clear(); for table_schema in table_schemas { - inner - .table_schemas - .insert(table_schema.id, Arc::new(table_schema)); + let key = (table_schema.id, table_schema.snapshot_id); + inner.table_schemas.insert(key, Arc::new(table_schema)); } info!( @@ -566,15 +705,16 @@ impl SchemaStore for PostgresStore { /// Stores a table schema in both database and cache. /// /// This method persists a table schema to the database and updates the - /// in-memory cache atomically. Used when new tables are discovered during - /// replication or when schema definitions need to be updated. - async fn store_table_schema(&self, table_schema: TableSchema) -> EtlResult<()> { - debug!("storing table schema for table '{}'", table_schema.name); + /// in-memory cache atomically. The schema's snapshot_id determines which + /// version this schema represents. + async fn store_table_schema(&self, table_schema: TableSchema) -> EtlResult> { + debug!( + "storing table schema for table '{}' at snapshot {}", + table_schema.name, table_schema.snapshot_id + ); let pool = self.connect_to_source().await?; - // We also lock the entire section to be consistent. - let mut inner = self.inner.lock().await; schema::store_table_schema(&pool, self.pipeline_id as i64, &table_schema) .await .map_err(|err| { @@ -584,11 +724,12 @@ impl SchemaStore for PostgresStore { format!("Failed to store table schema in PostgreSQL: {}", err) ) })?; - inner - .table_schemas - .insert(table_schema.id, Arc::new(table_schema)); - Ok(()) + let mut inner = self.inner.lock().await; + let table_schema = Arc::new(table_schema); + inner.insert_schema_with_eviction(table_schema.clone()); + + Ok(table_schema) } } @@ -599,17 +740,20 @@ impl CleanupStore for PostgresStore { // Use a single DB transaction to keep persistent state consistent. let mut tx = pool.begin().await?; - table_mappings::delete_table_mappings_for_table( + destination_metadata::delete_destination_table_metadata( &mut *tx, self.pipeline_id as i64, - &table_id, + table_id, ) .await .map_err(|err| { etl_error!( ErrorKind::SourceQueryFailed, - "Table mapping deletion failed", - format!("Failed to delete table mapping in PostgreSQL: {}", err) + "Destination table metadata deletion failed", + format!( + "Failed to delete destination table metadata in PostgreSQL: {}", + err + ) ) })?; @@ -632,8 +776,9 @@ impl CleanupStore for PostgresStore { let mut inner = self.inner.lock().await; inner.table_states.remove(&table_id); - inner.table_schemas.remove(&table_id); - inner.table_mappings.remove(&table_id); + // Remove all schema versions for this table + inner.table_schemas.retain(|(tid, _), _| *tid != table_id); + inner.destination_tables_metadata.remove(&table_id); emit_table_metrics( self.pipeline_id, diff --git a/etl/src/store/schema/base.rs b/etl/src/store/schema/base.rs index 2a52126b9..d094336c4 100644 --- a/etl/src/store/schema/base.rs +++ b/etl/src/store/schema/base.rs @@ -1,4 +1,4 @@ -use etl_postgres::types::{TableId, TableSchema}; +use etl_postgres::types::{SnapshotId, TableId, TableSchema}; use std::sync::Arc; use crate::error::EtlResult; @@ -6,19 +6,25 @@ use crate::error::EtlResult; /// Trait for storing and retrieving database table schema information. /// /// [`SchemaStore`] implementations are responsible for defining how the schema information -/// is stored and retrieved. +/// is stored and retrieved. The store supports schema versioning where each schema version +/// is identified by a snapshot_id (the start_lsn of the DDL message that created it). /// /// Implementations should ensure thread-safety and handle concurrent access to the data. pub trait SchemaStore { - /// Returns table schema for table with id `table_id` from the cache. + /// Returns the table schema for the given table at the specified snapshot point. /// - /// Does not load any new data into the cache. + /// Returns the schema version with the largest snapshot_id that is <= the requested + /// snapshot_id. If not found in cache, loads from the persistent store. As an optimization, + /// also loads the latest schema version when fetching from the database. + /// + /// Returns `None` if no schema version exists for the table at or before the given snapshot. fn get_table_schema( &self, table_id: &TableId, + snapshot_id: SnapshotId, ) -> impl Future>>> + Send; - /// Returns all table schemas from the cache. + /// Returns all cached table schemas. /// /// Does not read from the persistent store. fn get_table_schemas(&self) -> impl Future>>> + Send; @@ -29,8 +35,10 @@ pub trait SchemaStore { fn load_table_schemas(&self) -> impl Future> + Send; /// Stores a table schema in both the cache and the persistent store. + /// + /// The schema's `snapshot_id` field determines which version this schema represents. fn store_table_schema( &self, table_schema: TableSchema, - ) -> impl Future> + Send; + ) -> impl Future>> + Send; } diff --git a/etl/src/store/state/base.rs b/etl/src/store/state/base.rs index acb55c2d7..23ffa7dd8 100644 --- a/etl/src/store/state/base.rs +++ b/etl/src/store/state/base.rs @@ -2,13 +2,13 @@ use etl_postgres::types::TableId; use std::{collections::HashMap, future::Future}; use crate::error::EtlResult; +use crate::state::destination::DestinationTableMetadata; use crate::state::table::TableReplicationPhase; -/// Trait for storing and retrieving table replication state and mapping information. +/// Trait for storing and retrieving table replication state and destination metadata. /// -/// [`StateStore`] implementations are responsible for defining how table replication states and -/// table mappings are stored and retrieved. Table mappings define the relationship between -/// source table identifiers and destination table names. +/// [`StateStore`] implementations are responsible for defining how table replication states +/// and destination table metadata are stored and retrieved. /// /// Implementations should ensure thread-safety and handle concurrent access to the data. pub trait StateStore { @@ -21,12 +21,14 @@ pub trait StateStore { ) -> impl Future>> + Send; /// Returns the table replication states for all the tables from the cache. + /// /// Does not read from the persistent store. fn get_table_replication_states( &self, ) -> impl Future>> + Send; /// Loads the table replication states from the persistent state into the cache. + /// /// This should be called once at program start to load the state into the cache /// and then use only the `get_X` methods to access the state. Updating the state /// by calling the `update_table_replication_state` updates in both the cache and @@ -47,30 +49,26 @@ pub trait StateStore { table_id: TableId, ) -> impl Future> + Send; - /// Returns table mapping for a specific source table ID from the cache. + /// Returns destination table metadata for a specific table from the cache. /// /// Does not load any new data into the cache. - fn get_table_mapping( + fn get_destination_table_metadata( &self, - source_table_id: &TableId, - ) -> impl Future>> + Send; + table_id: &TableId, + ) -> impl Future>> + Send; - /// Returns all table mappings from the cache. + /// Loads all destination table metadata from the persistent state into the cache. /// - /// Does not read from the persistent store. - fn get_table_mappings( - &self, - ) -> impl Future>> + Send; + /// This should be called during startup to load the metadata into the cache. + fn load_destination_tables_metadata(&self) -> impl Future> + Send; - /// Loads all table mappings from the persistent state into the cache. + /// Stores destination table metadata in both the cache and persistent store. /// - /// This can be called lazily when table mappings are needed by the destination. - fn load_table_mappings(&self) -> impl Future> + Send; - - /// Stores a table mapping in both the cache and the persistent store. - fn store_table_mapping( + /// This performs a full upsert. For updates, get the current metadata, modify + /// the fields you need to change, and store it back. + fn store_destination_table_metadata( &self, - source_table_id: TableId, - destination_table_id: String, + table_id: TableId, + metadata: DestinationTableMetadata, ) -> impl Future> + Send; } diff --git a/etl/src/store/state/mod.rs b/etl/src/store/state/mod.rs index cbcb6ac7e..096e6ca5e 100644 --- a/etl/src/store/state/mod.rs +++ b/etl/src/store/state/mod.rs @@ -1,3 +1,4 @@ mod base; +pub use crate::state::destination::{DestinationTableMetadata, DestinationTableSchemaStatus}; pub use base::*; diff --git a/etl/src/test_utils/database.rs b/etl/src/test_utils/database.rs index 106bd0e2b..48e7de87a 100644 --- a/etl/src/test_utils/database.rs +++ b/etl/src/test_utils/database.rs @@ -44,19 +44,19 @@ fn local_pg_connection_config() -> PgConnectionConfig { } } -/// Creates a new test database instance with a unique name. +/// Creates a new test database instance with a unique name and runs migrations. /// /// This function spawns a new Postgres database with a random UUID as its name, /// using default credentials and disabled SSL. It automatically creates the test schema -/// for organizing test tables. +/// for organizing test tables and runs all ETL migrations. /// /// # Panics /// -/// Panics if the test schema cannot be created. +/// Panics if the test schema cannot be created or migrations fail. pub async fn spawn_source_database() -> PgDatabase { // We create the database via tokio postgres. let config = local_pg_connection_config(); - let database = PgDatabase::new(config).await; + let database = PgDatabase::new(config.clone()).await; // We create the test schema, where all tables will be added. database @@ -67,23 +67,6 @@ pub async fn spawn_source_database() -> PgDatabase { .await .expect("Failed to create test schema"); - database -} - -/// Creates a new test database instance with a unique name and all the ETL migrations run. -/// -/// This function spawns a new Postgres database with a random UUID as its name, -/// using default credentials and disabled SSL. It automatically creates the test schema -/// for organizing test tables. -/// -/// # Panics -/// -/// Panics if the test schema cannot be created. -pub async fn spawn_source_database_for_store() -> PgDatabase { - // We create the database via tokio postgres. - let config = local_pg_connection_config(); - let database = PgDatabase::new(config.clone()).await; - // We now connect via sqlx just to run the migrations, but we still use the original tokio postgres // connection for the db object returned. let pool = connect_to_source_database(&config, 1, 1) @@ -103,11 +86,11 @@ pub async fn spawn_source_database_for_store() -> PgDatabase { .await .expect("Failed to set search path to 'etl'"); - // Run replicator migrations to create the state store tables. - sqlx::migrate!("../etl-replicator/migrations") + // Run migrations to create the state store tables. + sqlx::migrate!("./migrations") .run(&pool) .await - .expect("Failed to run replicator migrations"); + .expect("Failed to run migrations"); database } diff --git a/etl/src/test_utils/event.rs b/etl/src/test_utils/event.rs index b30361877..80eae1a7b 100644 --- a/etl/src/test_utils/event.rs +++ b/etl/src/test_utils/event.rs @@ -21,15 +21,16 @@ pub fn group_events_by_type_and_table_id( let mut grouped = HashMap::new(); for event in events { let event_type = EventType::from(event); - // This grouping only works on simple DML operations. + // This grouping works on DML operations and Relation events. let table_ids = match event { - Event::Insert(event) => vec![event.table_id], - Event::Update(event) => vec![event.table_id], - Event::Delete(event) => vec![event.table_id], + Event::Relation(event) => vec![event.replicated_table_schema.id()], + Event::Insert(event) => vec![event.replicated_table_schema.id()], + Event::Update(event) => vec![event.replicated_table_schema.id()], + Event::Delete(event) => vec![event.replicated_table_schema.id()], Event::Truncate(event) => event - .rel_ids + .truncated_tables .iter() - .map(|rel_id| TableId::new(*rel_id)) + .map(|schema| schema.id()) .collect(), _ => vec![], }; @@ -55,6 +56,43 @@ pub fn check_events_count(events: &[Event], conditions: Vec<(EventType, u64)>) - }) } +/// Compares two events for equality in test contexts. +/// +/// This function compares events based on their key fields, ignoring LSN values since those +/// may vary between pipeline runs. +fn events_equal(a: &Event, b: &Event) -> bool { + match (a, b) { + (Event::Begin(a), Event::Begin(b)) => a == b, + (Event::Commit(a), Event::Commit(b)) => a == b, + (Event::Truncate(a), Event::Truncate(b)) => { + if a.options != b.options || a.truncated_tables.len() != b.truncated_tables.len() { + return false; + } + + // Compare table IDs of truncated tables + let a_ids: Vec<_> = a.truncated_tables.iter().map(|s| s.id()).collect(); + let b_ids: Vec<_> = b.truncated_tables.iter().map(|s| s.id()).collect(); + + a_ids == b_ids + } + (Event::Insert(a), Event::Insert(b)) => { + a.replicated_table_schema.id() == b.replicated_table_schema.id() + && a.table_row == b.table_row + } + (Event::Update(a), Event::Update(b)) => { + a.replicated_table_schema.id() == b.replicated_table_schema.id() + && a.table_row == b.table_row + && a.old_table_row == b.old_table_row + } + (Event::Delete(a), Event::Delete(b)) => { + a.replicated_table_schema.id() == b.replicated_table_schema.id() + && a.old_table_row == b.old_table_row + } + (Event::Unsupported, Event::Unsupported) => true, + _ => false, + } +} + /// Returns a new Vec of events with duplicates removed. /// /// Events that are not tied to a specific row (Begin/Commit/Relation/Truncate/Unsupported) @@ -68,10 +106,12 @@ pub fn check_events_count(events: &[Event], conditions: Vec<(EventType, u64)>) - /// thus in some tests we might have to exclude duplicates while performing assertions. pub fn deduplicate_events(events: &[Event]) -> Vec { let mut result: Vec = Vec::with_capacity(events.len()); - for e in events.iter().cloned() { - if !result.contains(&e) { - result.push(e); + + for event in events.iter().cloned() { + if !result.iter().any(|existing| events_equal(existing, &event)) { + result.push(event); } } + result } diff --git a/etl/src/test_utils/materialize.rs b/etl/src/test_utils/materialize.rs index 14ec648fc..061c95c8d 100644 --- a/etl/src/test_utils/materialize.rs +++ b/etl/src/test_utils/materialize.rs @@ -30,9 +30,9 @@ where // Filter by table_id if specified if let Some(target_table_id) = table_id { let event_table_id = match &event { - Event::Insert(insert_event) => insert_event.table_id, - Event::Update(update_event) => update_event.table_id, - Event::Delete(delete_event) => delete_event.table_id, + Event::Insert(insert_event) => insert_event.replicated_table_schema.id(), + Event::Update(update_event) => update_event.replicated_table_schema.id(), + Event::Delete(delete_event) => delete_event.replicated_table_schema.id(), _ => continue, // Skip other event types }; diff --git a/etl/src/test_utils/mod.rs b/etl/src/test_utils/mod.rs index 9381d2b03..613748440 100644 --- a/etl/src/test_utils/mod.rs +++ b/etl/src/test_utils/mod.rs @@ -9,6 +9,6 @@ pub mod event; pub mod materialize; pub mod notify; pub mod pipeline; -pub mod table; +pub mod schema; pub mod test_destination_wrapper; pub mod test_schema; diff --git a/etl/src/test_utils/notify.rs b/etl/src/test_utils/notify.rs index 3738c1f59..a60830b99 100644 --- a/etl/src/test_utils/notify.rs +++ b/etl/src/test_utils/notify.rs @@ -1,10 +1,11 @@ use std::{collections::HashMap, fmt, sync::Arc}; -use etl_postgres::types::{TableId, TableSchema}; +use etl_postgres::types::{SnapshotId, TableId, TableSchema}; use tokio::sync::{Notify, RwLock}; use crate::error::{ErrorKind, EtlResult}; use crate::etl_error; +use crate::state::destination::DestinationTableMetadata; use crate::state::table::{TableReplicationPhase, TableReplicationPhaseType}; use crate::store::cleanup::CleanupStore; use crate::store::schema::SchemaStore; @@ -29,8 +30,9 @@ type TableStateCondition = ( struct Inner { table_replication_states: HashMap, table_state_history: HashMap>, - table_schemas: HashMap>, - table_mappings: HashMap, + /// Stores table schemas in insertion order per table. + table_schemas: HashMap>>, + destination_tables_metadata: HashMap, table_state_type_conditions: Vec, table_state_conditions: Vec, method_call_notifiers: HashMap>>, @@ -87,7 +89,7 @@ impl NotifyingStore { table_replication_states: HashMap::new(), table_state_history: HashMap::new(), table_schemas: HashMap::new(), - table_mappings: HashMap::new(), + destination_tables_metadata: HashMap::new(), table_state_type_conditions: Vec::new(), table_state_conditions: Vec::new(), method_call_notifiers: HashMap::new(), @@ -103,12 +105,35 @@ impl NotifyingStore { inner.table_replication_states.clone() } - pub async fn get_table_schemas(&self) -> HashMap { + pub async fn get_latest_table_schemas(&self) -> HashMap { let inner = self.inner.read().await; + + // Return the latest schema version for each table (last in the Vec). + inner + .table_schemas + .iter() + .filter_map(|(table_id, schemas)| { + schemas + .last() + .map(|schema| (*table_id, Arc::as_ref(schema).clone())) + }) + .collect() + } + + pub async fn get_table_schemas(&self) -> HashMap> { + let inner = self.inner.read().await; + + // Return schemas in insertion order per table. inner .table_schemas .iter() - .map(|(id, schema)| (*id, Arc::as_ref(schema).clone())) + .map(|(table_id, schemas)| { + let schemas_with_ids: Vec<_> = schemas + .iter() + .map(|schema| (schema.snapshot_id, Arc::as_ref(schema).clone())) + .collect(); + (*table_id, schemas_with_ids) + }) .collect() } @@ -261,59 +286,77 @@ impl StateStore for NotifyingStore { Ok(previous_state) } - async fn get_table_mapping(&self, source_table_id: &TableId) -> EtlResult> { - let inner = self.inner.read().await; - Ok(inner.table_mappings.get(source_table_id).cloned()) - } - - async fn get_table_mappings(&self) -> EtlResult> { + async fn get_destination_table_metadata( + &self, + table_id: &TableId, + ) -> EtlResult> { let inner = self.inner.read().await; - Ok(inner.table_mappings.clone()) + Ok(inner.destination_tables_metadata.get(table_id).cloned()) } - async fn load_table_mappings(&self) -> EtlResult { + async fn load_destination_tables_metadata(&self) -> EtlResult { let inner = self.inner.read().await; - Ok(inner.table_mappings.len()) + Ok(inner.destination_tables_metadata.len()) } - async fn store_table_mapping( + async fn store_destination_table_metadata( &self, - source_table_id: TableId, - destination_table_id: String, + table_id: TableId, + metadata: DestinationTableMetadata, ) -> EtlResult<()> { let mut inner = self.inner.write().await; - inner - .table_mappings - .insert(source_table_id, destination_table_id); + inner.destination_tables_metadata.insert(table_id, metadata); Ok(()) } } impl SchemaStore for NotifyingStore { - async fn get_table_schema(&self, table_id: &TableId) -> EtlResult>> { + async fn get_table_schema( + &self, + table_id: &TableId, + snapshot_id: SnapshotId, + ) -> EtlResult>> { let inner = self.inner.read().await; - Ok(inner.table_schemas.get(table_id).cloned()) + // Find the best matching schema (largest snapshot_id <= requested). + let best_match = inner.table_schemas.get(table_id).and_then(|schemas| { + schemas + .iter() + .filter(|schema| schema.snapshot_id <= snapshot_id) + .max_by_key(|schema| schema.snapshot_id) + .cloned() + }); + + Ok(best_match) } async fn get_table_schemas(&self) -> EtlResult>> { let inner = self.inner.read().await; - Ok(inner.table_schemas.values().cloned().collect()) + Ok(inner + .table_schemas + .values() + .flat_map(|schemas| schemas.iter().cloned()) + .collect()) } async fn load_table_schemas(&self) -> EtlResult { let inner = self.inner.read().await; - Ok(inner.table_schemas.len()) + Ok(inner.table_schemas.values().map(|v| v.len()).sum()) } - async fn store_table_schema(&self, table_schema: TableSchema) -> EtlResult<()> { + async fn store_table_schema(&self, table_schema: TableSchema) -> EtlResult> { let mut inner = self.inner.write().await; + + let table_id = table_schema.id; + let table_schema = Arc::new(table_schema); inner .table_schemas - .insert(table_schema.id, Arc::new(table_schema)); + .entry(table_id) + .or_default() + .push(table_schema.clone()); - Ok(()) + Ok(table_schema) } } @@ -324,7 +367,7 @@ impl CleanupStore for NotifyingStore { inner.table_replication_states.remove(&table_id); inner.table_state_history.remove(&table_id); inner.table_schemas.remove(&table_id); - inner.table_mappings.remove(&table_id); + inner.destination_tables_metadata.remove(&table_id); Ok(()) } diff --git a/etl/src/test_utils/pipeline.rs b/etl/src/test_utils/pipeline.rs index 045d21a87..4f316e3cb 100644 --- a/etl/src/test_utils/pipeline.rs +++ b/etl/src/test_utils/pipeline.rs @@ -1,12 +1,20 @@ -use etl_config::shared::{BatchConfig, PgConnectionConfig, PipelineConfig}; -use uuid::Uuid; - use crate::destination::Destination; +use crate::destination::memory::MemoryDestination; use crate::pipeline::Pipeline; +use crate::state::table::TableReplicationPhaseType; use crate::store::cleanup::CleanupStore; use crate::store::schema::SchemaStore; use crate::store::state::StateStore; +use crate::test_utils::database::{spawn_source_database, test_table_name}; +use crate::test_utils::notify::NotifyingStore; +use crate::test_utils::test_destination_wrapper::TestDestinationWrapper; use crate::types::PipelineId; +use etl_config::shared::{BatchConfig, PgConnectionConfig, PipelineConfig}; +use etl_postgres::tokio::test_utils::PgDatabase; +use etl_postgres::types::{TableId, TableName}; +use rand::random; +use tokio_postgres::Client; +use uuid::Uuid; /// Generates a test-specific replication slot name with a random component. /// @@ -73,3 +81,69 @@ where Pipeline::new(config, store, destination) } + +pub async fn create_database_and_pipeline_with_table( + table_suffix: &str, + columns: &[(&str, &str)], +) -> ( + PgDatabase, + TableName, + TableId, + NotifyingStore, + TestDestinationWrapper, + Pipeline>, + PipelineId, + String, +) { + let database = spawn_source_database().await; + + let table_name = test_table_name(table_suffix); + let table_id = database + .create_table(table_name.clone(), true, columns) + .await + .unwrap(); + + let publication_name = format!("pub_{}", random::()); + database + .create_publication(&publication_name, &[table_name.clone()]) + .await + .unwrap(); + + let store = NotifyingStore::new(); + let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); + + let pipeline_id: PipelineId = random(); + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + publication_name.clone(), + store.clone(), + destination.clone(), + ); + + // We wait for sync done so that we have the apply worker dealing with events, this is the common + // testing condition which ensures that the table is ready to be streamed from the main apply worker. + // + // The rationale for wanting to test ETL mainly on the apply worker is that it's really hard to test + // ETL in a state before `SyncDone` since the system will advance on its own. To properly test all + // the table sync worker states, we would need a way to programmatically drive execution, but we deemed + // it too much work compared to the benefit it brings. + let sync_done = store + .notify_on_table_state_type(table_id, TableReplicationPhaseType::SyncDone) + .await; + + pipeline.start().await.unwrap(); + + sync_done.notified().await; + + ( + database, + table_name, + table_id, + store, + destination, + pipeline, + pipeline_id, + publication_name, + ) +} diff --git a/etl/src/test_utils/schema.rs b/etl/src/test_utils/schema.rs new file mode 100644 index 000000000..237b7f1a2 --- /dev/null +++ b/etl/src/test_utils/schema.rs @@ -0,0 +1,231 @@ +use crate::types::Type; +use etl_postgres::types::{ColumnSchema, ReplicatedTableSchema, SnapshotId, TableSchema}; + +/// Asserts that two column schemas are equal. +pub fn assert_column_schema_eq(actual: &ColumnSchema, expected: &ColumnSchema) { + assert_eq!( + actual.name, expected.name, + "column name mismatch: got '{}', expected '{}'", + actual.name, expected.name + ); + assert_eq!( + actual.typ, expected.typ, + "column '{}' type mismatch: got {:?}, expected {:?}", + actual.name, actual.typ, expected.typ + ); + assert_eq!( + actual.modifier, expected.modifier, + "column '{}' modifier mismatch: got {}, expected {}", + actual.name, actual.modifier, expected.modifier + ); + assert_eq!( + actual.nullable, expected.nullable, + "column '{}' nullable mismatch: got {}, expected {}", + actual.name, actual.nullable, expected.nullable + ); + assert_eq!( + actual.primary_key(), + expected.primary_key(), + "column '{}' primary_key mismatch: got {}, expected {}", + actual.name, + actual.primary_key(), + expected.primary_key() + ); +} + +/// Asserts that a column has the expected name and type. +pub fn assert_column_name_type(column: &ColumnSchema, expected_name: &str, expected_type: &Type) { + assert_eq!( + column.name, expected_name, + "column name mismatch: got '{}', expected '{expected_name}'", + column.name + ); + assert_eq!( + &column.typ, expected_type, + "column '{expected_name}' type mismatch: got {:?}, expected {expected_type:?}", + column.typ + ); +} + +/// Asserts that a column has the expected name. +pub fn assert_column_name(column: &ColumnSchema, expected_name: &str) { + assert_eq!( + column.name, expected_name, + "column name mismatch: got '{}', expected '{expected_name}'", + column.name + ); +} + +/// Asserts that columns match the expected column schemas. +pub fn assert_columns_eq<'a>( + columns: impl Iterator, + expected_columns: &[ColumnSchema], +) { + let columns: Vec<_> = columns.collect(); + assert_eq!( + columns.len(), + expected_columns.len(), + "column count mismatch: got {}, expected {}", + columns.len(), + expected_columns.len() + ); + + for (actual, expected) in columns.iter().zip(expected_columns.iter()) { + assert_column_schema_eq(actual, expected); + } +} + +/// Asserts that columns have the expected names and types. +pub fn assert_columns_names_types<'a>( + columns: impl Iterator, + expected_columns: &[(&str, Type)], +) { + let columns: Vec<_> = columns.collect(); + assert_eq!( + columns.len(), + expected_columns.len(), + "column count mismatch: got {}, expected {}", + columns.len(), + expected_columns.len() + ); + + for (i, (actual, (expected_name, expected_type))) in + columns.iter().zip(expected_columns.iter()).enumerate() + { + assert_eq!( + actual.name, *expected_name, + "column name mismatch at index {i}: got '{}', expected '{expected_name}'", + actual.name + ); + assert_eq!( + actual.typ, *expected_type, + "column '{expected_name}' type mismatch at index {i}: got {:?}, expected {expected_type:?}", + actual.typ + ); + } +} + +/// Asserts that columns have the expected names. +pub fn assert_columns_names<'a>( + columns: impl Iterator, + expected_names: &[&str], +) { + let columns: Vec<_> = columns.collect(); + assert_eq!( + columns.len(), + expected_names.len(), + "column count mismatch: got {}, expected {}", + columns.len(), + expected_names.len() + ); + + for (i, (actual, expected_name)) in columns.iter().zip(expected_names.iter()).enumerate() { + assert_eq!( + actual.name, *expected_name, + "column name mismatch at index {i}: got '{}', expected '{expected_name}'", + actual.name + ); + } +} + +/// Asserts that a table schema has columns matching the expected column schemas. +pub fn assert_table_schema_columns(schema: &TableSchema, expected_columns: &[ColumnSchema]) { + assert_columns_eq(schema.column_schemas.iter(), expected_columns); +} + +/// Asserts that a table schema has columns with the expected names and types. +pub fn assert_table_schema_column_names_types( + schema: &TableSchema, + expected_columns: &[(&str, Type)], +) { + assert_columns_names_types(schema.column_schemas.iter(), expected_columns); +} + +/// Asserts that a table schema has columns with the expected names. +pub fn assert_table_schema_column_names(schema: &TableSchema, expected_names: &[&str]) { + assert_columns_names(schema.column_schemas.iter(), expected_names); +} + +/// Asserts that a replicated table schema has columns matching the expected column schemas, +/// and that all columns are replicated. +pub fn assert_replicated_schema_columns( + schema: &ReplicatedTableSchema, + expected_columns: &[ColumnSchema], +) { + assert_columns_eq(schema.column_schemas(), expected_columns); + assert_all_columns_replicated(schema, expected_columns.len()); +} + +/// Asserts that a replicated table schema has columns with the expected names and types, +/// and that all columns are replicated. +pub fn assert_replicated_schema_column_names_types( + schema: &ReplicatedTableSchema, + expected_columns: &[(&str, Type)], +) { + assert_columns_names_types(schema.column_schemas(), expected_columns); + assert_all_columns_replicated(schema, expected_columns.len()); +} + +/// Asserts that a replicated table schema has columns with the expected names, +/// and that all columns are replicated. +pub fn assert_replicated_schema_column_names( + schema: &ReplicatedTableSchema, + expected_names: &[&str], +) { + assert_columns_names(schema.column_schemas(), expected_names); + assert_all_columns_replicated(schema, expected_names.len()); +} + +/// Asserts that all columns in the replication mask are set to 1. +fn assert_all_columns_replicated(schema: &ReplicatedTableSchema, expected_len: usize) { + let mask = schema.replication_mask().as_slice(); + assert_eq!( + mask.len(), + expected_len, + "replication mask length mismatch: got {}, expected {}", + mask.len(), + expected_len + ); + assert!( + mask.iter().all(|&bit| bit == 1), + "expected all columns to be replicated, but mask is {mask:?}" + ); +} + +/// Asserts that schema snapshots are in strictly increasing order by snapshot ID. +/// +/// If `first_is_zero` is true, the first snapshot ID must be 0. +/// If `first_is_zero` is false, the first snapshot ID must be > 0. +/// Each subsequent snapshot ID must be strictly greater than the previous one. +pub fn assert_schema_snapshots_ordering( + snapshots: &[(SnapshotId, TableSchema)], + first_is_zero: bool, +) { + assert!( + !snapshots.is_empty(), + "expected at least one schema snapshot" + ); + + let (first_snapshot_id, _) = &snapshots[0]; + if first_is_zero { + assert_eq!( + *first_snapshot_id, + SnapshotId::initial(), + "first snapshot_id is {first_snapshot_id}, expected 0" + ); + } else { + assert!( + *first_snapshot_id > SnapshotId::initial(), + "first snapshot_id is {first_snapshot_id}, expected > 0" + ); + } + + for i in 1..snapshots.len() { + let (prev_snapshot_id, _) = &snapshots[i - 1]; + let (snapshot_id, _) = &snapshots[i]; + assert!( + *snapshot_id > *prev_snapshot_id, + "snapshot at index {i} has snapshot_id {snapshot_id} which is not greater than previous snapshot_id {prev_snapshot_id}" + ); + } +} diff --git a/etl/src/test_utils/table.rs b/etl/src/test_utils/table.rs deleted file mode 100644 index 15e8b05d1..000000000 --- a/etl/src/test_utils/table.rs +++ /dev/null @@ -1,34 +0,0 @@ -use etl_postgres::types::{ColumnSchema, TableId, TableName, TableSchema}; -use std::collections::HashMap; - -/// Asserts that a table schema matches the expected schema. -/// -/// Compares all aspects of the table schema including table ID, name, and column -/// definitions. Each column's properties (name, type, modifier, nullability, and -/// primary key status) are verified. -/// -/// # Panics -/// -/// Panics if the table ID doesn't exist in the provided schemas, or if any aspect -/// of the schema doesn't match the expected values. -pub fn assert_table_schema( - table_schemas: &HashMap, - table_id: TableId, - expected_table_name: TableName, - expected_columns: &[ColumnSchema], -) { - let table_schema = table_schemas.get(&table_id).unwrap(); - assert_eq!(table_schema.id, table_id); - assert_eq!(table_schema.name, expected_table_name); - - let columns = &table_schema.column_schemas; - assert_eq!(columns.len(), expected_columns.len()); - - for (actual, expected) in columns.iter().zip(expected_columns.iter()) { - assert_eq!(actual.name, expected.name); - assert_eq!(actual.typ, expected.typ); - assert_eq!(actual.modifier, expected.modifier); - assert_eq!(actual.nullable, expected.nullable); - assert_eq!(actual.primary, expected.primary); - } -} diff --git a/etl/src/test_utils/test_destination_wrapper.rs b/etl/src/test_utils/test_destination_wrapper.rs index 9f841ee76..e055d99f0 100644 --- a/etl/src/test_utils/test_destination_wrapper.rs +++ b/etl/src/test_utils/test_destination_wrapper.rs @@ -1,4 +1,4 @@ -use etl_postgres::types::TableId; +use etl_postgres::types::{ReplicatedTableSchema, TableId}; use std::collections::HashMap; use std::fmt; use std::sync::Arc; @@ -158,28 +158,30 @@ where "wrapper" } - async fn truncate_table(&self, table_id: TableId) -> EtlResult<()> { + async fn truncate_table( + &self, + replicated_table_schema: &ReplicatedTableSchema, + ) -> EtlResult<()> { let destination = { let inner = self.inner.read().await; inner.wrapped_destination.clone() }; - let result = destination.truncate_table(table_id).await; + let result = destination.truncate_table(replicated_table_schema).await; let mut inner = self.inner.write().await; + let table_id = replicated_table_schema.id(); inner.table_rows.remove(&table_id); inner.events.retain_mut(|event| { let has_table_id = event.has_table_id(&table_id); - if let Event::Truncate(event) = event + if let Event::Truncate(truncate_event) = event && has_table_id { - let Some(index) = event.rel_ids.iter().position(|&id| table_id.0 == id) else { - return true; - }; - - event.rel_ids.remove(index); - if event.rel_ids.is_empty() { + truncate_event + .truncated_tables + .retain(|s| s.id() != table_id); + if truncate_event.truncated_tables.is_empty() { return false; } @@ -194,7 +196,7 @@ where async fn write_table_rows( &self, - table_id: TableId, + replicated_table_schema: &ReplicatedTableSchema, table_rows: Vec, ) -> EtlResult<()> { let destination = { @@ -204,10 +206,11 @@ where }; let result = destination - .write_table_rows(table_id, table_rows.clone()) + .write_table_rows(replicated_table_schema, table_rows.clone()) .await; { + let table_id = replicated_table_schema.id(); let mut inner = self.inner.write().await; if result.is_ok() { inner diff --git a/etl/src/test_utils/test_schema.rs b/etl/src/test_utils/test_schema.rs index c5a5b7ced..a6f501fb4 100644 --- a/etl/src/test_utils/test_schema.rs +++ b/etl/src/test_utils/test_schema.rs @@ -1,6 +1,10 @@ use etl_postgres::tokio::test_utils::{PgDatabase, id_column_schema}; -use etl_postgres::types::{ColumnSchema, TableId, TableName, TableSchema}; +use etl_postgres::types::{ + ColumnSchema, ReplicatedTableSchema, ReplicationMask, TableId, TableName, TableSchema, +}; +use std::collections::HashSet; use std::ops::RangeInclusive; +use std::sync::Arc; use tokio_postgres::types::{PgLsn, Type}; use tokio_postgres::{Client, GenericClient}; @@ -8,6 +12,24 @@ use crate::test_utils::database::{TEST_DATABASE_SCHEMA, test_table_name}; use crate::test_utils::test_destination_wrapper::TestDestinationWrapper; use crate::types::{Cell, Event, InsertEvent, TableRow}; +/// Creates a test column schema with sensible defaults. +fn test_column( + name: &str, + typ: Type, + ordinal_position: i32, + nullable: bool, + primary_key: bool, +) -> ColumnSchema { + ColumnSchema::new( + name.to_string(), + typ, + -1, + ordinal_position, + if primary_key { Some(1) } else { None }, + nullable, + ) +} + #[derive(Debug, Clone, Copy)] pub enum TableSelection { Both, @@ -70,20 +92,8 @@ pub async fn setup_test_database_schema( users_table_name, vec![ id_column_schema(), - ColumnSchema { - name: "name".to_string(), - typ: Type::TEXT, - modifier: -1, - nullable: false, - primary: false, - }, - ColumnSchema { - name: "age".to_string(), - typ: Type::INT4, - modifier: -1, - nullable: false, - primary: false, - }, + test_column("name", Type::TEXT, 2, false, false), + test_column("age", Type::INT4, 3, false, false), ], )); } @@ -106,13 +116,7 @@ pub async fn setup_test_database_schema( orders_table_name, vec![ id_column_schema(), - ColumnSchema { - name: "description".to_string(), - typ: Type::TEXT, - modifier: -1, - nullable: false, - primary: false, - }, + test_column("description", Type::TEXT, 2, false, false), ], )); } @@ -301,19 +305,28 @@ pub fn events_equal_excluding_fields(left: &Event, right: &Event) -> bool { && left.timestamp == right.timestamp } (Event::Insert(left), Event::Insert(right)) => { - left.table_id == right.table_id && left.table_row == right.table_row + left.replicated_table_schema.id() == right.replicated_table_schema.id() + && left.table_row == right.table_row } (Event::Update(left), Event::Update(right)) => { - left.table_id == right.table_id + left.replicated_table_schema.id() == right.replicated_table_schema.id() && left.table_row == right.table_row && left.old_table_row == right.old_table_row } (Event::Delete(left), Event::Delete(right)) => { - left.table_id == right.table_id && left.old_table_row == right.old_table_row + left.replicated_table_schema.id() == right.replicated_table_schema.id() + && left.old_table_row == right.old_table_row } - (Event::Relation(left), Event::Relation(right)) => left.table_schema == right.table_schema, (Event::Truncate(left), Event::Truncate(right)) => { - left.options == right.options && left.rel_ids == right.rel_ids + if left.options != right.options + || left.truncated_tables.len() != right.truncated_tables.len() + { + return false; + } + // Compare table IDs of truncated tables + let left_ids: Vec<_> = left.truncated_tables.iter().map(|s| s.id()).collect(); + let right_ids: Vec<_> = right.truncated_tables.iter().map(|s| s.id()).collect(); + left_ids == right_ids } (Event::Unsupported, Event::Unsupported) => true, _ => false, // Different event types @@ -322,16 +335,27 @@ pub fn events_equal_excluding_fields(left: &Event, right: &Event) -> bool { pub fn build_expected_users_inserts( mut starting_id: i64, - users_table_id: TableId, + users_table_schema: &TableSchema, expected_rows: Vec<(&str, i32)>, ) -> Vec { let mut events = Vec::new(); + // We build the replicated table schema with a mask for all columns. + let users_table_column_names = users_table_schema + .column_schemas + .iter() + .map(|c| c.name.clone()) + .collect::>(); + let replicated_table_schema = ReplicatedTableSchema::from_mask( + Arc::new(users_table_schema.clone()), + ReplicationMask::build_or_all(users_table_schema, &users_table_column_names), + ); + for (name, age) in expected_rows { events.push(Event::Insert(InsertEvent { start_lsn: PgLsn::from(0), commit_lsn: PgLsn::from(0), - table_id: users_table_id, + replicated_table_schema: replicated_table_schema.clone(), table_row: TableRow { values: vec![ Cell::I64(starting_id), @@ -349,16 +373,27 @@ pub fn build_expected_users_inserts( pub fn build_expected_orders_inserts( mut starting_id: i64, - orders_table_id: TableId, + orders_table_schema: &TableSchema, expected_rows: Vec<&str>, ) -> Vec { let mut events = Vec::new(); + // We build the replicated table schema with a mask for all columns. + let orders_table_column_names = orders_table_schema + .column_schemas + .iter() + .map(|c| c.name.clone()) + .collect::>(); + let replicated_table_schema = ReplicatedTableSchema::from_mask( + Arc::new(orders_table_schema.clone()), + ReplicationMask::build_or_all(orders_table_schema, &orders_table_column_names), + ); + for name in expected_rows { events.push(Event::Insert(InsertEvent { start_lsn: PgLsn::from(0), commit_lsn: PgLsn::from(0), - table_id: orders_table_id, + replicated_table_schema: replicated_table_schema.clone(), table_row: TableRow { values: vec![Cell::I64(starting_id), Cell::String(name.to_owned())], }, diff --git a/etl/src/types/event.rs b/etl/src/types/event.rs index f0d8de737..6559922e0 100644 --- a/etl/src/types/event.rs +++ b/etl/src/types/event.rs @@ -1,4 +1,4 @@ -use etl_postgres::types::{TableId, TableSchema}; +use etl_postgres::types::{ReplicatedTableSchema, TableId}; use std::fmt; use tokio_postgres::types::PgLsn; @@ -40,33 +40,18 @@ pub struct CommitEvent { pub timestamp: i64, } -/// Table schema definition event from Postgres logical replication. -/// -/// [`RelationEvent`] provides schema information for tables involved in replication. -/// It contains complete column definitions and metadata needed to interpret -/// subsequent data modification events for the table. -#[derive(Debug, Clone, PartialEq)] -pub struct RelationEvent { - /// LSN position where the event started. - pub start_lsn: PgLsn, - /// LSN position where the transaction of this event will commit. - pub commit_lsn: PgLsn, - /// Complete table schema including columns and types. - pub table_schema: TableSchema, -} - /// Row insertion event from Postgres logical replication. /// /// [`InsertEvent`] represents a new row being added to a table. It contains /// the complete row data for insertion into the destination system. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct InsertEvent { /// LSN position where the event started. pub start_lsn: PgLsn, /// LSN position where the transaction of this event will commit. pub commit_lsn: PgLsn, - /// ID of the table where the row was inserted. - pub table_id: TableId, + /// The replicated table schema for this event. + pub replicated_table_schema: ReplicatedTableSchema, /// Complete row data for the inserted row. pub table_row: TableRow, } @@ -76,14 +61,14 @@ pub struct InsertEvent { /// [`UpdateEvent`] represents an existing row being modified. It contains /// both the new row data and optionally the old row data for comparison /// and conflict resolution in the destination system. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct UpdateEvent { /// LSN position where the event started. pub start_lsn: PgLsn, /// LSN position where the transaction of this event will commit. pub commit_lsn: PgLsn, - /// ID of the table where the row was updated. - pub table_id: TableId, + /// The replicated table schema for this event. + pub replicated_table_schema: ReplicatedTableSchema, /// New row data after the update. pub table_row: TableRow, /// Previous row data before the update. @@ -98,14 +83,14 @@ pub struct UpdateEvent { /// /// [`DeleteEvent`] represents a row being removed from a table. It contains /// information about the deleted row for proper cleanup in the destination system. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct DeleteEvent { /// LSN position where the event started. pub start_lsn: PgLsn, /// LSN position where the transaction of this event will commit. pub commit_lsn: PgLsn, - /// ID of the table where the row was deleted. - pub table_id: TableId, + /// The replicated table schema for this event. + pub replicated_table_schema: ReplicatedTableSchema, /// Data from the deleted row. /// /// The boolean indicates whether the row contains only key columns (`true`) @@ -119,7 +104,7 @@ pub struct DeleteEvent { /// [`TruncateEvent`] represents one or more tables being truncated (all rows deleted). /// This is a bulk operation that clears entire tables and may affect multiple tables /// in a single operation when using cascading truncates. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct TruncateEvent { /// LSN position where the event started. pub start_lsn: PgLsn, @@ -127,8 +112,24 @@ pub struct TruncateEvent { pub commit_lsn: PgLsn, /// Truncate operation options from Postgres. pub options: i8, - /// List of table IDs that were truncated in this operation. - pub rel_ids: Vec, + /// List of schemas for tables that were truncated in this operation. + pub truncated_tables: Vec, +} + +/// Relation (schema) event from Postgres logical replication. +/// +/// [`RelationEvent`] represents a table schema notification in the replication stream. +/// It is emitted when a RELATION message is received, containing the current +/// replication mask for the table. This event notifies downstream consumers +/// about which columns are being replicated for a table. +#[derive(Debug, Clone)] +pub struct RelationEvent { + /// LSN position where the event started. + pub start_lsn: PgLsn, + /// LSN position where the transaction of this event will commit. + pub commit_lsn: PgLsn, + /// The replicated table schema containing the table schema and replication mask. + pub replicated_table_schema: ReplicatedTableSchema, } /// Represents a single replication event from Postgres logical replication. @@ -136,7 +137,7 @@ pub struct TruncateEvent { /// [`Event`] encapsulates all possible events that can occur in a Postgres replication /// stream, including data modification events and transaction control events. Each event /// type corresponds to specific operations in the source database. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub enum Event { /// Transaction begin event marking the start of a new transaction. Begin(BeginEvent), @@ -148,10 +149,10 @@ pub enum Event { Update(UpdateEvent), /// Row deletion event with deleted row data. Delete(DeleteEvent), - /// Relation schema information event describing table structure. - Relation(RelationEvent), /// Table truncation event clearing all rows from tables. Truncate(TruncateEvent), + /// Relation (schema) event notifying about table schema and replication mask. + Relation(RelationEvent), /// Unsupported event type that cannot be processed. Unsupported, } @@ -172,11 +173,11 @@ impl Event { /// specific tables and will always return false. pub fn has_table_id(&self, table_id: &TableId) -> bool { match self { - Event::Insert(insert_event) => insert_event.table_id == *table_id, - Event::Update(update_event) => update_event.table_id == *table_id, - Event::Delete(delete_event) => delete_event.table_id == *table_id, - Event::Relation(relation_event) => relation_event.table_schema.id == *table_id, - Event::Truncate(event) => event.rel_ids.contains(&table_id.0), + Event::Insert(e) => e.replicated_table_schema.id() == *table_id, + Event::Update(e) => e.replicated_table_schema.id() == *table_id, + Event::Delete(e) => e.replicated_table_schema.id() == *table_id, + Event::Truncate(e) => e.truncated_tables.iter().any(|s| s.id() == *table_id), + Event::Relation(e) => e.replicated_table_schema.id() == *table_id, _ => false, } } @@ -230,8 +231,8 @@ impl From<&Event> for EventType { Event::Insert(_) => EventType::Insert, Event::Update(_) => EventType::Update, Event::Delete(_) => EventType::Delete, - Event::Relation(_) => EventType::Relation, Event::Truncate(_) => EventType::Truncate, + Event::Relation(_) => EventType::Relation, Event::Unsupported => EventType::Unsupported, } } diff --git a/etl/src/workers/apply.rs b/etl/src/workers/apply.rs index 26acdb00a..312bc0e48 100644 --- a/etl/src/workers/apply.rs +++ b/etl/src/workers/apply.rs @@ -19,9 +19,8 @@ use crate::etl_error; use crate::replication::apply::{ApplyLoopAction, ApplyLoopHook, start_apply_loop}; use crate::replication::client::{GetOrCreateSlotResult, PgReplicationClient}; use crate::replication::common::get_active_table_replication_states; -use crate::state::table::{ - TableReplicationError, TableReplicationPhase, TableReplicationPhaseType, -}; +use crate::replication::masks::ReplicationMasks; +use crate::state::table::{TableReplicationPhase, TableReplicationPhaseType}; use crate::store::schema::SchemaStore; use crate::store::state::StateStore; use crate::types::PipelineId; @@ -81,6 +80,7 @@ pub struct ApplyWorker { pool: TableSyncWorkerPool, store: S, destination: D, + replication_masks: ReplicationMasks, shutdown_rx: ShutdownRx, table_sync_worker_permits: Arc, } @@ -99,6 +99,7 @@ impl ApplyWorker { pool: TableSyncWorkerPool, store: S, destination: D, + replication_masks: ReplicationMasks, shutdown_rx: ShutdownRx, table_sync_worker_permits: Arc, ) -> Self { @@ -109,6 +110,7 @@ impl ApplyWorker { pool, store, destination, + replication_masks, shutdown_rx, table_sync_worker_permits, } @@ -142,7 +144,7 @@ where // We create the signal used to notify the apply worker that it should force syncing tables. let (force_syncing_tables_tx, force_syncing_tables_rx) = create_signal(); - start_apply_loop( + let result = start_apply_loop( self.pipeline_id, start_lsn, self.config.clone(), @@ -155,18 +157,28 @@ where self.pool, self.store, self.destination, + self.replication_masks.clone(), self.shutdown_rx.clone(), force_syncing_tables_tx, self.table_sync_worker_permits.clone(), ), + self.replication_masks, self.shutdown_rx, Some(force_syncing_tables_rx), ) - .await?; + .await; - info!("apply worker completed successfully"); - - Ok(()) + match result { + Ok(_) => { + info!("apply worker completed successfully"); + Ok(()) + } + Err(err) => { + // We log the error here, this way it's logged even if the worker is not awaited. + error!("apply worker failed: {}", err); + Err(err) + } + } } .instrument(apply_worker_span.or_current()); @@ -271,6 +283,8 @@ struct ApplyWorkerHook { store: S, /// Destination where replicated data is written. destination: D, + /// Shared replication masks container for tracking column replication status. + replication_masks: ReplicationMasks, /// Shutdown signal receiver for graceful termination. shutdown_rx: ShutdownRx, /// Signal transmitter for triggering table sync operations. @@ -291,6 +305,7 @@ impl ApplyWorkerHook { pool: TableSyncWorkerPool, store: S, destination: D, + replication_masks: ReplicationMasks, shutdown_rx: ShutdownRx, force_syncing_tables_tx: SignalTx, table_sync_worker_permits: Arc, @@ -301,6 +316,7 @@ impl ApplyWorkerHook { pool, store, destination, + replication_masks, shutdown_rx, force_syncing_tables_tx, table_sync_worker_permits, @@ -328,6 +344,7 @@ where table_id, self.store.clone(), self.destination.clone(), + self.replication_masks.clone(), self.shutdown_rx.clone(), self.force_syncing_tables_tx.clone(), self.table_sync_worker_permits.clone(), @@ -473,7 +490,7 @@ where /// Processes all tables currently in synchronization phases. /// /// This method coordinates the lifecycle of syncing tables by promoting - /// `SyncDone` tables to `Ready` state when the apply worker catches up + /// `SyncDone` tables to the `Ready` state when the apply worker catches up /// to their sync LSN. For other tables, it handles the typical sync process. async fn process_syncing_tables( &self, @@ -529,33 +546,6 @@ where Ok(ApplyLoopAction::Continue) } - /// Handles table replication errors by updating the table's state. - /// - /// This method processes errors that occur during table replication by - /// converting them to appropriate error states and persisting the updated - /// state. The apply loop continues processing other tables after handling - /// the error. - async fn mark_table_errored( - &self, - table_replication_error: TableReplicationError, - ) -> EtlResult { - let pool = self.pool.lock().await; - - // Convert the table replication error directly to a phase. - let table_id = table_replication_error.table_id(); - TableSyncWorkerState::set_and_store( - &pool, - &self.store, - table_id, - table_replication_error.into(), - ) - .await?; - - // We want to always continue the loop, since we have to deal with the events of other - // tables. - Ok(ApplyLoopAction::Continue) - } - /// Determines whether changes should be applied for a given table. /// /// This method evaluates the table's replication state to decide if events diff --git a/etl/src/workers/table_sync.rs b/etl/src/workers/table_sync.rs index 30dfa873a..529a1b33a 100644 --- a/etl/src/workers/table_sync.rs +++ b/etl/src/workers/table_sync.rs @@ -19,6 +19,7 @@ use crate::replication::apply::{ ApplyLoopAction, ApplyLoopHook, ApplyLoopResult, start_apply_loop, }; use crate::replication::client::PgReplicationClient; +use crate::replication::masks::ReplicationMasks; use crate::replication::table_sync::{TableSyncResult, start_table_sync}; use crate::state::table::{ RetryPolicy, TableReplicationError, TableReplicationPhase, TableReplicationPhaseType, @@ -364,6 +365,7 @@ pub struct TableSyncWorker { table_id: TableId, store: S, destination: D, + replication_masks: ReplicationMasks, shutdown_rx: ShutdownRx, force_syncing_tables_tx: SignalTx, run_permit: Arc, @@ -383,6 +385,7 @@ impl TableSyncWorker { table_id: TableId, store: S, destination: D, + replication_masks: ReplicationMasks, shutdown_rx: ShutdownRx, force_syncing_tables_tx: SignalTx, run_permit: Arc, @@ -394,6 +397,7 @@ impl TableSyncWorker { table_id, store, destination, + replication_masks, shutdown_rx, force_syncing_tables_tx, run_permit, @@ -430,6 +434,7 @@ where // Clone all the fields we need for retries. let pipeline_id = self.pipeline_id; let destination = self.destination.clone(); + let replication_masks = self.replication_masks.clone(); let shutdown_rx = self.shutdown_rx.clone(); let force_syncing_tables_tx = self.force_syncing_tables_tx.clone(); let run_permit = self.run_permit.clone(); @@ -443,6 +448,7 @@ where table_id, store: store.clone(), destination: destination.clone(), + replication_masks: replication_masks.clone(), shutdown_rx: shutdown_rx.clone(), force_syncing_tables_tx: force_syncing_tables_tx.clone(), run_permit: run_permit.clone(), @@ -621,6 +627,7 @@ where state.clone(), self.store.clone(), self.destination.clone(), + &self.replication_masks, self.shutdown_rx.clone(), self.force_syncing_tables_tx, ) @@ -645,6 +652,7 @@ where self.store.clone(), self.destination, TableSyncWorkerHook::new(self.table_id, state, self.store), + self.replication_masks, self.shutdown_rx, None, ) @@ -865,33 +873,6 @@ where self.try_advance_phase(current_lsn, update_state).await } - /// Handles table replication errors for the table sync worker. - /// - /// This method processes errors specific to the table this worker manages. - /// If the error relates to this worker's table, it updates the state and - /// signals the worker to terminate. Errors for other tables are ignored. - async fn mark_table_errored( - &self, - table_replication_error: TableReplicationError, - ) -> EtlResult { - if self.table_id != table_replication_error.table_id() { - // If the table is different from the one handled by this table sync worker, marking - // the table will be a noop, and we want to continue the loop. - return Ok(ApplyLoopAction::Continue); - } - - // Since we already have access to the table sync worker state, we can avoid going through - // the pool, and we just modify the state here and also update the state store. - let mut inner = self.table_sync_worker_state.lock().await; - inner - .set_and_store(table_replication_error.into(), &self.state_store) - .await?; - - // If a table is marked as errored, this worker should stop processing immediately since there - // is no need to continue, and for this we mark the loop as completed. - Ok(ApplyLoopAction::Complete) - } - /// Determines whether changes should be applied for the given table. /// /// For table sync workers, changes are only applied if the table matches diff --git a/etl/tests/failpoints_pipeline.rs b/etl/tests/failpoints_pipeline.rs deleted file mode 100644 index 0f597b64d..000000000 --- a/etl/tests/failpoints_pipeline.rs +++ /dev/null @@ -1,279 +0,0 @@ -#![cfg(all(feature = "test-utils", feature = "failpoints"))] - -use etl::destination::memory::MemoryDestination; -use etl::error::ErrorKind; -use etl::failpoints::{ - START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION, START_TABLE_SYNC_DURING_DATA_SYNC, -}; -use etl::state::table::{RetryPolicy, TableReplicationPhase, TableReplicationPhaseType}; -use etl::test_utils::database::spawn_source_database; -use etl::test_utils::notify::NotifyingStore; -use etl::test_utils::pipeline::create_pipeline; -use etl::test_utils::test_destination_wrapper::TestDestinationWrapper; -use etl::test_utils::test_schema::{TableSelection, insert_users_data, setup_test_database_schema}; -use etl::types::PipelineId; -use etl_telemetry::tracing::init_test_tracing; -use fail::FailScenario; -use rand::random; - -#[tokio::test(flavor = "multi_thread")] -async fn table_copy_fails_after_data_sync_threw_an_error_with_no_retry() { - let _scenario = FailScenario::setup(); - fail::cfg( - START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION, - "1*return(no_retry)", - ) - .unwrap(); - - init_test_tracing(); - - let mut database = spawn_source_database().await; - let database_schema = setup_test_database_schema(&database, TableSelection::UsersOnly).await; - - // Insert initial test data. - let rows_inserted = 10; - insert_users_data( - &mut database, - &database_schema.users_schema().name, - 1..=rows_inserted, - ) - .await; - - let store = NotifyingStore::new(); - let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); - - // We start the pipeline from scratch. - let pipeline_id: PipelineId = random(); - let mut pipeline = create_pipeline( - &database.config, - pipeline_id, - database_schema.publication_name(), - store.clone(), - destination.clone(), - ); - - // Register notifications for table sync phases. - let users_state_notify = store - .notify_on_table_state_type( - database_schema.users_schema().id, - TableReplicationPhaseType::Errored, - ) - .await; - - pipeline.start().await.unwrap(); - - users_state_notify.notified().await; - - // We expect to have a no retry error which is generated by the failpoint. - let err = pipeline.shutdown_and_wait().await.err().unwrap(); - assert_eq!(err.kinds().len(), 1); - assert_eq!(err.kinds()[0], ErrorKind::WithNoRetry); - - // Verify no data is there. - let table_rows = destination.get_table_rows().await; - assert!(table_rows.is_empty()); - - // Verify table schemas were correctly stored. - let table_schemas = store.get_table_schemas().await; - assert!(table_schemas.is_empty()); -} - -#[tokio::test(flavor = "multi_thread")] -async fn table_copy_fails_after_timed_retry_exceeded_max_attempts() { - let _scenario = FailScenario::setup(); - // Since we have table_error_retry_max_attempts: 2, we want to fail 3 times, so that on the 3rd - // time, the system switches to manual retry. - fail::cfg( - START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION, - "3*return(timed_retry)", - ) - .unwrap(); - - init_test_tracing(); - - let mut database = spawn_source_database().await; - let database_schema = setup_test_database_schema(&database, TableSelection::UsersOnly).await; - - // Insert initial test data. - let rows_inserted = 10; - insert_users_data( - &mut database, - &database_schema.users_schema().name, - 1..=rows_inserted, - ) - .await; - - let store = NotifyingStore::new(); - let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); - - // We start the pipeline from scratch. - let pipeline_id: PipelineId = random(); - let mut pipeline = create_pipeline( - &database.config, - pipeline_id, - database_schema.publication_name(), - store.clone(), - destination.clone(), - ); - - // Register notifications for waiting on the manual retry which is expected to be flipped by the - // max attempts handling. - let users_state_notify = store - .notify_on_table_state(database_schema.users_schema().id, |phase| { - matches!( - phase, - TableReplicationPhase::Errored { - retry_policy: RetryPolicy::ManualRetry, - .. - } - ) - }) - .await; - - pipeline.start().await.unwrap(); - - users_state_notify.notified().await; - - // We expect to still have the timed retry kind since this is the kind of error that we triggered. - let err = pipeline.shutdown_and_wait().await.err().unwrap(); - assert_eq!(err.kinds().len(), 1); - assert_eq!(err.kinds()[0], ErrorKind::WithTimedRetry); - - // Verify no data is there. - let table_rows = destination.get_table_rows().await; - assert!(table_rows.is_empty()); - - // Verify table schemas were correctly stored. - let table_schemas = store.get_table_schemas().await; - assert!(table_schemas.is_empty()); -} - -#[tokio::test(flavor = "multi_thread")] -async fn table_copy_is_consistent_after_data_sync_threw_an_error_with_timed_retry() { - let _scenario = FailScenario::setup(); - fail::cfg( - START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION, - "1*return(timed_retry)", - ) - .unwrap(); - - init_test_tracing(); - - let mut database = spawn_source_database().await; - let database_schema = setup_test_database_schema(&database, TableSelection::UsersOnly).await; - - // Insert initial test data. - let rows_inserted = 10; - insert_users_data( - &mut database, - &database_schema.users_schema().name, - 1..=rows_inserted, - ) - .await; - - let store = NotifyingStore::new(); - let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); - - // We start the pipeline from scratch. - let pipeline_id: PipelineId = random(); - let mut pipeline = create_pipeline( - &database.config, - pipeline_id, - database_schema.publication_name(), - store.clone(), - destination.clone(), - ); - - // We register the interest in waiting for both table syncs to have started. - let users_state_notify = store - .notify_on_table_state_type( - database_schema.users_schema().id, - TableReplicationPhaseType::SyncDone, - ) - .await; - - pipeline.start().await.unwrap(); - - users_state_notify.notified().await; - - // We expect no errors, since the same table sync worker task is retried. - pipeline.shutdown_and_wait().await.unwrap(); - - // Verify copied data. - let table_rows = destination.get_table_rows().await; - let users_table_rows = table_rows.get(&database_schema.users_schema().id).unwrap(); - assert_eq!(users_table_rows.len(), rows_inserted); - - // Verify table schemas were correctly stored. - let table_schemas = store.get_table_schemas().await; - assert_eq!(table_schemas.len(), 1); - assert_eq!( - *table_schemas - .get(&database_schema.users_schema().id) - .unwrap(), - database_schema.users_schema() - ); -} - -#[tokio::test(flavor = "multi_thread")] -async fn table_copy_is_consistent_during_data_sync_threw_an_error_with_timed_retry() { - let _scenario = FailScenario::setup(); - fail::cfg(START_TABLE_SYNC_DURING_DATA_SYNC, "1*return(timed_retry)").unwrap(); - - init_test_tracing(); - - let mut database = spawn_source_database().await; - let database_schema = setup_test_database_schema(&database, TableSelection::UsersOnly).await; - - // Insert initial test data. - let rows_inserted = 10; - insert_users_data( - &mut database, - &database_schema.users_schema().name, - 1..=rows_inserted, - ) - .await; - - let store = NotifyingStore::new(); - let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); - - // We start the pipeline from scratch. - let pipeline_id: PipelineId = random(); - let mut pipeline = create_pipeline( - &database.config, - pipeline_id, - database_schema.publication_name(), - store.clone(), - destination.clone(), - ); - - // We register the interest in waiting for both table syncs to have started. - let users_state_notify = store - .notify_on_table_state_type( - database_schema.users_schema().id, - TableReplicationPhaseType::SyncDone, - ) - .await; - - pipeline.start().await.unwrap(); - - users_state_notify.notified().await; - - // We expect no errors, since the same table sync worker task is retried. - pipeline.shutdown_and_wait().await.unwrap(); - - // Verify copied data. - let table_rows = destination.get_table_rows().await; - let users_table_rows = table_rows.get(&database_schema.users_schema().id).unwrap(); - assert_eq!(users_table_rows.len(), rows_inserted); - - // Verify table schemas were correctly stored. - let table_schemas = store.get_table_schemas().await; - assert_eq!(table_schemas.len(), 1); - assert_eq!( - *table_schemas - .get(&database_schema.users_schema().id) - .unwrap(), - database_schema.users_schema() - ); -} diff --git a/etl/tests/pipeline.rs b/etl/tests/pipeline.rs index 27fcfd83f..804188e3a 100644 --- a/etl/tests/pipeline.rs +++ b/etl/tests/pipeline.rs @@ -7,7 +7,7 @@ use etl::test_utils::database::{spawn_source_database, test_table_name}; use etl::test_utils::event::group_events_by_type_and_table_id; use etl::test_utils::notify::NotifyingStore; use etl::test_utils::pipeline::{create_pipeline, create_pipeline_with}; -use etl::test_utils::table::assert_table_schema; +use etl::test_utils::schema::assert_table_schema_columns; use etl::test_utils::test_destination_wrapper::TestDestinationWrapper; use etl::test_utils::test_schema::{ TableSelection, assert_events_equal, build_expected_orders_inserts, @@ -18,7 +18,7 @@ use etl::types::{Event, EventType, InsertEvent, PipelineId, Type}; use etl_config::shared::BatchConfig; use etl_postgres::below_version; use etl_postgres::replication::slots::EtlReplicationSlot; -use etl_postgres::tokio::test_utils::{TableModification, id_column_schema}; +use etl_postgres::tokio::test_utils::id_column_schema; use etl_postgres::types::ColumnSchema; use etl_postgres::version::POSTGRES_15; use etl_telemetry::tracing::init_test_tracing; @@ -26,6 +26,24 @@ use rand::random; use std::time::Duration; use tokio::time::sleep; +/// Creates a test column schema with sensible defaults. +fn test_column( + name: &str, + typ: Type, + ordinal_position: i32, + nullable: bool, + primary_key: bool, +) -> ColumnSchema { + ColumnSchema::new( + name.to_string(), + typ, + -1, + ordinal_position, + if primary_key { Some(1) } else { None }, + nullable, + ) +} + #[tokio::test(flavor = "multi_thread")] async fn pipeline_fails_when_slot_deleted_with_non_init_tables() { init_test_tracing(); @@ -185,7 +203,7 @@ async fn table_schema_copy_survives_pipeline_restarts() { ); // We check that the table schemas have been stored. - let table_schemas = store.get_table_schemas().await; + let table_schemas = store.get_latest_table_schemas().await; assert_eq!(table_schemas.len(), 2); assert_eq!( *table_schemas @@ -414,12 +432,16 @@ async fn publication_for_all_tables_in_schema_ignores_new_tables_until_restart() return; } - // Create first table. + // Create first table and insert one row. let table_1 = test_table_name("table_1"); let table_1_id = database .create_table(table_1.clone(), true, &[("name", "text not null")]) .await .unwrap(); + database + .insert_values(table_1.clone(), &["name"], &[&"test_name_1".to_owned()]) + .await + .unwrap(); // Create a publication for all tables in the test schema. let publication_name = "test_pub_all_schema"; @@ -448,6 +470,18 @@ async fn publication_for_all_tables_in_schema_ignores_new_tables_until_restart() sync_done.notified().await; + // Wait for an insert event in table 1. + let insert_events_notify = destination + .wait_for_events_count(vec![(EventType::Insert, 1)]) + .await; + + database + .insert_values(table_1.clone(), &["name"], &[&"test_name_2".to_owned()]) + .await + .unwrap(); + + insert_events_notify.notified().await; + // Create a new table in the same schema and insert a row. let table_2 = test_table_name("table_2"); let table_2_id = database @@ -459,18 +493,29 @@ async fn publication_for_all_tables_in_schema_ignores_new_tables_until_restart() .await .unwrap(); - // Wait for the events to come in from the new table. + // Wait for the events to come in from the new table to make sure the pipeline reacts to them + // gracefully even if they are not replicated. sleep(Duration::from_secs(2)).await; // Shutdown and verify no errors occurred. pipeline.shutdown_and_wait().await.unwrap(); // Check that only the schemas of the first table were stored. - let table_schemas = store.get_table_schemas().await; + let table_schemas = store.get_latest_table_schemas().await; assert_eq!(table_schemas.len(), 1); assert!(table_schemas.contains_key(&table_1_id)); assert!(!table_schemas.contains_key(&table_2_id)); + // Verify the table rows and events inserted into table 1. + let table_rows = destination.get_table_rows().await; + assert_eq!(table_rows.get(&table_1_id).unwrap().len(), 1); + let events = destination.get_events().await; + let grouped_events = group_events_by_type_and_table_id(&events); + let insert_events = grouped_events + .get(&(EventType::Insert, table_1_id)) + .unwrap(); + assert_eq!(insert_events.len(), 1); + // We restart the pipeline and verify that the new table is now processed. let mut pipeline = create_pipeline( &database.config, @@ -488,14 +533,39 @@ async fn publication_for_all_tables_in_schema_ignores_new_tables_until_restart() sync_done.notified().await; + // We clear the events to make waiting more idiomatic down the line. + destination.clear_events().await; + + // Wait for an insert event in table 2. + let insert_events_notify = destination + .wait_for_events_count(vec![(EventType::Insert, 1)]) + .await; + + database + .insert_values(table_2.clone(), &["value"], &[&2_i32]) + .await + .unwrap(); + + insert_events_notify.notified().await; + // Shutdown and verify no errors occurred. pipeline.shutdown_and_wait().await.unwrap(); // Check that both schemas exist. - let table_schemas = store.get_table_schemas().await; + let table_schemas = store.get_latest_table_schemas().await; assert_eq!(table_schemas.len(), 2); assert!(table_schemas.contains_key(&table_1_id)); assert!(table_schemas.contains_key(&table_2_id)); + + // Verify the table rows and events inserted into table 2. + let table_rows = destination.get_table_rows().await; + assert_eq!(table_rows.get(&table_2_id).unwrap().len(), 1); + let events = destination.get_events().await; + let grouped_events = group_events_by_type_and_table_id(&events); + let insert_events = grouped_events + .get(&(EventType::Insert, table_2_id)) + .unwrap(); + assert_eq!(insert_events.len(), 1); } #[tokio::test(flavor = "multi_thread")] @@ -709,7 +779,7 @@ async fn table_copy_and_sync_streams_new_data() { // Build expected events for verification let expected_users_inserts = build_expected_users_inserts( 11, - database_schema.users_schema().id, + &database_schema.users_schema(), vec![ ("user_11", 11), ("user_12", 12), @@ -719,7 +789,7 @@ async fn table_copy_and_sync_streams_new_data() { ); let expected_orders_inserts = build_expected_orders_inserts( 11, - database_schema.orders_schema().id, + &database_schema.orders_schema(), vec![ "description_11", "description_12", @@ -817,7 +887,7 @@ async fn table_sync_streams_new_data_with_batch_timeout_expired() { // Build expected events for verification let expected_users_inserts = build_expected_users_inserts( 1, - database_schema.users_schema().id, + &database_schema.users_schema(), vec![ ("user_1", 1), ("user_2", 2), @@ -890,146 +960,6 @@ async fn table_processing_converges_to_apply_loop_with_no_events_coming() { assert_eq!(age_sum, expected_age_sum); } -#[tokio::test(flavor = "multi_thread")] -async fn table_processing_with_schema_change_errors_table() { - init_test_tracing(); - let database = spawn_source_database().await; - let database_schema = setup_test_database_schema(&database, TableSelection::OrdersOnly).await; - - // Insert data in the table. - database - .insert_values( - database_schema.orders_schema().name.clone(), - &["description"], - &[&"description_1"], - ) - .await - .unwrap(); - - let store = NotifyingStore::new(); - let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); - - // Start pipeline from scratch. - let pipeline_id: PipelineId = random(); - let mut pipeline = create_pipeline( - &database.config, - pipeline_id, - database_schema.publication_name(), - store.clone(), - destination.clone(), - ); - - // Register notifications for initial table copy completion. - let orders_state_notify = store - .notify_on_table_state_type( - database_schema.orders_schema().id, - TableReplicationPhaseType::FinishedCopy, - ) - .await; - - pipeline.start().await.unwrap(); - - orders_state_notify.notified().await; - - // Register notification for the sync done state. - let orders_state_notify = store - .notify_on_table_state_type( - database_schema.orders_schema().id, - TableReplicationPhaseType::SyncDone, - ) - .await; - - // Insert new data in the table. - database - .insert_values( - database_schema.orders_schema().name.clone(), - &["description"], - &[&"description_2"], - ) - .await - .unwrap(); - - orders_state_notify.notified().await; - - // Register notification for the ready state. - let orders_state_notify = store - .notify_on_table_state_type( - database_schema.orders_schema().id, - TableReplicationPhaseType::Ready, - ) - .await; - - // Insert new data in the table. - database - .insert_values( - database_schema.orders_schema().name.clone(), - &["description"], - &[&"description_3"], - ) - .await - .unwrap(); - - orders_state_notify.notified().await; - - // Register notification for the errored state. - let orders_state_notify = store - .notify_on_table_state_type( - database_schema.orders_schema().id, - TableReplicationPhaseType::Errored, - ) - .await; - - // Change the schema of orders by adding a new column. - database - .alter_table( - database_schema.orders_schema().name.clone(), - &[TableModification::AddColumn { - name: "date", - data_type: "integer", - }], - ) - .await - .unwrap(); - - // Insert new data in the table. - database - .insert_values( - database_schema.orders_schema().name.clone(), - &["description", "date"], - &[&"description_with_date", &10], - ) - .await - .unwrap(); - - orders_state_notify.notified().await; - - pipeline.shutdown_and_wait().await.unwrap(); - - // We assert that the schema is the initial one. - let table_schemas = store.get_table_schemas().await; - assert_eq!(table_schemas.len(), 1); - assert_eq!( - *table_schemas - .get(&database_schema.orders_schema().id) - .unwrap(), - database_schema.orders_schema() - ); - - // We check that we got the insert events after the first data of the table has been copied. - let events = destination.get_events().await; - let grouped_events = group_events_by_type_and_table_id(&events); - let orders_inserts = grouped_events - .get(&(EventType::Insert, database_schema.orders_schema().id)) - .unwrap(); - - let expected_orders_inserts = build_expected_orders_inserts( - 2, - database_schema.orders_schema().id, - vec!["description_2", "description_3"], - ); - assert_events_equal(orders_inserts, &expected_orders_inserts); -} - #[tokio::test(flavor = "multi_thread")] async fn table_without_primary_key_is_errored() { init_test_tracing(); @@ -1103,7 +1033,7 @@ async fn pipeline_respects_column_level_publication() { return; } - // Create a table with multiple columns including a sensitive 'email' column. + // Create a table with multiple columns. let table_name = test_table_name("users"); let table_id = database .create_table( @@ -1113,12 +1043,13 @@ async fn pipeline_respects_column_level_publication() { ("name", "text not null"), ("age", "integer not null"), ("email", "text not null"), + ("phone", "text not null"), ], ) .await .unwrap(); - // Create publication with only a subset of columns (excluding 'email'). + // Create publication with only a subset of columns. let publication_name = "test_pub".to_string(); database .run_sql(&format!( @@ -1149,15 +1080,15 @@ async fn pipeline_respects_column_level_publication() { sync_done_notify.notified().await; - // Wait for two insert events to be processed. + // Wait for an insert event to be processed. let insert_events_notify = destination - .wait_for_events_count(vec![(EventType::Insert, 2)]) + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) .await; - // Insert test data with all columns (including email). + // Insert test data with all columns (including email and phone). database .run_sql(&format!( - "insert into {} (name, age, email) values ('Alice', 25, 'alice@example.com'), ('Bob', 30, 'bob@example.com')", + "insert into {} (name, age, email, phone) values ('Alice', 25, 'alice@example.com', '555-0001')", table_name.as_quoted_identifier() )) .await @@ -1165,37 +1096,235 @@ async fn pipeline_respects_column_level_publication() { insert_events_notify.notified().await; - pipeline.shutdown_and_wait().await.unwrap(); - // Verify the events and check that only published columns are included. let events = destination.get_events().await; let grouped_events = group_events_by_type_and_table_id(&events); let insert_events = grouped_events.get(&(EventType::Insert, table_id)).unwrap(); - assert_eq!(insert_events.len(), 2); + assert_eq!(insert_events.len(), 1); + + let initial_relation_event = events + .iter() + .rev() + .find_map(|event| match event { + Event::Relation(relation) if relation.replicated_table_schema.id() == table_id => { + Some(relation.clone()) + } + _ => None, + }) + .expect("Expected relation event for initial publication state"); + + let initial_relation_columns: Vec<&str> = initial_relation_event + .replicated_table_schema + .column_schemas() + .map(|c| c.name.as_str()) + .collect(); + assert_eq!(initial_relation_columns, vec!["id", "name", "age"]); + assert_eq!( + initial_relation_event + .replicated_table_schema + .replication_mask() + .as_slice(), + &[1, 1, 1, 0, 0] + ); + assert_eq!( + initial_relation_event + .replicated_table_schema + .get_inner() + .column_schemas + .len(), + 5 + ); - // Check that each insert event contains only the published columns (id, name, age). - // Since Cell values don't include column names, we verify by checking the count. + // Check that each insert event contains only the published columns (id, name, age) and that the + // schema used is correct. for event in insert_events { - if let Event::Insert(InsertEvent { table_row, .. }) = event { + if let Event::Insert(InsertEvent { + replicated_table_schema, + table_row, + .. + }) = event + { // Verify exactly 3 columns (id, name, age). - // If email was included, there would be 4 values. assert_eq!(table_row.values.len(), 3); + + // Get only the replicated column names from the schema + let replicated_column_names: Vec<&str> = replicated_table_schema + .column_schemas() + .map(|c| c.name.as_str()) + .collect(); + assert_eq!(replicated_column_names, vec!["id", "name", "age"]); + + // The underlying full schema has all 5 columns + let full_schema = replicated_table_schema.get_inner(); + assert_eq!(full_schema.column_schemas.len(), 5); } } - // Also verify the stored table schema only includes published columns. - let table_schemas = state_store.get_table_schemas().await; - let stored_schema = table_schemas.get(&table_id).unwrap(); - let column_names: Vec<&str> = stored_schema - .column_schemas + // Clear events and restart pipeline. + destination.clear_events().await; + + // Add email column to publication -> (id, name, age, email). + database + .run_sql(&format!( + "alter publication {publication_name} set table {} (id, name, age, email)", + table_name.as_quoted_identifier() + )) + .await + .unwrap(); + + // Wait for 1 insert event with 4 columns. + let insert_notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + database + .run_sql(&format!( + "insert into {} (name, age, email, phone) values ('Charlie', 35, 'charlie@example.com', '555-0003')", + table_name.as_quoted_identifier() + )) + .await + .unwrap(); + + insert_notify.notified().await; + + // Verify 4 columns arrived (id, name, age, email). + let events = destination.get_events().await; + let grouped = group_events_by_type_and_table_id(&events); + let inserts = grouped.get(&(EventType::Insert, table_id)).unwrap(); + assert_eq!(inserts.len(), 1); + + let relation_after_adding_email = events + .iter() + .rev() + .find_map(|event| match event { + Event::Relation(relation) if relation.replicated_table_schema.id() == table_id => { + Some(relation.clone()) + } + _ => None, + }) + .expect("Expected relation event after adding email to publication"); + + if let Event::Insert(InsertEvent { + replicated_table_schema, + table_row, + .. + }) = &inserts[0] + { + assert_eq!(table_row.values.len(), 4); + let col_names: Vec<&str> = replicated_table_schema + .column_schemas() + .map(|c| c.name.as_str()) + .collect(); + assert_eq!(col_names, vec!["id", "name", "age", "email"]); + } else { + panic!("Expected Insert event"); + } + + let relation_columns: Vec<&str> = relation_after_adding_email + .replicated_table_schema + .column_schemas() + .map(|c| c.name.as_str()) + .collect(); + assert_eq!(relation_columns, vec!["id", "name", "age", "email"]); + assert_eq!( + relation_after_adding_email + .replicated_table_schema + .replication_mask() + .as_slice(), + &[1, 1, 1, 1, 0] + ); + assert_eq!( + relation_after_adding_email + .replicated_table_schema + .get_inner() + .column_schemas + .len(), + 5 + ); + + // Remove age column from publication -> (id, name, email). + database + .run_sql(&format!( + "alter publication {publication_name} set table {} (id, name, email)", + table_name.as_quoted_identifier() + )) + .await + .unwrap(); + + // Clear events and restart pipeline. + destination.clear_events().await; + + // Wait for 1 insert event with 3 columns (different set than before). + let insert_notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + database + .run_sql(&format!( + "insert into {} (name, age, email, phone) values ('Diana', 40, 'diana@example.com', '555-0004')", + table_name.as_quoted_identifier() + )) + .await + .unwrap(); + + insert_notify.notified().await; + + // We shutdown the pipeline. + pipeline.shutdown_and_wait().await.unwrap(); + + // Verify 3 columns arrived (id, name, email) - age and phone excluded. + let events = destination.get_events().await; + let relation_after_removing_age = events .iter() + .rev() + .find_map(|event| match event { + Event::Relation(relation) if relation.replicated_table_schema.id() == table_id => { + Some(relation.clone()) + } + _ => None, + }) + .expect("Expected relation event after removing age from publication"); + let grouped = group_events_by_type_and_table_id(&events); + let inserts = grouped.get(&(EventType::Insert, table_id)).unwrap(); + assert_eq!(inserts.len(), 1); + + if let Event::Insert(InsertEvent { + replicated_table_schema, + table_row, + .. + }) = &inserts[0] + { + assert_eq!(table_row.values.len(), 3); + let col_names: Vec<&str> = replicated_table_schema + .column_schemas() + .map(|c| c.name.as_str()) + .collect(); + assert_eq!(col_names, vec!["id", "name", "email"]); + } else { + panic!("Expected Insert event"); + } + + let relation_columns: Vec<&str> = relation_after_removing_age + .replicated_table_schema + .column_schemas() .map(|c| c.name.as_str()) .collect(); - assert!(column_names.contains(&"id")); - assert!(column_names.contains(&"name")); - assert!(column_names.contains(&"age")); - assert!(!column_names.contains(&"email")); - assert_eq!(stored_schema.column_schemas.len(), 3); + assert_eq!(relation_columns, vec!["id", "name", "email"]); + assert_eq!( + relation_after_removing_age + .replicated_table_schema + .replication_mask() + .as_slice(), + &[1, 1, 0, 1, 0] + ); + assert_eq!( + relation_after_removing_age + .replicated_table_schema + .get_inner() + .column_schemas + .len(), + 5 + ); } #[tokio::test(flavor = "multi_thread")] @@ -1250,27 +1379,16 @@ async fn empty_tables_are_created_at_destination() { pipeline.shutdown_and_wait().await.unwrap(); // Verify the table schema was stored. - let table_schemas = state_store.get_table_schemas().await; - assert_table_schema( - &table_schemas, - table_id, - table_name, + let table_schemas = state_store.get_latest_table_schemas().await; + let table_schema = table_schemas.get(&table_id).unwrap(); + assert_eq!(table_schema.id, table_id); + assert_eq!(table_schema.name, table_name); + assert_table_schema_columns( + table_schema, &[ id_column_schema(), - ColumnSchema { - name: "name".to_string(), - typ: Type::TEXT, - modifier: -1, - nullable: true, - primary: false, - }, - ColumnSchema { - name: "created_at".to_string(), - typ: Type::TIMESTAMP, - modifier: -1, - nullable: true, - primary: false, - }, + test_column("name", Type::TEXT, 2, true, false), + test_column("created_at", Type::TIMESTAMP, 3, true, false), ], ); diff --git a/etl/tests/pipeline_with_failpoints.rs b/etl/tests/pipeline_with_failpoints.rs new file mode 100644 index 000000000..8ced24cec --- /dev/null +++ b/etl/tests/pipeline_with_failpoints.rs @@ -0,0 +1,858 @@ +#![cfg(all(feature = "test-utils", feature = "failpoints"))] + +use etl::destination::memory::MemoryDestination; +use etl::error::ErrorKind; +use etl::failpoints::{ + SEND_STATUS_UPDATE_FP, START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION_FP, + START_TABLE_SYNC_DURING_DATA_SYNC_FP, +}; +use etl::state::table::{RetryPolicy, TableReplicationPhase, TableReplicationPhaseType}; +use etl::test_utils::database::{spawn_source_database, test_table_name}; +use etl::test_utils::event::group_events_by_type_and_table_id; +use etl::test_utils::notify::NotifyingStore; +use etl::test_utils::pipeline::{create_database_and_pipeline_with_table, create_pipeline}; +use etl::test_utils::schema::{ + assert_schema_snapshots_ordering, assert_table_schema_column_names_types, +}; +use etl::test_utils::test_destination_wrapper::TestDestinationWrapper; +use etl::test_utils::test_schema::{TableSelection, insert_users_data, setup_test_database_schema}; +use etl::types::Type; +use etl::types::{Event, EventType, InsertEvent, PipelineId, TableId}; +use etl_postgres::below_version; +use etl_postgres::tokio::test_utils::TableModification; +use etl_postgres::version::POSTGRES_15; +use etl_telemetry::tracing::init_test_tracing; +use fail::FailScenario; +use rand::random; + +#[tokio::test(flavor = "multi_thread")] +async fn table_copy_fails_after_data_sync_threw_an_error_with_no_retry() { + let _scenario = FailScenario::setup(); + fail::cfg( + START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION_FP, + "1*return(no_retry)", + ) + .unwrap(); + + init_test_tracing(); + + let mut database = spawn_source_database().await; + let database_schema = setup_test_database_schema(&database, TableSelection::UsersOnly).await; + + // Insert initial test data. + let rows_inserted = 10; + insert_users_data( + &mut database, + &database_schema.users_schema().name, + 1..=rows_inserted, + ) + .await; + + let store = NotifyingStore::new(); + let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); + + // We start the pipeline from scratch. + let pipeline_id: PipelineId = random(); + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + database_schema.publication_name(), + store.clone(), + destination.clone(), + ); + + // Register notifications for table sync phases. + let users_state_notify = store + .notify_on_table_state_type( + database_schema.users_schema().id, + TableReplicationPhaseType::Errored, + ) + .await; + + pipeline.start().await.unwrap(); + + users_state_notify.notified().await; + + // We expect to have a no retry error which is generated by the failpoint. + let err = pipeline.shutdown_and_wait().await.err().unwrap(); + assert_eq!(err.kinds().len(), 1); + assert_eq!(err.kinds()[0], ErrorKind::WithNoRetry); + + // Verify no data is there. + let table_rows = destination.get_table_rows().await; + assert!(table_rows.is_empty()); + + // Verify table schemas were correctly stored. + let table_schemas = store.get_latest_table_schemas().await; + assert!(table_schemas.is_empty()); +} + +#[tokio::test(flavor = "multi_thread")] +async fn table_copy_fails_after_timed_retry_exceeded_max_attempts() { + let _scenario = FailScenario::setup(); + // Since we have table_error_retry_max_attempts: 2, we want to fail 3 times, so that on the 3rd + // time, the system switches to manual retry. + fail::cfg( + START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION_FP, + "3*return(timed_retry)", + ) + .unwrap(); + + init_test_tracing(); + + let mut database = spawn_source_database().await; + let database_schema = setup_test_database_schema(&database, TableSelection::UsersOnly).await; + + // Insert initial test data. + let rows_inserted = 10; + insert_users_data( + &mut database, + &database_schema.users_schema().name, + 1..=rows_inserted, + ) + .await; + + let store = NotifyingStore::new(); + let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); + + // We start the pipeline from scratch. + let pipeline_id: PipelineId = random(); + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + database_schema.publication_name(), + store.clone(), + destination.clone(), + ); + + // Register notifications for waiting on the manual retry which is expected to be flipped by the + // max attempts handling. + let users_state_notify = store + .notify_on_table_state(database_schema.users_schema().id, |phase| { + matches!( + phase, + TableReplicationPhase::Errored { + retry_policy: RetryPolicy::ManualRetry, + .. + } + ) + }) + .await; + + pipeline.start().await.unwrap(); + + users_state_notify.notified().await; + + // We expect to still have the timed retry kind since this is the kind of error that we triggered. + let err = pipeline.shutdown_and_wait().await.err().unwrap(); + assert_eq!(err.kinds().len(), 1); + assert_eq!(err.kinds()[0], ErrorKind::WithTimedRetry); + + // Verify no data is there. + let table_rows = destination.get_table_rows().await; + assert!(table_rows.is_empty()); + + // Verify table schemas were correctly stored. + let table_schemas = store.get_latest_table_schemas().await; + assert!(table_schemas.is_empty()); +} + +#[tokio::test(flavor = "multi_thread")] +async fn table_copy_is_consistent_after_data_sync_threw_an_error_with_timed_retry() { + let _scenario = FailScenario::setup(); + fail::cfg( + START_TABLE_SYNC_BEFORE_DATA_SYNC_SLOT_CREATION_FP, + "1*return(timed_retry)", + ) + .unwrap(); + + init_test_tracing(); + + let mut database = spawn_source_database().await; + let database_schema = setup_test_database_schema(&database, TableSelection::UsersOnly).await; + + // Insert initial test data. + let rows_inserted = 10; + insert_users_data( + &mut database, + &database_schema.users_schema().name, + 1..=rows_inserted, + ) + .await; + + let store = NotifyingStore::new(); + let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); + + // We start the pipeline from scratch. + let pipeline_id: PipelineId = random(); + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + database_schema.publication_name(), + store.clone(), + destination.clone(), + ); + + // We register the interest in waiting for both table syncs to have started. + let users_state_notify = store + .notify_on_table_state_type( + database_schema.users_schema().id, + TableReplicationPhaseType::SyncDone, + ) + .await; + + pipeline.start().await.unwrap(); + + users_state_notify.notified().await; + + // We expect no errors, since the same table sync worker task is retried. + pipeline.shutdown_and_wait().await.unwrap(); + + // Verify copied data. + let table_rows = destination.get_table_rows().await; + let users_table_rows = table_rows.get(&database_schema.users_schema().id).unwrap(); + assert_eq!(users_table_rows.len(), rows_inserted); + + // Verify table schemas were correctly stored. + let table_schemas = store.get_latest_table_schemas().await; + assert_eq!(table_schemas.len(), 1); + assert_eq!( + *table_schemas + .get(&database_schema.users_schema().id) + .unwrap(), + database_schema.users_schema() + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn table_copy_is_consistent_during_data_sync_threw_an_error_with_timed_retry() { + let _scenario = FailScenario::setup(); + fail::cfg( + START_TABLE_SYNC_DURING_DATA_SYNC_FP, + "1*return(timed_retry)", + ) + .unwrap(); + + init_test_tracing(); + + let mut database = spawn_source_database().await; + let database_schema = setup_test_database_schema(&database, TableSelection::UsersOnly).await; + + // Insert initial test data. + let rows_inserted = 10; + insert_users_data( + &mut database, + &database_schema.users_schema().name, + 1..=rows_inserted, + ) + .await; + + let store = NotifyingStore::new(); + let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); + + // We start the pipeline from scratch. + let pipeline_id: PipelineId = random(); + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + database_schema.publication_name(), + store.clone(), + destination.clone(), + ); + + // We register the interest in waiting for both table syncs to have started. + let users_state_notify = store + .notify_on_table_state_type( + database_schema.users_schema().id, + TableReplicationPhaseType::SyncDone, + ) + .await; + + pipeline.start().await.unwrap(); + + users_state_notify.notified().await; + + // We expect no errors, since the same table sync worker task is retried. + pipeline.shutdown_and_wait().await.unwrap(); + + // Verify copied data. + let table_rows = destination.get_table_rows().await; + let users_table_rows = table_rows.get(&database_schema.users_schema().id).unwrap(); + assert_eq!(users_table_rows.len(), rows_inserted); + + // Verify table schemas were correctly stored. + let table_schemas = store.get_latest_table_schemas().await; + assert_eq!(table_schemas.len(), 1); + assert_eq!( + *table_schemas + .get(&database_schema.users_schema().id) + .unwrap(), + database_schema.users_schema() + ); +} + +#[ignore] +#[tokio::test(flavor = "multi_thread")] +async fn table_schema_snapshots_are_consistent_after_missing_status_update_with_interleaved_ddl() { + let _scenario = FailScenario::setup(); + fail::cfg(SEND_STATUS_UPDATE_FP, "return").unwrap(); + + init_test_tracing(); + + let (database, table_name, table_id, store, destination, pipeline, pipeline_id, publication) = + create_database_and_pipeline_with_table( + "schema_add_column", + &[("name", "text not null"), ("age", "integer not null")], + ) + .await; + + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 2), (EventType::Insert, 2)]) + .await; + + database + .insert_values(table_name.clone(), &["name", "age"], &[&"Alice", &25]) + .await + .unwrap(); + + database + .alter_table( + table_name.clone(), + &[TableModification::AddColumn { + name: "email", + data_type: "text null", + }], + ) + .await + .unwrap(); + + database + .insert_values( + table_name.clone(), + &["name", "age", "email"], + &[&"Bob", &28, &"bob@example.com"], + ) + .await + .unwrap(); + + notify.notified().await; + pipeline.shutdown_and_wait().await.unwrap(); + + // Assert that we got all the events correctly. + let events = destination.get_events().await; + let grouped = group_events_by_type_and_table_id(&events); + + assert_eq!( + grouped.get(&(EventType::Relation, table_id)).unwrap().len(), + 2 + ); + assert_eq!( + grouped.get(&(EventType::Insert, table_id)).unwrap().len(), + 2 + ); + + // Assert that we have 2 schema snapshots stored in order. + let table_schemas = store.get_table_schemas().await; + let table_schemas_snapshots = table_schemas.get(&table_id).unwrap(); + assert_eq!(table_schemas_snapshots.len(), 2); + assert_schema_snapshots_ordering(table_schemas_snapshots, true); + + // Verify the first snapshot has the original schema (id, name, age). + let (_, first_schema) = &table_schemas_snapshots[0]; + assert_table_schema_column_names_types( + first_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ], + ); + + // Verify the second snapshot has the new column added (id, name, age, email). + let (_, second_schema) = &table_schemas_snapshots[1]; + assert_table_schema_column_names_types( + second_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ("email", Type::TEXT), + ], + ); + + // Clear up the events. + destination.clear_events().await; + + // Restart the pipeline with the failpoint disabled to verify recovery. + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + publication, + store.clone(), + destination.clone(), + ); + + pipeline.start().await.unwrap(); + + database + .insert_values( + table_name.clone(), + &["name", "age", "email"], + &[&"Charlie", &35, &"charlie@example.com"], + ) + .await + .unwrap(); + + // TODO: figure out how to wait for errors in the apply worker and remove ignore. + + // We expect to have a corrupted table schema error since when we reprocess the events, Postgres + // sends a `Relation` message with the `email` column even for entries before the DDL that added + // the `email`. For now this is a limitation that we are acknowledging, but we would like to find + // a solution for this. + let err = pipeline.shutdown_and_wait().await.err().unwrap(); + assert_eq!(err.kinds().len(), 1); + assert_eq!(err.kinds()[0], ErrorKind::CorruptedTableSchema); +} + +#[tokio::test(flavor = "multi_thread")] +async fn table_schema_snapshots_are_consistent_after_missing_status_update_with_initial_ddl() { + let _scenario = FailScenario::setup(); + fail::cfg(SEND_STATUS_UPDATE_FP, "return").unwrap(); + + init_test_tracing(); + + let (database, table_name, table_id, store, destination, pipeline, pipeline_id, publication) = + create_database_and_pipeline_with_table( + "schema_add_column", + &[("name", "text not null"), ("age", "integer not null")], + ) + .await; + + // The reason for why we wait for two `Relation` messages is that since we have a DDL event before + // DML statements, Postgres likely avoids sending an initial `Relation` message since it's already + // sent given the DDL event. + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 2), (EventType::Insert, 2)]) + .await; + + // We immediately add a column to the table without any DML, to show the case where we can recover + // in case we immediately start with a DDL event. + database + .alter_table( + table_name.clone(), + &[TableModification::AddColumn { + name: "email", + data_type: "text null", + }], + ) + .await + .unwrap(); + + database + .insert_values( + table_name.clone(), + &["name", "age", "email"], + &[&"Bob", &28, &"bob@example.com"], + ) + .await + .unwrap(); + + database + .alter_table( + table_name.clone(), + &[TableModification::DropColumn { name: "age" }], + ) + .await + .unwrap(); + + database + .insert_values( + table_name.clone(), + &["name", "email"], + &[&"Matt", &"matt@example.com"], + ) + .await + .unwrap(); + + notify.notified().await; + pipeline.shutdown_and_wait().await.unwrap(); + + // Assert that we got all the events correctly. + let events = destination.get_events().await; + let grouped = group_events_by_type_and_table_id(&events); + + assert_eq!( + grouped.get(&(EventType::Relation, table_id)).unwrap().len(), + 2 + ); + assert_eq!( + grouped.get(&(EventType::Insert, table_id)).unwrap().len(), + 2 + ); + + // Assert that we have 3 schema snapshots stored in order (1 base snapshot + 2 relation changes). + let table_schemas = store.get_table_schemas().await; + let table_schemas_snapshots = table_schemas.get(&table_id).unwrap(); + assert_eq!(table_schemas_snapshots.len(), 3); + assert_schema_snapshots_ordering(table_schemas_snapshots, true); + + // Verify the first snapshot has the initial schema (id, name, age). + let (_, first_schema) = &table_schemas_snapshots[0]; + assert_table_schema_column_names_types( + first_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ], + ); + + // Verify the first snapshot has the new schema (id, name, age, email). + let (_, first_schema) = &table_schemas_snapshots[1]; + assert_table_schema_column_names_types( + first_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ("email", Type::TEXT), + ], + ); + + // Verify the second snapshot doesn't have the age column (id, name, email). + let (_, second_schema) = &table_schemas_snapshots[2]; + assert_table_schema_column_names_types( + second_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("email", Type::TEXT), + ], + ); + + // Clear up the events. + destination.clear_events().await; + + // Restart the pipeline with the failpoint disabled to verify recovery. + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + publication, + store.clone(), + destination.clone(), + ); + + pipeline.start().await.unwrap(); + + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 2), (EventType::Insert, 3)]) + .await; + + database + .insert_values( + table_name.clone(), + &["name", "email"], + &[&"Charlie", &"charlie@example.com"], + ) + .await + .unwrap(); + + notify.notified().await; + pipeline.shutdown_and_wait().await.unwrap(); + + // Assert that we got all the events correctly. + let events = destination.get_events().await; + let grouped = group_events_by_type_and_table_id(&events); + + assert_eq!( + grouped.get(&(EventType::Relation, table_id)).unwrap().len(), + 2 + ); + assert_eq!( + grouped.get(&(EventType::Insert, table_id)).unwrap().len(), + 3 + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn table_schema_replication_masks_are_consistent_after_restart() { + let _scenario = FailScenario::setup(); + fail::cfg(SEND_STATUS_UPDATE_FP, "return").unwrap(); + + init_test_tracing(); + let database = spawn_source_database().await; + + // Column filters in publication are only available from Postgres 15+. + if below_version!(database.server_version(), POSTGRES_15) { + eprintln!("Skipping test: PostgreSQL 15+ required for column filters"); + return; + } + + // Create a table with 3 columns (plus auto-generated id). + let table_name = test_table_name("col_removal"); + let table_id = database + .create_table( + table_name.clone(), + true, + &[ + ("name", "text not null"), + ("age", "integer not null"), + ("email", "text not null"), + ], + ) + .await + .unwrap(); + + // Create publication with all 3 columns (plus id) initially. + let publication_name = format!("pub_{}", random::()); + database + .run_sql(&format!( + "create publication {publication_name} for table {} (id, name, age, email)", + table_name.as_quoted_identifier() + )) + .await + .expect("Failed to create publication with column filter"); + + let store = NotifyingStore::new(); + let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); + + let pipeline_id: PipelineId = random(); + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + publication_name.clone(), + store.clone(), + destination.clone(), + ); + + // Wait for the table to finish syncing. + let sync_done_notify = store + .notify_on_table_state_type(table_id, TableReplicationPhaseType::SyncDone) + .await; + + pipeline.start().await.unwrap(); + + sync_done_notify.notified().await; + + // We expect 3 relation events (one per publication change) and 3 insert events. + let events_notify = destination + .wait_for_events_count(vec![(EventType::Relation, 3), (EventType::Insert, 3)]) + .await; + + // Phase 1: Insert with all 4 columns (id, name, age, email). + database + .run_sql(&format!( + "insert into {} (name, age, email) values ('Alice', 25, 'alice@example.com')", + table_name.as_quoted_identifier() + )) + .await + .unwrap(); + + // Phase 2: Remove email column -> (id, name, age), then insert. + database + .run_sql(&format!( + "alter publication {publication_name} set table {} (id, name, age)", + table_name.as_quoted_identifier() + )) + .await + .unwrap(); + + database + .run_sql(&format!( + "insert into {} (name, age, email) values ('Bob', 30, 'bob@example.com')", + table_name.as_quoted_identifier() + )) + .await + .unwrap(); + + // Phase 3: Remove age column -> (id, name), then insert. + database + .run_sql(&format!( + "alter publication {publication_name} set table {} (id, name)", + table_name.as_quoted_identifier() + )) + .await + .unwrap(); + + database + .run_sql(&format!( + "insert into {} (name, age, email) values ('Charlie', 35, 'charlie@example.com')", + table_name.as_quoted_identifier() + )) + .await + .unwrap(); + + events_notify.notified().await; + + // Helper to verify events after each run. + let verify_events = |events: &[Event], table_id: TableId| { + let grouped = group_events_by_type_and_table_id(events); + + // Verify we have 3 relation events. + let relation_events: Vec<_> = events + .iter() + .filter_map(|event| match event { + Event::Relation(relation) if relation.replicated_table_schema.id() == table_id => { + Some(relation.clone()) + } + _ => None, + }) + .collect(); + assert_eq!( + relation_events.len(), + 3, + "Expected 3 relation events, got {}", + relation_events.len() + ); + + // Verify relation events have decreasing column counts: 4 -> 3 -> 2. + let relation_column_counts: Vec = relation_events + .iter() + .map(|r| r.replicated_table_schema.column_schemas().count()) + .collect(); + assert_eq!( + relation_column_counts, + vec![4, 3, 2], + "Expected relation column counts [4, 3, 2], got {relation_column_counts:?}" + ); + + // Verify relation column names for each phase. + let relation_1_cols: Vec<&str> = relation_events[0] + .replicated_table_schema + .column_schemas() + .map(|c| c.name.as_str()) + .collect(); + assert_eq!(relation_1_cols, vec!["id", "name", "age", "email"]); + + let relation_2_cols: Vec<&str> = relation_events[1] + .replicated_table_schema + .column_schemas() + .map(|c| c.name.as_str()) + .collect(); + assert_eq!(relation_2_cols, vec!["id", "name", "age"]); + + let relation_3_cols: Vec<&str> = relation_events[2] + .replicated_table_schema + .column_schemas() + .map(|c| c.name.as_str()) + .collect(); + assert_eq!(relation_3_cols, vec!["id", "name"]); + + // Verify replication masks. + assert_eq!( + relation_events[0] + .replicated_table_schema + .replication_mask() + .as_slice(), + &[1, 1, 1, 1] + ); + assert_eq!( + relation_events[1] + .replicated_table_schema + .replication_mask() + .as_slice(), + &[1, 1, 1, 0] + ); + assert_eq!( + relation_events[2] + .replicated_table_schema + .replication_mask() + .as_slice(), + &[1, 1, 0, 0] + ); + + // Verify underlying schema always has 4 columns. + for relation in &relation_events { + assert_eq!( + relation + .replicated_table_schema + .get_inner() + .column_schemas + .len(), + 4 + ); + } + + // Verify we have 3 insert events. + let insert_events = grouped.get(&(EventType::Insert, table_id)).unwrap(); + assert_eq!( + insert_events.len(), + 3, + "Expected 3 insert events, got {}", + insert_events.len() + ); + + // Verify insert events have decreasing value counts: 4 -> 3 -> 2. + let insert_value_counts: Vec = insert_events + .iter() + .filter_map(|event| { + if let Event::Insert(InsertEvent { table_row, .. }) = event { + Some(table_row.values.len()) + } else { + None + } + }) + .collect(); + assert_eq!( + insert_value_counts, + vec![4, 3, 2], + "Expected insert value counts [4, 3, 2], got {insert_value_counts:?}" + ); + }; + + // Shutdown the pipeline. + pipeline.shutdown_and_wait().await.unwrap(); + + // Verify events from first run. + let events = destination.get_events().await; + verify_events(&events, table_id); + + // Verify schema snapshots are stored correctly. + let table_schemas = store.get_table_schemas().await; + let table_schemas_snapshots = table_schemas.get(&table_id).unwrap(); + assert!( + !table_schemas_snapshots.is_empty(), + "Expected at least 1 schema snapshot" + ); + assert_schema_snapshots_ordering(table_schemas_snapshots, true); + + // The underlying table schema should always have 4 columns. + for (_, schema) in table_schemas_snapshots { + assert_table_schema_column_names_types( + schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ("email", Type::TEXT), + ], + ); + } + + // Clear up the events. + destination.clear_events().await; + + // Restart the pipeline - Postgres will resend the data since we don't track progress exactly. + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + publication_name.clone(), + store.clone(), + destination.clone(), + ); + + // Wait for 3 relation events and 3 insert events again after restart. + let events_notify_restart = destination + .wait_for_events_count(vec![(EventType::Relation, 3), (EventType::Insert, 3)]) + .await; + + pipeline.start().await.unwrap(); + + events_notify_restart.notified().await; + + // Verify the same events are received after restart. + let events_after_restart = destination.get_events().await; + verify_events(&events_after_restart, table_id); + + pipeline.shutdown_and_wait().await.unwrap(); +} diff --git a/etl/tests/pipeline_with_partitioned_table.rs b/etl/tests/pipeline_with_partitioned_table.rs index 0a680ceaf..d24d049d6 100644 --- a/etl/tests/pipeline_with_partitioned_table.rs +++ b/etl/tests/pipeline_with_partitioned_table.rs @@ -75,7 +75,7 @@ async fn partitioned_table_copy_replicates_existing_data() { let _ = pipeline.shutdown_and_wait().await; // Verify table schema was discovered correctly. - let table_schemas = state_store.get_table_schemas().await; + let table_schemas = state_store.get_latest_table_schemas().await; assert!(table_schemas.contains_key(&parent_table_id)); let parent_schema = &table_schemas[&parent_table_id]; @@ -90,21 +90,21 @@ async fn partitioned_table_copy_replicates_existing_data() { assert_eq!(id_column.name, "id"); assert_eq!(id_column.typ, Type::INT8); assert!(!id_column.nullable); - assert!(id_column.primary); + assert!(id_column.primary_key()); // Check data column. let data_column = &parent_schema.column_schemas[1]; assert_eq!(data_column.name, "data"); assert_eq!(data_column.typ, Type::TEXT); assert!(!data_column.nullable); - assert!(!data_column.primary); + assert!(!data_column.primary_key()); // Check partition_key column. let partition_key_column = &parent_schema.column_schemas[2]; assert_eq!(partition_key_column.name, "partition_key"); assert_eq!(partition_key_column.typ, Type::INT4); assert!(!partition_key_column.nullable); - assert!(partition_key_column.primary); + assert!(partition_key_column.primary_key()); let table_rows = destination.get_table_rows().await; let total_rows: usize = table_rows.values().map(|rows| rows.len()).sum(); @@ -1246,7 +1246,7 @@ async fn nested_partitioned_table_copy_and_cdc() { parent_sync_done.notified().await; // Verify table schema was discovered correctly for nested partitioned table. - let table_schemas = state_store.get_table_schemas().await; + let table_schemas = state_store.get_latest_table_schemas().await; assert!(table_schemas.contains_key(&parent_table_id)); let parent_schema = &table_schemas[&parent_table_id]; @@ -1261,28 +1261,28 @@ async fn nested_partitioned_table_copy_and_cdc() { assert_eq!(id_column.name, "id"); assert_eq!(id_column.typ, Type::INT8); assert!(!id_column.nullable); - assert!(id_column.primary); + assert!(id_column.primary_key()); // Check data column. let data_column = &parent_schema.column_schemas[1]; assert_eq!(data_column.name, "data"); assert_eq!(data_column.typ, Type::TEXT); assert!(!data_column.nullable); - assert!(!data_column.primary); + assert!(!data_column.primary_key()); // Check partition_key column (part of primary key). let partition_key_column = &parent_schema.column_schemas[2]; assert_eq!(partition_key_column.name, "partition_key"); assert_eq!(partition_key_column.typ, Type::INT4); assert!(!partition_key_column.nullable); - assert!(partition_key_column.primary); + assert!(partition_key_column.primary_key()); // Check sub_partition_key column (part of primary key for nested partitioning). let sub_partition_key_column = &parent_schema.column_schemas[3]; assert_eq!(sub_partition_key_column.name, "sub_partition_key"); assert_eq!(sub_partition_key_column.typ, Type::INT4); assert!(!sub_partition_key_column.nullable); - assert!(sub_partition_key_column.primary); + assert!(sub_partition_key_column.primary_key()); // Verify initial COPY replicated all 3 rows. let table_rows = destination.get_table_rows().await; diff --git a/etl/tests/pipelines_with_schema_changes.rs b/etl/tests/pipelines_with_schema_changes.rs new file mode 100644 index 000000000..e67170dcf --- /dev/null +++ b/etl/tests/pipelines_with_schema_changes.rs @@ -0,0 +1,840 @@ +#![cfg(feature = "test-utils")] + +use std::time::Duration; + +use etl::destination::memory::MemoryDestination; +use etl::state::table::TableReplicationPhaseType; +use etl::test_utils::database::{spawn_source_database, test_table_name}; +use etl::test_utils::event::group_events_by_type_and_table_id; +use etl::test_utils::notify::NotifyingStore; +use etl::test_utils::pipeline::{create_database_and_pipeline_with_table, create_pipeline}; +use etl::test_utils::schema::{ + assert_replicated_schema_column_names_types, assert_schema_snapshots_ordering, + assert_table_schema_column_names_types, +}; +use etl::test_utils::test_destination_wrapper::TestDestinationWrapper; +use etl::test_utils::test_schema::create_partitioned_table; +use etl::types::{Event, EventType, PipelineId, Type}; +use etl_postgres::tokio::test_utils::TableModification; +use etl_postgres::types::TableId; +use etl_telemetry::tracing::init_test_tracing; +use rand::random; +use tokio::time::sleep; + +fn get_last_relation_event(events: &[Event], table_id: TableId) -> &Event { + events + .iter() + .rev() + .find(|e| matches!(e, Event::Relation(r) if r.replicated_table_schema.id() == table_id)) + .expect("no relation events for table") +} + +fn get_last_insert_event(events: &[Event], table_id: TableId) -> &Event { + events + .iter() + .rev() + .find(|e| matches!(e, Event::Insert(i) if i.replicated_table_schema.id() == table_id)) + .expect("no insert events for table") +} + +#[tokio::test(flavor = "multi_thread")] +async fn relation_message_updates_when_column_added() { + init_test_tracing(); + + let (database, table_name, table_id, store, destination, pipeline, _pipeline_id, _publication) = + create_database_and_pipeline_with_table( + "schema_add_column", + &[("name", "text not null"), ("age", "integer not null")], + ) + .await; + + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + database + .alter_table( + table_name.clone(), + &[TableModification::AddColumn { + name: "email", + data_type: "text not null", + }], + ) + .await + .unwrap(); + + database + .insert_values( + table_name.clone(), + &["name", "age", "email"], + &[&"Alice", &25, &"alice@example.com"], + ) + .await + .unwrap(); + + notify.notified().await; + pipeline.shutdown_and_wait().await.unwrap(); + + let events = destination.get_events().await; + let grouped = group_events_by_type_and_table_id(&events); + + assert_eq!( + grouped.get(&(EventType::Relation, table_id)).unwrap().len(), + 1 + ); + assert_eq!( + grouped.get(&(EventType::Insert, table_id)).unwrap().len(), + 1 + ); + + let Event::Relation(r) = get_last_relation_event(&events, table_id) else { + panic!("expected relation event"); + }; + assert_replicated_schema_column_names_types( + &r.replicated_table_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ("email", Type::TEXT), + ], + ); + let Event::Insert(i) = get_last_insert_event(&events, table_id) else { + panic!("expected insert event"); + }; + assert_eq!(i.table_row.values.len(), 4); + + // Verify schema snapshots are stored in order. + let table_schemas = store.get_table_schemas().await; + let snapshots = table_schemas.get(&table_id).unwrap(); + assert_eq!(snapshots.len(), 2); + assert_schema_snapshots_ordering(snapshots, true); + + let (_, first_schema) = &snapshots[0]; + assert_table_schema_column_names_types( + first_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ], + ); + + let (_, second_schema) = &snapshots[1]; + assert_table_schema_column_names_types( + second_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ("email", Type::TEXT), + ], + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn relation_message_updates_when_column_removed() { + init_test_tracing(); + + let (database, table_name, table_id, store, destination, pipeline, _pipeline_id, _publication) = + create_database_and_pipeline_with_table( + "schema_remove_column", + &[("name", "text not null"), ("age", "integer not null")], + ) + .await; + + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + database + .alter_table( + table_name.clone(), + &[TableModification::DropColumn { name: "age" }], + ) + .await + .unwrap(); + + database + .insert_values(table_name.clone(), &["name"], &[&"Bob"]) + .await + .unwrap(); + + notify.notified().await; + pipeline.shutdown_and_wait().await.unwrap(); + + let events = destination.get_events().await; + let grouped = group_events_by_type_and_table_id(&events); + + assert_eq!( + grouped.get(&(EventType::Relation, table_id)).unwrap().len(), + 1 + ); + assert_eq!( + grouped.get(&(EventType::Insert, table_id)).unwrap().len(), + 1 + ); + + let Event::Relation(r) = get_last_relation_event(&events, table_id) else { + panic!("expected relation event"); + }; + assert_replicated_schema_column_names_types( + &r.replicated_table_schema, + &[("id", Type::INT8), ("name", Type::TEXT)], + ); + let Event::Insert(i) = get_last_insert_event(&events, table_id) else { + panic!("expected insert event"); + }; + assert_eq!(i.table_row.values.len(), 2); + + // Verify schema snapshots are stored in order. + let table_schemas = store.get_table_schemas().await; + let snapshots = table_schemas.get(&table_id).unwrap(); + assert_eq!(snapshots.len(), 2); + assert_schema_snapshots_ordering(snapshots, true); + + let (_, first_schema) = &snapshots[0]; + assert_table_schema_column_names_types( + first_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ], + ); + + let (_, second_schema) = &snapshots[1]; + assert_table_schema_column_names_types( + second_schema, + &[("id", Type::INT8), ("name", Type::TEXT)], + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn relation_message_updates_when_column_renamed() { + init_test_tracing(); + + let (database, table_name, table_id, store, destination, pipeline, _pipeline_id, _publication) = + create_database_and_pipeline_with_table( + "schema_rename_column", + &[("name", "text not null"), ("age", "integer not null")], + ) + .await; + + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + database + .alter_table( + table_name.clone(), + &[TableModification::RenameColumn { + old_name: "name", + new_name: "full_name", + }], + ) + .await + .unwrap(); + + database + .insert_values(table_name.clone(), &["full_name", "age"], &[&"Carol", &41]) + .await + .unwrap(); + + notify.notified().await; + pipeline.shutdown_and_wait().await.unwrap(); + + let events = destination.get_events().await; + let grouped = group_events_by_type_and_table_id(&events); + + assert_eq!( + grouped.get(&(EventType::Relation, table_id)).unwrap().len(), + 1 + ); + assert_eq!( + grouped.get(&(EventType::Insert, table_id)).unwrap().len(), + 1 + ); + + let Event::Relation(r) = get_last_relation_event(&events, table_id) else { + panic!("expected relation event"); + }; + assert_replicated_schema_column_names_types( + &r.replicated_table_schema, + &[ + ("id", Type::INT8), + ("full_name", Type::TEXT), + ("age", Type::INT4), + ], + ); + let Event::Insert(i) = get_last_insert_event(&events, table_id) else { + panic!("expected insert event"); + }; + assert_eq!(i.table_row.values.len(), 3); + + // Verify schema snapshots are stored in order. + let table_schemas = store.get_table_schemas().await; + let snapshots = table_schemas.get(&table_id).unwrap(); + assert_eq!(snapshots.len(), 2); + assert_schema_snapshots_ordering(snapshots, true); + + let (_, first_schema) = &snapshots[0]; + assert_table_schema_column_names_types( + first_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ], + ); + + let (_, second_schema) = &snapshots[1]; + assert_table_schema_column_names_types( + second_schema, + &[ + ("id", Type::INT8), + ("full_name", Type::TEXT), + ("age", Type::INT4), + ], + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn relation_message_updates_when_column_type_changes() { + init_test_tracing(); + + let (database, table_name, table_id, store, destination, pipeline, _pipeline_id, _publication) = + create_database_and_pipeline_with_table( + "schema_change_type", + &[("name", "text not null"), ("age", "integer not null")], + ) + .await; + + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + database + .alter_table( + table_name.clone(), + &[TableModification::AlterColumn { + name: "age", + alteration: "type bigint", + }], + ) + .await + .unwrap(); + + database + .insert_values(table_name.clone(), &["name", "age"], &[&"Dave", &45_i64]) + .await + .unwrap(); + + notify.notified().await; + pipeline.shutdown_and_wait().await.unwrap(); + + let events = destination.get_events().await; + let grouped = group_events_by_type_and_table_id(&events); + + assert_eq!( + grouped.get(&(EventType::Relation, table_id)).unwrap().len(), + 1 + ); + assert_eq!( + grouped.get(&(EventType::Insert, table_id)).unwrap().len(), + 1 + ); + + let Event::Relation(r) = get_last_relation_event(&events, table_id) else { + panic!("expected relation event"); + }; + assert_replicated_schema_column_names_types( + &r.replicated_table_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT8), + ], + ); + let Event::Insert(i) = get_last_insert_event(&events, table_id) else { + panic!("expected insert event"); + }; + assert_eq!(i.table_row.values.len(), 3); + + // Verify schema snapshots are stored in order. + let table_schemas = store.get_table_schemas().await; + let snapshots = table_schemas.get(&table_id).unwrap(); + assert_eq!(snapshots.len(), 2); + assert_schema_snapshots_ordering(snapshots, true); + + let (_, first_schema) = &snapshots[0]; + assert_table_schema_column_names_types( + first_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ], + ); + + let (_, second_schema) = &snapshots[1]; + assert_table_schema_column_names_types( + second_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT8), + ], + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn pipeline_recovers_after_multiple_schema_changes_and_restart() { + init_test_tracing(); + + // Start with initial schema: id (auto), name (text), age (integer), status (text) + let (database, table_name, table_id, store, destination, pipeline, pipeline_id, publication) = + create_database_and_pipeline_with_table( + "schema_multi_change_restart", + &[ + ("name", "text not null"), + ("age", "integer not null"), + ("status", "text not null"), + ], + ) + .await; + + // Phase 1: Add column + insert, then restart + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + database + .alter_table( + table_name.clone(), + &[TableModification::AddColumn { + name: "email", + data_type: "text not null", + }], + ) + .await + .unwrap(); + + database + .insert_values( + table_name.clone(), + &["name", "age", "status", "email"], + &[&"Alice", &25, &"active", &"alice@example.com"], + ) + .await + .unwrap(); + + notify.notified().await; + sleep(Duration::from_secs(5)).await; + pipeline.shutdown_and_wait().await.unwrap(); + destination.clear_events().await; + + // Phase 2: Rename column + change type + insert, then restart + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + publication.clone(), + store.clone(), + destination.clone(), + ); + pipeline.start().await.unwrap(); + + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + database + .alter_table( + table_name.clone(), + &[TableModification::RenameColumn { + old_name: "age", + new_name: "years", + }], + ) + .await + .unwrap(); + + database + .alter_table( + table_name.clone(), + &[TableModification::AlterColumn { + name: "years", + alteration: "type bigint", + }], + ) + .await + .unwrap(); + + database + .insert_values( + table_name.clone(), + &["name", "years", "status", "email"], + &[&"Bob", &30_i64, &"pending", &"bob@example.com"], + ) + .await + .unwrap(); + + notify.notified().await; + sleep(Duration::from_secs(1)).await; + pipeline.shutdown_and_wait().await.unwrap(); + destination.clear_events().await; + + // Phase 3: Drop column + insert, then restart + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + publication.clone(), + store.clone(), + destination.clone(), + ); + pipeline.start().await.unwrap(); + + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + database + .alter_table( + table_name.clone(), + &[TableModification::DropColumn { name: "status" }], + ) + .await + .unwrap(); + + database + .insert_values( + table_name.clone(), + &["name", "years", "email"], + &[&"Carol", &35_i64, &"carol@example.com"], + ) + .await + .unwrap(); + + notify.notified().await; + sleep(Duration::from_secs(1)).await; + pipeline.shutdown_and_wait().await.unwrap(); + destination.clear_events().await; + + // Phase 4: Add another column + rename existing + insert, then verify + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + publication, + store.clone(), + destination.clone(), + ); + pipeline.start().await.unwrap(); + + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + database + .alter_table( + table_name.clone(), + &[TableModification::AddColumn { + name: "created_at", + data_type: "timestamp not null default now()", + }], + ) + .await + .unwrap(); + + database + .alter_table( + table_name.clone(), + &[TableModification::RenameColumn { + old_name: "email", + new_name: "contact_email", + }], + ) + .await + .unwrap(); + + database + .insert_values( + table_name.clone(), + &["name", "years", "contact_email"], + &[&"Dave", &40_i64, &"dave@example.com"], + ) + .await + .unwrap(); + + notify.notified().await; + pipeline.shutdown_and_wait().await.unwrap(); + + // Final schema should be: id (int8), name (text), years (int8), contact_email (text), created_at (timestamp) + let events = destination.get_events().await; + + let Event::Relation(r) = get_last_relation_event(&events, table_id) else { + panic!("expected relation event"); + }; + assert_replicated_schema_column_names_types( + &r.replicated_table_schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("years", Type::INT8), + ("contact_email", Type::TEXT), + ("created_at", Type::TIMESTAMP), + ], + ); + let Event::Insert(i) = get_last_insert_event(&events, table_id) else { + panic!("expected insert event"); + }; + assert_eq!(i.table_row.values.len(), 5); + + // Verify all schema snapshots are stored in order. + // We have 7 snapshots: + // - Initial (id, name, age, status) + // - After adding email + // - After renaming age -> years + // - After changing years type to bigint + // - After dropping status + // - After adding created_at + // - After renaming email -> contact_email (this is the final schema for the insert) + let table_schemas = store.get_table_schemas().await; + let snapshots = table_schemas.get(&table_id).unwrap(); + assert_eq!(snapshots.len(), 7); + assert_schema_snapshots_ordering(snapshots, true); + + // Initial schema: id, name, age, status + let (_, schema) = &snapshots[0]; + assert_table_schema_column_names_types( + schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ("status", Type::TEXT), + ], + ); + + // After adding email: id, name, age, status, email + let (_, schema) = &snapshots[1]; + assert_table_schema_column_names_types( + schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("age", Type::INT4), + ("status", Type::TEXT), + ("email", Type::TEXT), + ], + ); + + // After renaming age -> years: id, name, years, status, email + let (_, schema) = &snapshots[2]; + assert_table_schema_column_names_types( + schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("years", Type::INT4), + ("status", Type::TEXT), + ("email", Type::TEXT), + ], + ); + + // After changing years type to bigint: id, name, years (int8), status, email + let (_, schema) = &snapshots[3]; + assert_table_schema_column_names_types( + schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("years", Type::INT8), + ("status", Type::TEXT), + ("email", Type::TEXT), + ], + ); + + // After dropping status: id, name, years, email + let (_, schema) = &snapshots[4]; + assert_table_schema_column_names_types( + schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("years", Type::INT8), + ("email", Type::TEXT), + ], + ); + + // After adding created_at: id, name, years, email, created_at + let (_, schema) = &snapshots[5]; + assert_table_schema_column_names_types( + schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("years", Type::INT8), + ("email", Type::TEXT), + ("created_at", Type::TIMESTAMP), + ], + ); + + // Final schema after renaming email -> contact_email: id, name, years, contact_email, created_at + let (_, schema) = &snapshots[6]; + assert_table_schema_column_names_types( + schema, + &[ + ("id", Type::INT8), + ("name", Type::TEXT), + ("years", Type::INT8), + ("contact_email", Type::TEXT), + ("created_at", Type::TIMESTAMP), + ], + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn partitioned_table_schema_change_updates_relation_message() { + init_test_tracing(); + let database = spawn_source_database().await; + + let table_name = test_table_name("partitioned_schema_change"); + let partition_specs = [("p1", "from (1) to (100)"), ("p2", "from (100) to (200)")]; + + let (parent_table_id, _partition_table_ids) = + create_partitioned_table(&database, table_name.clone(), &partition_specs) + .await + .unwrap(); + + // Insert initial data into partitions. + database + .run_sql(&format!( + "insert into {} (data, partition_key) values ('event1', 50), ('event2', 150)", + table_name.as_quoted_identifier() + )) + .await + .unwrap(); + + let publication_name = "test_partitioned_schema_change_pub".to_string(); + database + .create_publication(&publication_name, std::slice::from_ref(&table_name)) + .await + .unwrap(); + + let state_store = NotifyingStore::new(); + let destination = TestDestinationWrapper::wrap(MemoryDestination::new()); + + let pipeline_id: PipelineId = random(); + let mut pipeline = create_pipeline( + &database.config, + pipeline_id, + publication_name, + state_store.clone(), + destination.clone(), + ); + + let parent_sync_done = state_store + .notify_on_table_state_type(parent_table_id, TableReplicationPhaseType::SyncDone) + .await; + + pipeline.start().await.unwrap(); + + parent_sync_done.notified().await; + + // Wait for the Relation event (schema change) and Insert event. + let notify = destination + .wait_for_events_count(vec![(EventType::Relation, 1), (EventType::Insert, 1)]) + .await; + + // Add a new column to the partitioned table. + database + .alter_table( + table_name.clone(), + &[TableModification::AddColumn { + name: "category", + data_type: "text not null default 'default_category'", + }], + ) + .await + .unwrap(); + + // Insert a row with the new column into one of the partitions. + database + .run_sql(&format!( + "insert into {} (data, partition_key, category) values ('event3', 75, 'test_category')", + table_name.as_quoted_identifier() + )) + .await + .unwrap(); + + notify.notified().await; + pipeline.shutdown_and_wait().await.unwrap(); + + let events = destination.get_events().await; + let grouped = group_events_by_type_and_table_id(&events); + + // Verify we received exactly 1 Relation event for the parent table. + assert_eq!( + grouped + .get(&(EventType::Relation, parent_table_id)) + .unwrap() + .len(), + 1 + ); + assert_eq!( + grouped + .get(&(EventType::Insert, parent_table_id)) + .unwrap() + .len(), + 1 + ); + + // Verify the Relation event has the updated schema with the new column. + let Event::Relation(r) = get_last_relation_event(&events, parent_table_id) else { + panic!("expected relation event"); + }; + assert_replicated_schema_column_names_types( + &r.replicated_table_schema, + &[ + ("id", Type::INT8), + ("data", Type::TEXT), + ("partition_key", Type::INT4), + ("category", Type::TEXT), + ], + ); + + // Verify the Insert event has 4 columns. + let Event::Insert(i) = get_last_insert_event(&events, parent_table_id) else { + panic!("expected insert event"); + }; + assert_eq!(i.table_row.values.len(), 4); + + // Verify schema snapshots are stored in order. + let table_schemas = state_store.get_table_schemas().await; + let snapshots = table_schemas.get(&parent_table_id).unwrap(); + assert_eq!(snapshots.len(), 2); + assert_schema_snapshots_ordering(snapshots, true); + + // Initial schema: id, data, partition_key. + let (_, first_schema) = &snapshots[0]; + assert_table_schema_column_names_types( + first_schema, + &[ + ("id", Type::INT8), + ("data", Type::TEXT), + ("partition_key", Type::INT4), + ], + ); + + // After adding category: id, data, partition_key, category. + let (_, second_schema) = &snapshots[1]; + assert_table_schema_column_names_types( + second_schema, + &[ + ("id", Type::INT8), + ("data", Type::TEXT), + ("partition_key", Type::INT4), + ("category", Type::TEXT), + ], + ); +} diff --git a/etl/tests/postgres_store.rs b/etl/tests/postgres_store.rs index 1038c591a..e7f4e48e0 100644 --- a/etl/tests/postgres_store.rs +++ b/etl/tests/postgres_store.rs @@ -1,30 +1,45 @@ #![cfg(feature = "test-utils")] +use etl::state::destination::DestinationTableMetadata; use etl::state::table::{RetryPolicy, TableReplicationPhase}; use etl::store::both::postgres::PostgresStore; use etl::store::cleanup::CleanupStore; use etl::store::schema::SchemaStore; use etl::store::state::StateStore; -use etl::test_utils::database::spawn_source_database_for_store; +use etl::test_utils::database::spawn_source_database; use etl_postgres::replication::connect_to_source_database; -use etl_postgres::types::{ColumnSchema, TableId, TableName, TableSchema}; +use etl_postgres::types::ReplicationMask; +use etl_postgres::types::{ColumnSchema, SnapshotId, TableId, TableName, TableSchema}; use etl_telemetry::tracing::init_test_tracing; use sqlx::postgres::types::Oid as SqlxTableId; use tokio_postgres::types::{PgLsn, Type as PgType}; +/// Creates a test column schema with sensible defaults. +fn test_column( + name: &str, + typ: PgType, + modifier: i32, + ordinal_position: i32, + nullable: bool, + primary_key: bool, +) -> ColumnSchema { + ColumnSchema::new( + name.to_string(), + typ, + modifier, + ordinal_position, + if primary_key { Some(1) } else { None }, + nullable, + ) +} + fn create_sample_table_schema() -> TableSchema { let table_id = TableId::new(12345); let table_name = TableName::new("public".to_string(), "test_table".to_string()); let columns = vec![ - ColumnSchema::new("id".to_string(), PgType::INT4, -1, false, true), - ColumnSchema::new("name".to_string(), PgType::TEXT, -1, true, false), - ColumnSchema::new( - "created_at".to_string(), - PgType::TIMESTAMPTZ, - -1, - false, - false, - ), + test_column("id", PgType::INT4, -1, 1, false, true), + test_column("name", PgType::TEXT, -1, 2, true, false), + test_column("created_at", PgType::TIMESTAMPTZ, -1, 3, false, false), ]; TableSchema::new(table_id, table_name, columns) @@ -34,8 +49,8 @@ fn create_another_table_schema() -> TableSchema { let table_id = TableId::new(67890); let table_name = TableName::new("public".to_string(), "another_table".to_string()); let columns = vec![ - ColumnSchema::new("id".to_string(), PgType::INT8, -1, false, true), - ColumnSchema::new("description".to_string(), PgType::VARCHAR, 255, true, false), + test_column("id", PgType::INT8, -1, 1, false, true), + test_column("description", PgType::VARCHAR, 255, 2, true, false), ]; TableSchema::new(table_id, table_name, columns) @@ -45,7 +60,7 @@ fn create_another_table_schema() -> TableSchema { async fn test_state_store_operations() { init_test_tracing(); - let database = spawn_source_database_for_store().await; + let database = spawn_source_database().await; let pipeline_id = 1; let table_id = TableId::new(12345); @@ -112,7 +127,7 @@ async fn test_state_store_operations() { async fn test_state_store_rollback() { init_test_tracing(); - let database = spawn_source_database_for_store().await; + let database = spawn_source_database().await; let pipeline_id = 1; let table_id = TableId::new(12345); @@ -181,7 +196,7 @@ async fn test_state_store_rollback() { async fn test_state_store_load_states() { init_test_tracing(); - let database = spawn_source_database_for_store().await; + let database = spawn_source_database().await; let pipeline_id = 1; let table_id1 = TableId::new(12345); let table_id2 = TableId::new(67890); @@ -223,7 +238,7 @@ async fn test_state_store_load_states() { async fn test_schema_store_operations() { init_test_tracing(); - let database = spawn_source_database_for_store().await; + let database = spawn_source_database().await; let pipeline_id = 1; let store = PostgresStore::new(pipeline_id, database.config.clone()); @@ -231,7 +246,10 @@ async fn test_schema_store_operations() { let table_id = table_schema.id; // Test initial state - should be empty - let schema = store.get_table_schema(&table_id).await.unwrap(); + let schema = store + .get_table_schema(&table_id, SnapshotId::max()) + .await + .unwrap(); assert!(schema.is_none()); let all_schemas = store.get_table_schemas().await.unwrap(); @@ -243,7 +261,10 @@ async fn test_schema_store_operations() { .await .unwrap(); - let schema = store.get_table_schema(&table_id).await.unwrap(); + let schema = store + .get_table_schema(&table_id, SnapshotId::max()) + .await + .unwrap(); assert!(schema.is_some()); let schema = schema.unwrap(); assert_eq!(schema.id, table_schema.id); @@ -271,7 +292,7 @@ async fn test_schema_store_operations() { async fn test_schema_store_load_schemas() { init_test_tracing(); - let database = spawn_source_database_for_store().await; + let database = spawn_source_database().await; let pipeline_id = 1; let store = PostgresStore::new(pipeline_id, database.config.clone()); @@ -303,13 +324,19 @@ async fn test_schema_store_load_schemas() { let schemas = new_store.get_table_schemas().await.unwrap(); assert_eq!(schemas.len(), 2); - let schema1 = new_store.get_table_schema(&table_schema1.id).await.unwrap(); + let schema1 = new_store + .get_table_schema(&table_schema1.id, SnapshotId::max()) + .await + .unwrap(); assert!(schema1.is_some()); let schema1 = schema1.unwrap(); assert_eq!(schema1.id, table_schema1.id); assert_eq!(schema1.name, table_schema1.name); - let schema2 = new_store.get_table_schema(&table_schema2.id).await.unwrap(); + let schema2 = new_store + .get_table_schema(&table_schema2.id, SnapshotId::max()) + .await + .unwrap(); assert!(schema2.is_some()); let schema2 = schema2.unwrap(); assert_eq!(schema2.id, table_schema2.id); @@ -317,43 +344,64 @@ async fn test_schema_store_load_schemas() { } #[tokio::test(flavor = "multi_thread")] -async fn test_schema_store_update_existing() { +async fn test_schema_store_versioning() { init_test_tracing(); - let database = spawn_source_database_for_store().await; + let database = spawn_source_database().await; let pipeline_id = 1; let store = PostgresStore::new(pipeline_id, database.config.clone()); let mut table_schema = create_sample_table_schema(); - // Store initial schema + // Store initial schema at snapshot 0 store .store_table_schema(table_schema.clone()) .await .unwrap(); - // Update schema by adding a column - table_schema.add_column_schema(ColumnSchema::new( - "updated_at".to_string(), + // Create a new version with a higher snapshot_id + table_schema.add_column_schema(test_column( + "updated_at", PgType::TIMESTAMPTZ, -1, + 4, true, false, )); + table_schema.snapshot_id = SnapshotId::from(100u64); // New snapshot for the schema change - // Store updated schema + // Store updated schema as new version store .store_table_schema(table_schema.clone()) .await .unwrap(); - // Verify updated schema - let schema = store.get_table_schema(&table_schema.id).await.unwrap(); + // Verify querying at snapshot 100+ returns the updated schema + let schema = store + .get_table_schema(&table_schema.id, SnapshotId::max()) + .await + .unwrap(); assert!(schema.is_some()); let schema = schema.unwrap(); assert_eq!(schema.column_schemas.len(), 4); // Original 3 + 1 new column + assert_eq!(schema.snapshot_id, SnapshotId::from(100u64)); - // Verify the new column was added + // Verify querying at snapshot 50 returns the original schema + let schema = store + .get_table_schema(&table_schema.id, SnapshotId::from(50u64)) + .await + .unwrap(); + assert!(schema.is_some()); + let schema = schema.unwrap(); + assert_eq!(schema.column_schemas.len(), 3); // Original 3 columns + assert_eq!(schema.snapshot_id, SnapshotId::initial()); + + // Verify the new column was added in the latest version + let schema = store + .get_table_schema(&table_schema.id, SnapshotId::max()) + .await + .unwrap() + .unwrap(); let updated_at_column = schema .column_schemas .iter() @@ -361,11 +409,170 @@ async fn test_schema_store_update_existing() { assert!(updated_at_column.is_some()); } +#[tokio::test(flavor = "multi_thread")] +async fn test_schema_store_upsert_replaces_columns() { + init_test_tracing(); + + let database = spawn_source_database().await; + let pipeline_id = 1; + + let store = PostgresStore::new(pipeline_id, database.config.clone()); + + // Create initial schema with 3 columns + let table_id = TableId::new(12345); + let table_name = TableName::new("public".to_string(), "test_table".to_string()); + let initial_columns = vec![ + test_column("id", PgType::INT4, -1, 1, false, true), + test_column("name", PgType::TEXT, -1, 2, true, false), + test_column("old_column", PgType::TEXT, -1, 3, true, false), + ]; + let table_schema = TableSchema::new(table_id, table_name.clone(), initial_columns); + + // Store initial schema + store + .store_table_schema(table_schema.clone()) + .await + .unwrap(); + + // Verify initial columns + let schema = store + .get_table_schema(&table_id, SnapshotId::max()) + .await + .unwrap() + .unwrap(); + assert_eq!(schema.column_schemas.len(), 3); + assert!(schema.column_schemas.iter().any(|c| c.name == "old_column")); + + // Create updated schema with SAME snapshot_id but different columns + // (simulating a retry or re-processing scenario) + let updated_columns = vec![ + test_column("id", PgType::INT4, -1, 1, false, true), + test_column("name", PgType::TEXT, -1, 2, true, false), + test_column("new_column", PgType::TEXT, -1, 3, true, false), // replaced old_column + test_column("extra_column", PgType::INT8, -1, 4, true, false), // added column + ]; + let updated_schema = TableSchema::new(table_id, table_name, updated_columns); + + // Store updated schema with same snapshot_id (upsert) + store + .store_table_schema(updated_schema.clone()) + .await + .unwrap(); + + // Verify columns were replaced, not accumulated + // Need to clear cache and reload from DB to verify DB state + let new_store = PostgresStore::new(pipeline_id, database.config.clone()); + let schema = new_store + .get_table_schema(&table_id, SnapshotId::max()) + .await + .unwrap() + .unwrap(); + + assert_eq!(schema.column_schemas.len(), 4); // Should be 4, not 3+4=7 + assert!( + !schema.column_schemas.iter().any(|c| c.name == "old_column"), + "old_column should have been deleted" + ); + assert!( + schema.column_schemas.iter().any(|c| c.name == "new_column"), + "new_column should exist" + ); + assert!( + schema + .column_schemas + .iter() + .any(|c| c.name == "extra_column"), + "extra_column should exist" + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_schema_cache_eviction() { + init_test_tracing(); + + let database = spawn_source_database().await; + let pipeline_id = 1; + + let store = PostgresStore::new(pipeline_id, database.config.clone()); + + // Store 3 schema versions for table 1 + let table_id_1 = TableId::new(12345); + let table_name_1 = TableName::new("public".to_string(), "test_table".to_string()); + for snapshot_id in [0u64, 100, 200] { + let columns = vec![ + test_column("id", PgType::INT4, -1, 1, false, true), + test_column( + &format!("col_at_{snapshot_id}"), + PgType::TEXT, + -1, + 2, + true, + false, + ), + ]; + let mut table_schema = TableSchema::new(table_id_1, table_name_1.clone(), columns); + table_schema.snapshot_id = SnapshotId::from(snapshot_id); + store + .store_table_schema(table_schema.clone()) + .await + .unwrap(); + } + + // Store 3 schemas for table 2 to verify eviction is per-table + let table_id_2 = TableId::new(67890); + let table_name_2 = TableName::new("public".to_string(), "table_2".to_string()); + for snapshot_id in [0u64, 100, 200] { + let columns = vec![test_column("id", PgType::INT4, -1, 1, false, true)]; + let mut schema = TableSchema::new(table_id_2, table_name_2.clone(), columns); + schema.snapshot_id = SnapshotId::from(snapshot_id); + store.store_table_schema(schema).await.unwrap(); + } + + // Check cache size - should have 2 schemas per table = 4 total + let cached_schemas = store.get_table_schemas().await.unwrap(); + assert_eq!(cached_schemas.len(), 4, "Should have 2 schemas per table"); + + // Verify eviction keeps newest snapshots (100 and 200), evicts oldest (0) + let table_1_snapshots: Vec = cached_schemas + .iter() + .filter(|s| s.id == table_id_1) + .map(|s| s.snapshot_id) + .collect(); + assert!( + table_1_snapshots.contains(&SnapshotId::from(100u64)) + && table_1_snapshots.contains(&SnapshotId::from(200u64)) + ); + assert!( + !table_1_snapshots.contains(&SnapshotId::initial()), + "Snapshot 0 should be evicted" + ); + + let table_2_snapshots: Vec = cached_schemas + .iter() + .filter(|s| s.id == table_id_2) + .map(|s| s.snapshot_id) + .collect(); + assert!( + !table_2_snapshots.contains(&SnapshotId::initial()), + "Table 2 snapshot 0 should be evicted" + ); + + // Evicted schemas should still be loadable from DB + let new_store = PostgresStore::new(pipeline_id, database.config.clone()); + let schema_0 = new_store + .get_table_schema(&table_id_1, SnapshotId::initial()) + .await + .unwrap() + .unwrap(); + assert_eq!(schema_0.snapshot_id, SnapshotId::initial()); + assert!(schema_0.column_schemas.iter().any(|c| c.name == "col_at_0")); +} + #[tokio::test(flavor = "multi_thread")] async fn test_multiple_pipelines_isolation() { init_test_tracing(); - let database = spawn_source_database_for_store().await; + let database = spawn_source_database().await; let pipeline_id1 = 1; let pipeline_id2 = 2; let table_id = TableId::new(12345); @@ -373,26 +580,27 @@ async fn test_multiple_pipelines_isolation() { let store1 = PostgresStore::new(pipeline_id1, database.config.clone()); let store2 = PostgresStore::new(pipeline_id2, database.config.clone()); - // Add state to pipeline 1 + // Test state isolation let init_phase = TableReplicationPhase::Init; store1 .update_table_replication_state(table_id, init_phase.clone()) .await .unwrap(); - // Add different state to pipeline 2 for the same table let data_sync_phase = TableReplicationPhase::DataSync; store2 .update_table_replication_state(table_id, data_sync_phase.clone()) .await .unwrap(); - // Verify isolation - each pipeline sees only its own state - let state1 = store1.get_table_replication_state(table_id).await.unwrap(); - assert_eq!(state1, Some(init_phase)); - - let state2 = store2.get_table_replication_state(table_id).await.unwrap(); - assert_eq!(state2, Some(data_sync_phase)); + assert_eq!( + store1.get_table_replication_state(table_id).await.unwrap(), + Some(init_phase) + ); + assert_eq!( + store2.get_table_replication_state(table_id).await.unwrap(), + Some(data_sync_phase) + ); // Test schema isolation let table_schema1 = create_sample_table_schema(); @@ -407,7 +615,6 @@ async fn test_multiple_pipelines_isolation() { .await .unwrap(); - // Each pipeline sees only its own schemas let schemas1 = store1.get_table_schemas().await.unwrap(); assert_eq!(schemas1.len(), 1); assert_eq!(schemas1[0].id, table_schema1.id); @@ -415,13 +622,63 @@ async fn test_multiple_pipelines_isolation() { let schemas2 = store2.get_table_schemas().await.unwrap(); assert_eq!(schemas2.len(), 1); assert_eq!(schemas2[0].id, table_schema2.id); + + // Test destination metadata isolation + let metadata1 = DestinationTableMetadata::new_applied( + "pipeline1_table".to_string(), + SnapshotId::initial(), + ReplicationMask::from_bytes(vec![1, 1, 1]), + ); + let metadata2 = DestinationTableMetadata::new_applied( + "pipeline2_table".to_string(), + SnapshotId::initial(), + ReplicationMask::from_bytes(vec![1, 1, 1]), + ); + + store1 + .store_destination_table_metadata(table_id, metadata1.clone()) + .await + .unwrap(); + store2 + .store_destination_table_metadata(table_id, metadata2.clone()) + .await + .unwrap(); + + assert_eq!( + store1 + .get_destination_table_metadata(&table_id) + .await + .unwrap() + .map(|m| m.destination_table_id), + Some("pipeline1_table".to_string()) + ); + assert_eq!( + store2 + .get_destination_table_metadata(&table_id) + .await + .unwrap() + .map(|m| m.destination_table_id), + Some("pipeline2_table".to_string()) + ); + + // Verify isolation persists after loading from database + let new_store1 = PostgresStore::new(pipeline_id1, database.config.clone()); + new_store1.load_destination_tables_metadata().await.unwrap(); + assert_eq!( + new_store1 + .get_destination_table_metadata(&table_id) + .await + .unwrap() + .map(|m| m.destination_table_id), + Some("pipeline1_table".to_string()) + ); } #[tokio::test(flavor = "multi_thread")] async fn test_errored_state_with_different_retry_policies() { init_test_tracing(); - let database = spawn_source_database_for_store().await; + let database = spawn_source_database().await; let pipeline_id = 1; let table_id = TableId::new(12345); @@ -461,7 +718,7 @@ async fn test_errored_state_with_different_retry_policies() { async fn test_state_transitions_and_history() { init_test_tracing(); - let database = spawn_source_database_for_store().await; + let database = spawn_source_database().await; let pipeline_id = 1; let table_id = TableId::new(12345); @@ -534,154 +791,28 @@ async fn test_state_transitions_and_history() { } #[tokio::test(flavor = "multi_thread")] -async fn test_table_mappings_basic_operations() { +async fn test_cleanup_deletes_state_schema_and_metadata_for_table() { init_test_tracing(); - let database = spawn_source_database_for_store().await; + let database = spawn_source_database().await; let pipeline_id = 1; let store = PostgresStore::new(pipeline_id, database.config.clone()); - let table_id1 = TableId::new(12345); - let table_id2 = TableId::new(67890); - - // Test initial state - should be empty - let mapping = store.get_table_mapping(&table_id1).await.unwrap(); - assert!(mapping.is_none()); - - let all_mappings = store.get_table_mappings().await.unwrap(); - assert!(all_mappings.is_empty()); - - // Test storing and retrieving mappings + // Test idempotency: cleanup on non-existent table should succeed + let nonexistent_table_id = TableId::new(99999); store - .store_table_mapping(table_id1, "public_users_1".to_string()) + .cleanup_table_state(nonexistent_table_id) .await .unwrap(); - store - .store_table_mapping(table_id2, "public_orders_2".to_string()) - .await - .unwrap(); - - let all_mappings = store.get_table_mappings().await.unwrap(); - assert_eq!(all_mappings.len(), 2); - assert_eq!( - all_mappings.get(&table_id1), - Some(&"public_users_1".to_string()) - ); - assert_eq!( - all_mappings.get(&table_id2), - Some(&"public_orders_2".to_string()) - ); - - // Test updating an existing mapping (upsert) - store - .store_table_mapping(table_id1, "public_users_1_updated".to_string()) - .await - .unwrap(); - - let mapping = store.get_table_mapping(&table_id1).await.unwrap(); - assert_eq!(mapping, Some("public_users_1_updated".to_string())); -} - -#[tokio::test(flavor = "multi_thread")] -async fn test_table_mappings_persistence_and_loading() { - init_test_tracing(); - - let database = spawn_source_database_for_store().await; - let pipeline_id = 1; - - let store = PostgresStore::new(pipeline_id, database.config.clone()); - - // Store some mappings - store - .store_table_mapping(TableId::new(1), "dest_table_1".to_string()) - .await - .unwrap(); - store - .store_table_mapping(TableId::new(2), "dest_table_2".to_string()) - .await - .unwrap(); - - // Create a new store instance (simulating restart) - let new_store = PostgresStore::new(pipeline_id, database.config.clone()); - - // Initially empty cache - let mappings = new_store.get_table_mappings().await.unwrap(); - assert!(mappings.is_empty()); - - // Load all mappings from database - let loaded_count = new_store.load_table_mappings().await.unwrap(); - assert_eq!(loaded_count, 2); - - // Verify loaded mappings - let mappings = new_store.get_table_mappings().await.unwrap(); - assert_eq!(mappings.len(), 2); - assert_eq!( - mappings.get(&TableId::new(1)), - Some(&"dest_table_1".to_string()) - ); - assert_eq!( - mappings.get(&TableId::new(2)), - Some(&"dest_table_2".to_string()) - ); -} - -#[tokio::test(flavor = "multi_thread")] -async fn test_table_mappings_pipeline_isolation() { - init_test_tracing(); - - let database = spawn_source_database_for_store().await; - let pipeline_id1 = 1; - let pipeline_id2 = 2; - - let store1 = PostgresStore::new(pipeline_id1, database.config.clone()); - let store2 = PostgresStore::new(pipeline_id2, database.config.clone()); - - let table_id = TableId::new(12345); - - // Store different mappings for the same table ID in different pipelines - store1 - .store_table_mapping(table_id, "pipeline1_table".to_string()) - .await - .unwrap(); - - store2 - .store_table_mapping(table_id, "pipeline2_table".to_string()) - .await - .unwrap(); - - // Verify isolation - each pipeline sees only its own mapping - let mapping1 = store1.get_table_mapping(&table_id).await.unwrap(); - assert_eq!(mapping1, Some("pipeline1_table".to_string())); - - let mapping2 = store2.get_table_mapping(&table_id).await.unwrap(); - assert_eq!(mapping2, Some("pipeline2_table".to_string())); - - // Verify isolation persists after loading from database - let new_store1 = PostgresStore::new(pipeline_id1, database.config.clone()); - new_store1.load_table_mappings().await.unwrap(); - - let loaded_mapping1 = new_store1.get_table_mapping(&table_id).await.unwrap(); - assert_eq!(loaded_mapping1, Some("pipeline1_table".to_string())); -} - -#[tokio::test(flavor = "multi_thread")] -async fn test_cleanup_deletes_state_schema_and_mapping_for_table() { - init_test_tracing(); - - let database = spawn_source_database_for_store().await; - let pipeline_id = 1; - - let store = PostgresStore::new(pipeline_id, database.config.clone()); - // Prepare two tables: one we will delete, one we will keep let table_1_schema = create_sample_table_schema(); let table_1_id = table_1_schema.id; let table_2_schema = create_another_table_schema(); let table_2_id = table_2_schema.id; - // Populate state, schema, and mapping for both tables + // Populate state, schema, and metadata for both tables store .update_table_replication_state(table_1_id, TableReplicationPhase::Ready) .await @@ -700,12 +831,23 @@ async fn test_cleanup_deletes_state_schema_and_mapping_for_table() { .await .unwrap(); + let metadata1 = DestinationTableMetadata::new_applied( + "dest_table_1".to_string(), + SnapshotId::initial(), + ReplicationMask::from_bytes(vec![1, 1, 1]), + ); + let metadata2 = DestinationTableMetadata::new_applied( + "dest_table_2".to_string(), + SnapshotId::initial(), + ReplicationMask::from_bytes(vec![1, 1, 1]), + ); + store - .store_table_mapping(table_1_id, "dest_table_1".to_string()) + .store_destination_table_metadata(table_1_id, metadata1) .await .unwrap(); store - .store_table_mapping(table_2_id, "dest_table_2".to_string()) + .store_destination_table_metadata(table_2_id, metadata2) .await .unwrap(); @@ -717,10 +859,16 @@ async fn test_cleanup_deletes_state_schema_and_mapping_for_table() { .unwrap() .is_some() ); - assert!(store.get_table_schema(&table_1_id).await.unwrap().is_some()); assert!( store - .get_table_mapping(&table_1_id) + .get_table_schema(&table_1_id, SnapshotId::max()) + .await + .unwrap() + .is_some() + ); + assert!( + store + .get_destination_table_metadata(&table_1_id) .await .unwrap() .is_some() @@ -737,10 +885,16 @@ async fn test_cleanup_deletes_state_schema_and_mapping_for_table() { .unwrap() .is_none() ); - assert!(store.get_table_schema(&table_1_id).await.unwrap().is_none()); assert!( store - .get_table_mapping(&table_1_id) + .get_table_schema(&table_1_id, SnapshotId::max()) + .await + .unwrap() + .is_none() + ); + assert!( + store + .get_destination_table_metadata(&table_1_id) .await .unwrap() .is_none() @@ -754,10 +908,16 @@ async fn test_cleanup_deletes_state_schema_and_mapping_for_table() { .unwrap() .is_some() ); - assert!(store.get_table_schema(&table_2_id).await.unwrap().is_some()); assert!( store - .get_table_mapping(&table_2_id) + .get_table_schema(&table_2_id, SnapshotId::max()) + .await + .unwrap() + .is_some() + ); + assert!( + store + .get_destination_table_metadata(&table_2_id) .await .unwrap() .is_some() @@ -767,7 +927,7 @@ async fn test_cleanup_deletes_state_schema_and_mapping_for_table() { let new_store = PostgresStore::new(pipeline_id, database.config.clone()); new_store.load_table_replication_states().await.unwrap(); new_store.load_table_schemas().await.unwrap(); - new_store.load_table_mappings().await.unwrap(); + new_store.load_destination_tables_metadata().await.unwrap(); // Table 1 should not be present after reload assert!( @@ -779,14 +939,14 @@ async fn test_cleanup_deletes_state_schema_and_mapping_for_table() { ); assert!( new_store - .get_table_schema(&table_1_id) + .get_table_schema(&table_1_id, SnapshotId::max()) .await .unwrap() .is_none() ); assert!( new_store - .get_table_mapping(&table_1_id) + .get_destination_table_metadata(&table_1_id) .await .unwrap() .is_none() @@ -802,14 +962,14 @@ async fn test_cleanup_deletes_state_schema_and_mapping_for_table() { ); assert!( new_store - .get_table_schema(&table_2_id) + .get_table_schema(&table_2_id, SnapshotId::max()) .await .unwrap() .is_some() ); assert!( new_store - .get_table_mapping(&table_2_id) + .get_destination_table_metadata(&table_2_id) .await .unwrap() .is_some() @@ -817,51 +977,166 @@ async fn test_cleanup_deletes_state_schema_and_mapping_for_table() { } #[tokio::test(flavor = "multi_thread")] -async fn test_cleanup_idempotent_when_no_state_present() { +async fn test_replication_mask_loads_correctly_from_string_bytea() { init_test_tracing(); - let database = spawn_source_database_for_store().await; + let database = spawn_source_database().await; let pipeline_id = 1; + let table_id = TableId::new(12345); + + let pool = connect_to_source_database(&database.config, 1, 1) + .await + .expect("Failed to connect to source database with sqlx"); + + // Manually insert a row with a specific replication mask bytea. + // The mask [1, 0, 1, 1, 0] represents columns: replicated, not replicated, replicated, replicated, not replicated. + let expected_mask_bytes: Vec = vec![1, 0, 1, 1, 0]; + + sqlx::query( + r#" + INSERT INTO etl.destination_tables_metadata + (pipeline_id, table_id, destination_table_id, snapshot_id, schema_status, replication_mask) + VALUES ($1, $2, 'test_dest_table', '0/0'::pg_lsn, 'applied', '\x0100010100') + "#, + ) + .bind(pipeline_id as i64) + .bind(SqlxTableId(table_id.into_inner())) + .bind(&expected_mask_bytes) + .execute(&pool) + .await + .unwrap(); + + // Load metadata using the store let store = PostgresStore::new(pipeline_id, database.config.clone()); + store.load_destination_tables_metadata().await.unwrap(); - let table_schema = create_sample_table_schema(); - let table_id = table_schema.id; + // Verify the loaded replication mask matches what was inserted + let metadata = store + .get_destination_table_metadata(&table_id) + .await + .unwrap() + .expect("Metadata should exist"); - // Ensure no state exists yet - assert!( - store - .get_table_replication_state(table_id) - .await - .unwrap() - .is_none() + assert_eq!( + metadata.replication_mask.as_slice(), + &expected_mask_bytes, + "Loaded replication mask should match inserted bytea" ); - assert!(store.get_table_schema(&table_id).await.unwrap().is_none()); - assert!(store.get_table_mapping(&table_id).await.unwrap().is_none()); + assert_eq!(metadata.destination_table_id, "test_dest_table"); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_replication_mask_various_patterns() { + init_test_tracing(); - // Calling cleanup should succeed even if nothing exists - store.cleanup_table_state(table_id).await.unwrap(); + let database = spawn_source_database().await; + let pipeline_id = 1; - // Add state and clean up again - store - .update_table_replication_state(table_id, TableReplicationPhase::Init) + let pool = connect_to_source_database(&database.config, 1, 1) .await - .unwrap(); - store.store_table_schema(table_schema).await.unwrap(); - store - .store_table_mapping(table_id, "dest_table".to_string()) + .expect("Failed to connect to source database with sqlx"); + + // Test various mask patterns + let test_cases: Vec<(TableId, &str, Vec)> = vec![ + // All columns replicated + (TableId::new(1001), "all_ones", vec![1, 1, 1, 1, 1]), + // No columns replicated + (TableId::new(1002), "all_zeros", vec![0, 0, 0, 0]), + // Single column replicated + (TableId::new(1003), "single_one", vec![1]), + // Alternating pattern + (TableId::new(1004), "alternating", vec![1, 0, 1, 0, 1, 0]), + // Large mask (20 columns) + ( + TableId::new(1005), + "large", + vec![1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1], + ), + // Empty mask (table with no columns - edge case) + (TableId::new(1006), "empty", vec![]), + ]; + + // Insert all test cases + for (table_id, dest_name, mask_bytes) in &test_cases { + sqlx::query( + r#" + INSERT INTO etl.destination_tables_metadata + (pipeline_id, table_id, destination_table_id, snapshot_id, schema_status, replication_mask) + VALUES ($1, $2, $3, '0/0'::pg_lsn, 'applied', $4) + "#, + ) + .bind(pipeline_id as i64) + .bind(SqlxTableId(table_id.into_inner())) + .bind(*dest_name) + .bind(mask_bytes) + .execute(&pool) .await .unwrap(); + } - store.cleanup_table_state(table_id).await.unwrap(); + // Load all metadata using the store + let store = PostgresStore::new(pipeline_id, database.config.clone()); + store.load_destination_tables_metadata().await.unwrap(); - // Verify everything is gone - assert!( - store - .get_table_replication_state(table_id) + // Verify each test case + for (table_id, dest_name, expected_mask) in &test_cases { + let metadata = store + .get_destination_table_metadata(table_id) .await .unwrap() - .is_none() + .unwrap_or_else(|| panic!("Metadata for {dest_name} should exist")); + + assert_eq!( + metadata.replication_mask.as_slice(), + expected_mask.as_slice(), + "Mask mismatch for {}: expected {:?}, got {:?}", + dest_name, + expected_mask, + metadata.replication_mask.as_slice() + ); + assert_eq!( + metadata.destination_table_id, *dest_name, + "Destination table ID mismatch" + ); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_replication_mask_roundtrip() { + init_test_tracing(); + + let database = spawn_source_database().await; + let pipeline_id = 1; + let table_id = TableId::new(54321); + + // Create a store and save metadata with a specific mask + let original_mask = ReplicationMask::from_bytes(vec![1, 0, 1, 0, 1, 1, 0, 0]); + let metadata = DestinationTableMetadata::new_applied( + "roundtrip_table".to_string(), + SnapshotId::initial(), + original_mask.clone(), + ); + + let store = PostgresStore::new(pipeline_id, database.config.clone()); + store + .store_destination_table_metadata(table_id, metadata) + .await + .unwrap(); + + // Create a fresh store and load from database + let new_store = PostgresStore::new(pipeline_id, database.config.clone()); + new_store.load_destination_tables_metadata().await.unwrap(); + + // Verify the loaded mask matches the original + let loaded_metadata = new_store + .get_destination_table_metadata(&table_id) + .await + .unwrap() + .expect("Metadata should exist after loading"); + + assert_eq!( + loaded_metadata.replication_mask.as_slice(), + original_mask.as_slice(), + "Roundtrip should preserve replication mask exactly" ); - assert!(store.get_table_schema(&table_id).await.unwrap().is_none()); - assert!(store.get_table_mapping(&table_id).await.unwrap().is_none()); } diff --git a/etl/tests/replication.rs b/etl/tests/replication.rs index 806ca818f..6ecaee891 100644 --- a/etl/tests/replication.rs +++ b/etl/tests/replication.rs @@ -6,7 +6,7 @@ use etl::error::ErrorKind; use etl::replication::client::PgReplicationClient; use etl::test_utils::database::{spawn_source_database, test_table_name}; use etl::test_utils::pipeline::test_slot_name; -use etl::test_utils::table::assert_table_schema; +use etl::test_utils::schema::assert_table_schema_columns; use etl::test_utils::test_schema::create_partitioned_table; use etl_postgres::below_version; use etl_postgres::tokio::test_utils::{TableModification, id_column_schema}; @@ -20,6 +20,24 @@ use tokio::pin; use tokio_postgres::CopyOutStream; use tokio_postgres::types::{ToSql, Type}; +/// Creates a test column schema with sensible defaults. +fn test_column( + name: &str, + typ: Type, + ordinal_position: i32, + nullable: bool, + primary_key: bool, +) -> ColumnSchema { + ColumnSchema::new( + name.to_string(), + typ, + -1, + ordinal_position, + if primary_key { Some(1) } else { None }, + nullable, + ) +} + async fn count_stream_rows(stream: CopyOutStream) -> u64 { pin!(stream); @@ -172,13 +190,7 @@ async fn test_table_schema_copy_is_consistent() { .await .unwrap(); - let age_schema = ColumnSchema { - name: "age".to_string(), - typ: Type::INT4, - modifier: -1, - nullable: true, - primary: false, - }; + let age_schema = test_column("age", Type::INT4, 2, true, false); let table_1_id = database .create_table(test_table_name("table_1"), true, &[("age", "integer")]) @@ -192,17 +204,12 @@ async fn test_table_schema_copy_is_consistent() { .unwrap(); // We use the transaction to consistently read the table schemas. - let table_1_schema = transaction - .get_table_schemas(&[table_1_id], None) - .await - .unwrap(); + let table_1_schemas = transaction.get_table_schemas(&[table_1_id]).await.unwrap(); transaction.commit().await.unwrap(); - assert_table_schema( - &table_1_schema, - table_1_id, - test_table_name("table_1"), - &[id_column_schema(), age_schema.clone()], - ); + let table_1_schema = table_1_schemas.get(&table_1_id).unwrap(); + assert_eq!(table_1_schema.id, table_1_id); + assert_eq!(table_1_schema.name, test_table_name("table_1")); + assert_table_schema_columns(table_1_schema, &[id_column_schema(), age_schema.clone()]); } #[tokio::test(flavor = "multi_thread")] @@ -217,20 +224,8 @@ async fn test_table_schema_copy_across_multiple_connections() { .await .unwrap(); - let age_schema = ColumnSchema { - name: "age".to_string(), - typ: Type::INT4, - modifier: -1, - nullable: true, - primary: false, - }; - let year_schema = ColumnSchema { - name: "year".to_string(), - typ: Type::INT4, - modifier: -1, - nullable: true, - primary: false, - }; + let age_schema = test_column("age", Type::INT4, 2, true, false); + let year_schema = test_column("year", Type::INT4, 3, true, false); let table_1_id = database .create_table(test_table_name("table_1"), true, &[("age", "integer")]) @@ -244,17 +239,12 @@ async fn test_table_schema_copy_across_multiple_connections() { .unwrap(); // We use the transaction to consistently read the table schemas. - let table_1_schema = transaction - .get_table_schemas(&[table_1_id], None) - .await - .unwrap(); + let table_1_schemas = transaction.get_table_schemas(&[table_1_id]).await.unwrap(); transaction.commit().await.unwrap(); - assert_table_schema( - &table_1_schema, - table_1_id, - test_table_name("table_1"), - &[id_column_schema(), age_schema.clone()], - ); + let table_1_schema = table_1_schemas.get(&table_1_id).unwrap(); + assert_eq!(table_1_schema.id, table_1_id); + assert_eq!(table_1_schema.name, test_table_name("table_1")); + assert_table_schema_columns(table_1_schema, &[id_column_schema(), age_schema.clone()]); // We create a new table in the database and update the schema of the old one. let table_2_id = database @@ -279,27 +269,20 @@ async fn test_table_schema_copy_across_multiple_connections() { .unwrap(); // We use the transaction to consistently read the table schemas. - let table_1_schema = transaction - .get_table_schemas(&[table_1_id], None) - .await - .unwrap(); - let table_2_schema = transaction - .get_table_schemas(&[table_2_id], None) - .await - .unwrap(); + let table_1_schemas = transaction.get_table_schemas(&[table_1_id]).await.unwrap(); + let table_2_schemas = transaction.get_table_schemas(&[table_2_id]).await.unwrap(); transaction.commit().await.unwrap(); - assert_table_schema( - &table_1_schema, - table_1_id, - test_table_name("table_1"), + let table_1_schema = table_1_schemas.get(&table_1_id).unwrap(); + assert_eq!(table_1_schema.id, table_1_id); + assert_eq!(table_1_schema.name, test_table_name("table_1")); + assert_table_schema_columns( + table_1_schema, &[id_column_schema(), age_schema.clone(), year_schema.clone()], ); - assert_table_schema( - &table_2_schema, - table_2_id, - test_table_name("table_2"), - &[id_column_schema(), year_schema], - ); + let table_2_schema = table_2_schemas.get(&table_2_id).unwrap(); + assert_eq!(table_2_schema.id, table_2_id); + assert_eq!(table_2_schema.name, test_table_name("table_2")); + assert_table_schema_columns(table_2_schema, &[id_column_schema(), year_schema]); } #[tokio::test(flavor = "multi_thread")] @@ -343,18 +326,9 @@ async fn test_table_copy_stream_is_consistent() { .unwrap(); // We create a transaction to copy the table data consistently. + let columns = [test_column("age", Type::INT4, 2, true, false)]; let stream = transaction - .get_table_copy_stream( - table_1_id, - &[ColumnSchema { - name: "age".to_string(), - typ: Type::INT4, - modifier: -1, - nullable: true, - primary: false, - }], - None, - ) + .get_table_copy_stream(table_1_id, columns.iter(), None) .await .unwrap(); @@ -419,18 +393,9 @@ async fn test_table_copy_stream_respects_row_filter() { .unwrap(); // We create a transaction to copy the table data consistently. + let columns = [test_column("age", Type::INT4, 2, true, false)]; let stream = transaction - .get_table_copy_stream( - test_table_id, - &[ColumnSchema { - name: "age".to_string(), - typ: Type::INT4, - modifier: -1, - nullable: true, - primary: false, - }], - Some("test_pub"), - ) + .get_table_copy_stream(test_table_id, columns.iter(), Some("test_pub")) .await .unwrap(); @@ -444,7 +409,7 @@ async fn test_table_copy_stream_respects_row_filter() { } #[tokio::test(flavor = "multi_thread")] -async fn test_table_copy_stream_respects_column_filter() { +async fn test_get_replicated_column_names_respects_column_filter() { init_test_tracing(); let database = spawn_source_database().await; @@ -485,71 +450,286 @@ async fn test_table_copy_stream_respects_column_filter() { .await .unwrap(); - // Insert test data with all columns. + // Create the slot when the database schema contains the test data. + let (transaction, _) = parent_client + .create_slot_with_transaction(&test_slot_name("my_slot")) + .await + .unwrap(); + + // Get table schema without publication filter - should include ALL columns. + let table_schemas = transaction + .get_table_schemas(&[test_table_id]) + .await + .unwrap(); + let table_schema = &table_schemas[&test_table_id]; + + // Verify all columns are present in the schema. + assert_eq!(table_schema.id, test_table_id); + assert_eq!(table_schema.name, test_table_name); + assert_table_schema_columns( + table_schema, + &[ + id_column_schema(), + test_column("name", Type::TEXT, 2, true, false), + test_column("age", Type::INT4, 3, true, false), + test_column("email", Type::TEXT, 4, true, false), + ], + ); + + // Get replicated column names from the publication - should only include published columns. + let replicated_columns = transaction + .get_replicated_column_names(test_table_id, table_schema, publication_name) + .await + .unwrap(); + + // Transaction should be committed after queries are done. + transaction.commit().await.unwrap(); + + // Verify only the published columns are returned (id, name, age - not email). + assert_eq!(replicated_columns.len(), 3); + assert!(replicated_columns.contains("id")); + assert!(replicated_columns.contains("name")); + assert!(replicated_columns.contains("age")); + assert!(!replicated_columns.contains("email")); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_get_replicated_column_names_for_all_tables_publication() { + init_test_tracing(); + let database = spawn_source_database().await; + + // Column filters in publication are only available from Postgres 15+. + if below_version!(database.server_version(), POSTGRES_15) { + eprintln!("Skipping test: PostgreSQL 15+ required for column filters"); + return; + } + + // Create a table with multiple columns. + let test_table_name = test_table_name("table_1"); + let test_table_id = database + .create_table( + test_table_name.clone(), + true, + &[("name", "text"), ("age", "integer"), ("email", "text")], + ) + .await + .unwrap(); + database .run_sql(&format!( - "insert into {test_table_name} (name, age, email) values ('Alice', 25, 'alice@example.com')" + "alter table {test_table_name} replica identity full" )) .await .unwrap(); + + // Create a FOR ALL TABLES publication. Column filtering is NOT supported with this type. + let publication_name = "test_pub_all_tables"; database .run_sql(&format!( - "insert into {test_table_name} (name, age, email) values ('Bob', 30, 'bob@example.com')" + "create publication {publication_name} for all tables" )) .await .unwrap(); - // Create the slot when the database schema contains the test data. + let parent_client = PgReplicationClient::connect(database.config.clone()) + .await + .unwrap(); + let (transaction, _) = parent_client .create_slot_with_transaction(&test_slot_name("my_slot")) .await .unwrap(); - // Get table schema with the publication - should only include published columns. + // Get table schema. let table_schemas = transaction - .get_table_schemas(&[test_table_id], Some(publication_name)) + .get_table_schemas(&[test_table_id]) .await .unwrap(); - assert_table_schema( - &table_schemas, - test_table_id, - test_table_name, - &[ - id_column_schema(), - ColumnSchema { - name: "name".to_string(), - typ: Type::TEXT, - modifier: -1, - nullable: true, - primary: false, - }, - ColumnSchema { - name: "age".to_string(), - typ: Type::INT4, - modifier: -1, - nullable: true, - primary: false, - }, - ], - ); + let table_schema = &table_schemas[&test_table_id]; - // Get table copy stream with the publication. - let stream = transaction - .get_table_copy_stream( - test_table_id, - &table_schemas[&test_table_id].column_schemas, - Some("test_pub"), + // Get replicated column names - FOR ALL TABLES doesn't support column filtering, + // so all columns should be returned. + let replicated_columns = transaction + .get_replicated_column_names(test_table_id, table_schema, publication_name) + .await + .unwrap(); + + transaction.commit().await.unwrap(); + + // All columns should be returned since FOR ALL TABLES doesn't support column filtering. + assert_eq!(replicated_columns.len(), 4); + assert!(replicated_columns.contains("id")); + assert!(replicated_columns.contains("name")); + assert!(replicated_columns.contains("age")); + assert!(replicated_columns.contains("email")); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_get_replicated_column_names_for_tables_in_schema_publication() { + init_test_tracing(); + let database = spawn_source_database().await; + + // Column filters in publication are only available from Postgres 15+. + if below_version!(database.server_version(), POSTGRES_15) { + eprintln!("Skipping test: PostgreSQL 15+ required for column filters"); + return; + } + + // Create a table with multiple columns. + let test_table_name = test_table_name("table_1"); + let test_table_id = database + .create_table( + test_table_name.clone(), + true, + &[("name", "text"), ("age", "integer"), ("email", "text")], ) .await .unwrap(); - let rows_count = count_stream_rows(stream).await; + database + .run_sql(&format!( + "alter table {test_table_name} replica identity full" + )) + .await + .unwrap(); + + // Create a FOR TABLES IN SCHEMA publication. Column filtering is NOT supported with this type. + // Note: Tables are created in the "test" schema by test_table_name(). + let publication_name = "test_pub_schema"; + database + .run_sql(&format!( + "create publication {publication_name} for tables in schema test" + )) + .await + .unwrap(); + + let parent_client = PgReplicationClient::connect(database.config.clone()) + .await + .unwrap(); + + let (transaction, _) = parent_client + .create_slot_with_transaction(&test_slot_name("my_slot")) + .await + .unwrap(); + + // Get table schema. + let table_schemas = transaction + .get_table_schemas(&[test_table_id]) + .await + .unwrap(); + let table_schema = &table_schemas[&test_table_id]; + + // Get replicated column names - FOR TABLES IN SCHEMA doesn't support column filtering, + // so all columns should be returned. + let replicated_columns = transaction + .get_replicated_column_names(test_table_id, table_schema, publication_name) + .await + .unwrap(); + + transaction.commit().await.unwrap(); + + // All columns should be returned since FOR TABLES IN SCHEMA doesn't support column filtering. + assert_eq!(replicated_columns.len(), 4); + assert!(replicated_columns.contains("id")); + assert!(replicated_columns.contains("name")); + assert!(replicated_columns.contains("age")); + assert!(replicated_columns.contains("email")); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_get_replicated_column_names_errors_when_table_not_in_publication() { + init_test_tracing(); + let database = spawn_source_database().await; + + // Column filters in publication are only available from Postgres 15+. + if below_version!(database.server_version(), POSTGRES_15) { + eprintln!("Skipping test: PostgreSQL 15+ required for column filters"); + return; + } + + // Create a table with multiple columns. + let table_1_name = test_table_name("table_1"); + let table_1_id = database + .create_table( + table_1_name.clone(), + true, + &[("name", "text"), ("age", "integer")], + ) + .await + .unwrap(); + + database + .run_sql(&format!("alter table {table_1_name} replica identity full")) + .await + .unwrap(); + + // Create a second table that WILL be in the publication. + let table_2_name = test_table_name("table_2"); + database + .create_table(table_2_name.clone(), true, &[("data", "text")]) + .await + .unwrap(); + + // Create publication for only the second table, NOT including table_1. + let publication_name = "test_pub_other"; + database + .run_sql(&format!( + "create publication {publication_name} for table {table_2_name}" + )) + .await + .unwrap(); + + let parent_client = PgReplicationClient::connect(database.config.clone()) + .await + .unwrap(); + + let (transaction, _) = parent_client + .create_slot_with_transaction(&test_slot_name("my_slot")) + .await + .unwrap(); + + // Get table schema for the table NOT in the publication. + let table_schemas = transaction.get_table_schemas(&[table_1_id]).await.unwrap(); + let table_schema = &table_schemas[&table_1_id]; + + // Attempting to get replicated column names for a table not in the publication should error. + let result = transaction + .get_replicated_column_names(table_1_id, table_schema, publication_name) + .await; - // Transaction should be committed after the copy stream is exhausted. transaction.commit().await.unwrap(); - // We expect to have 2 rows (the ones we inserted). - assert_eq!(rows_count, 2); + // Should return a ConfigError since the table is not in the publication. + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!(err.kind(), ErrorKind::ConfigError); + assert!(err.to_string().contains("not included in publication")); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_get_publication_table_ids_errors_when_empty() { + init_test_tracing(); + let database = spawn_source_database().await; + + // Create an empty publication (no tables). + let publication_name = "test_pub_empty"; + database + .run_sql(&format!("create publication {publication_name}")) + .await + .unwrap(); + + let client = PgReplicationClient::connect(database.config.clone()) + .await + .unwrap(); + + // Attempting to get table IDs from an empty publication should error. + let result = client.get_publication_table_ids(publication_name).await; + + // Should return a ConfigError since the publication has no tables. + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!(err.kind(), ErrorKind::ConfigError); + assert!(err.to_string().contains("does not contain any tables")); } #[tokio::test(flavor = "multi_thread")] @@ -594,18 +774,9 @@ async fn test_table_copy_stream_no_row_filter() { .unwrap(); // We create a transaction to copy the table data consistently. + let columns = [test_column("age", Type::INT4, 2, true, false)]; let stream = transaction - .get_table_copy_stream( - test_table_id, - &[ColumnSchema { - name: "age".to_string(), - typ: Type::INT4, - modifier: -1, - nullable: true, - primary: false, - }], - Some("test_pub"), - ) + .get_table_copy_stream(test_table_id, columns.iter(), Some("test_pub")) .await .unwrap(); diff --git a/scripts/docker-compose.yaml b/scripts/docker-compose.yaml index bc4d07fc0..91c5856a4 100644 --- a/scripts/docker-compose.yaml +++ b/scripts/docker-compose.yaml @@ -19,12 +19,16 @@ services: - "${POSTGRES_PORT:-5430}:5432" volumes: - ${POSTGRES_DATA_VOLUME:-postgres_data}:/var/lib/postgresql/data + # These parameters are configured for local testing only. + # In production, use PostgreSQL defaults or your own tuned values. + # `wal_sender_timeout/2` is the keepalive timeout for streaming replication. command: > postgres -N 1000 -c wal_level=logical -c max_wal_senders=100 -c max_replication_slots=100 + -c wal_sender_timeout=10s restart: unless-stopped healthcheck: test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-postgres}"] diff --git a/scripts/init.sh b/scripts/init.sh index 7b95708fb..c21c07402 100755 --- a/scripts/init.sh +++ b/scripts/init.sh @@ -73,7 +73,7 @@ echo "🔗 Database URL: ${DATABASE_URL}" # Run database migrations echo "🔄 Running database migrations..." SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -bash "${SCRIPT_DIR}/../etl-api/scripts/run_migrations.sh" +bash "${SCRIPT_DIR}/run_migrations.sh" # Seed default replicator image (idempotent). echo "🖼️ Seeding default replicator image..." diff --git a/scripts/run_migrations.sh b/scripts/run_migrations.sh new file mode 100755 index 000000000..f8fe08d7b --- /dev/null +++ b/scripts/run_migrations.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash +set -eo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(dirname "$SCRIPT_DIR")" + +usage() { + echo "Usage: $0 [OPTIONS] [TARGETS...]" + echo "" + echo "Run database migrations for etl components." + echo "" + echo "Targets:" + echo " etl-api Run etl-api migrations (public schema)" + echo " etl Run etl migrations (etl schema)" + echo " all Run all migrations (default if no target specified)" + echo "" + echo "Options:" + echo " -h, --help Show this help message" + echo "" + echo "Examples:" + echo " $0 # Run all migrations" + echo " $0 etl-api # Run only etl-api migrations" + echo " $0 etl # Run only etl migrations" + echo " $0 etl-api etl # Run both explicitly" +} + +check_sqlx() { + if ! [ -x "$(command -v sqlx)" ]; then + echo >&2 "Error: SQLx CLI is not installed." + echo >&2 "To install it, run:" + echo >&2 " cargo install --version='~0.7' sqlx-cli --no-default-features --features rustls,postgres" + exit 1 + fi +} + +check_psql() { + if ! [ -x "$(command -v psql)" ]; then + echo >&2 "Error: Postgres client (psql) is not installed." + echo >&2 "Please install it using your system's package manager." + exit 1 + fi +} + +setup_database_url() { + DB_USER="${POSTGRES_USER:=postgres}" + DB_PASSWORD="${POSTGRES_PASSWORD:=postgres}" + DB_NAME="${POSTGRES_DB:=postgres}" + DB_PORT="${POSTGRES_PORT:=5430}" + DB_HOST="${POSTGRES_HOST:=localhost}" + + export DATABASE_URL="postgres://${DB_USER}:${DB_PASSWORD}@${DB_HOST}:${DB_PORT}/${DB_NAME}" +} + +run_etl_api_migrations() { + local migrations_dir="${ROOT_DIR}/etl-api/migrations" + + if [ ! -d "$migrations_dir" ]; then + echo >&2 "Error: 'etl-api/migrations' folder not found at $migrations_dir" + exit 1 + fi + + echo "Running etl-api migrations..." + sqlx database create + sqlx migrate run --source "$migrations_dir" + echo "etl-api migrations complete!" +} + +run_etl_migrations() { + local migrations_dir="${ROOT_DIR}/etl/migrations" + + if [ ! -d "$migrations_dir" ]; then + echo >&2 "Error: 'etl/migrations' folder not found at $migrations_dir" + exit 1 + fi + + echo "Running etl migrations..." + + # Create the etl schema if it doesn't exist. + # This matches the behavior in etl/src/migrations.rs. + psql "${DATABASE_URL}" -v ON_ERROR_STOP=1 -c "create schema if not exists etl;" > /dev/null + + # Create a temporary sqlx-cli compatible database URL that sets the search_path. + # This ensures the _sqlx_migrations table is created in the etl schema. + local sqlx_migrations_opts="options=-csearch_path%3Detl" + local migration_url="${DATABASE_URL}?${sqlx_migrations_opts}" + + sqlx database create --database-url "${DATABASE_URL}" + sqlx migrate run --source "$migrations_dir" --database-url "${migration_url}" + echo "etl migrations complete!" +} + +# Parse arguments +RUN_ETL_API=false +RUN_ETL=false + +if [ $# -eq 0 ]; then + RUN_ETL_API=true + RUN_ETL=true +fi + +for arg in "$@"; do + case "$arg" in + -h|--help) + usage + exit 0 + ;; + etl-api) + RUN_ETL_API=true + ;; + etl) + RUN_ETL=true + ;; + all) + RUN_ETL_API=true + RUN_ETL=true + ;; + *) + echo >&2 "Error: Unknown argument '$arg'" + usage + exit 1 + ;; + esac +done + +# Check dependencies +check_sqlx +if [ "$RUN_ETL" = true ]; then + check_psql +fi + +# Setup database URL +setup_database_url + +# Run migrations +if [ "$RUN_ETL_API" = true ]; then + run_etl_api_migrations +fi + +if [ "$RUN_ETL" = true ]; then + run_etl_migrations +fi + +echo "All requested migrations complete!"