Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 156 additions & 9 deletions crates/core/src/operations/merge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1501,16 +1501,19 @@ fn modify_schema(
let error = arrow::error::ArrowError::SchemaError("Schema evolved fields cannot have generated expressions. Recreate the table to achieve this.".to_string());
return Err(DeltaTableError::Arrow { source: error });
}

if let Ok(target_field) = target_schema.field_from_column(columns) {
// for nested data types we need to first merge then see if there a change then replace the pre-existing field
let new_field = merge_arrow_field(target_field, source_field, true)?;
if &new_field == target_field {
continue;
//Check if the columns in the source schema exist in the target schema
match target_schema.field_from_column(columns) {
Ok(target_field) => {
// This case is when there is an added column in an nested datatype
let new_field = merge_arrow_field(target_field, source_field, true)?;
if &new_field != target_field {
ending_schema.try_merge(&Arc::new(new_field))?;
}
}
Err(_) => {
// This function is called multiple time with different operations so this handle any collisions
ending_schema.try_merge(&Arc::new(source_field.to_owned().with_nullable(true)))?;
}
ending_schema.try_merge(&Arc::new(new_field))?;
} else {
ending_schema.push(source_field.to_owned().with_nullable(true));
}
}
Ok(())
Expand Down Expand Up @@ -2277,6 +2280,150 @@ mod tests {
);
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_schema_evolution_simple_update_with_simple_insert() {
let (table, _) = setup().await;

let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new("inserted_by", ArrowDataType::Utf8, true),
]));
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![50, 200, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
Arc::new(arrow::array::StringArray::from(vec!["B1", "C1", "X1"])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();

let (table, _) = DeltaOps(table)
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_matched_update(|update| {
update
.update("value", col("source.value").add(lit(1)))
.update("modified", col("source.modified"))
.update("inserted_by", col("source.inserted_by"))
})
.unwrap()
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
.set("inserted_by", "source.inserted_by")
})
.unwrap()
.await
.unwrap();

let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
let expected = vec![
"+----+-------+------------+-------------+",
"| id | value | modified | inserted_by |",
"+----+-------+------------+-------------+",
"| A | 1 | 2021-02-01 | |",
"| B | 51 | 2021-02-02 | B1 |",
"| C | 201 | 2023-07-04 | C1 |",
"| D | 100 | 2021-02-02 | |",
"| X | 30 | 2023-07-04 | X1 |",
"+----+-------+------------+-------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = Arc::clone(&schema).try_into_kernel().unwrap();
assert_eq!(
&expected_schema_struct,
table.snapshot().unwrap().schema().as_ref()
);
assert_batches_sorted_eq!(&expected, &actual);
}
#[tokio::test]
async fn test_merge_schema_evolution_simple_insert_with_simple_update() {
let (table, _) = setup().await;

let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", ArrowDataType::Utf8, true),
Field::new("value", ArrowDataType::Int32, true),
Field::new("modified", ArrowDataType::Utf8, true),
Field::new("inserted_by", ArrowDataType::Utf8, true),
]));
let ctx = SessionContext::new();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])),
Arc::new(arrow::array::Int32Array::from(vec![50, 200, 30])),
Arc::new(arrow::array::StringArray::from(vec![
"2021-02-02",
"2023-07-04",
"2023-07-04",
])),
Arc::new(arrow::array::StringArray::from(vec!["B1", "C1", "X1"])),
],
)
.unwrap();
let source = ctx.read_batch(batch).unwrap();

let (table, _) = DeltaOps(table)
.merge(source, col("target.id").eq(col("source.id")))
.with_source_alias("source")
.with_target_alias("target")
.with_merge_schema(true)
.when_not_matched_insert(|insert| {
insert
.set("id", col("source.id"))
.set("value", col("source.value"))
.set("modified", col("source.modified"))
.set("inserted_by", "source.inserted_by")
})
.unwrap()
.when_matched_update(|update| {
update
.update("value", col("source.value").add(lit(1)))
.update("modified", col("source.modified"))
.update("inserted_by", col("source.inserted_by"))
})
.unwrap()
.await
.unwrap();

let last_commit = table.last_commit().await.unwrap();
let parameters = last_commit.operation_parameters.clone().unwrap();
assert_eq!(parameters["mergePredicate"], json!("target.id = source.id"));
let expected = vec![
"+----+-------+------------+-------------+",
"| id | value | modified | inserted_by |",
"+----+-------+------------+-------------+",
"| A | 1 | 2021-02-01 | |",
"| B | 51 | 2021-02-02 | B1 |",
"| C | 201 | 2023-07-04 | C1 |",
"| D | 100 | 2021-02-02 | |",
"| X | 30 | 2023-07-04 | X1 |",
"+----+-------+------------+-------------+",
];
let actual = get_data(&table).await;
let expected_schema_struct: StructType = Arc::clone(&schema).try_into_kernel().unwrap();
assert_eq!(
&expected_schema_struct,
table.snapshot().unwrap().schema().as_ref()
);
assert_batches_sorted_eq!(&expected, &actual);
}

#[tokio::test]
async fn test_merge_schema_evolution_simple_insert() {
Expand Down
77 changes: 77 additions & 0 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,83 @@ def test_merge_when_matched_update_wo_predicate_with_schema_evolution(
assert result == expected


def test_merge_when_matched_update_wo_predicate_and_insert_with_schema_evolution(
tmp_path: pathlib.Path, sample_table: Table
):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

source_table = Table(
{
"id": Array(
["4", "5"],
ArrowField("id", type=DataType.string(), nullable=True),
),
"price": Array(
[10, 100],
ArrowField("price", type=DataType.int64(), nullable=True),
),
"sold": Array(
[10, 20],
ArrowField("sold", type=DataType.int32(), nullable=True),
),
"customer": Array(
["john", "doe"],
ArrowField("customer", type=DataType.string(), nullable=True),
),
},
)

dt.merge(
source=source_table,
predicate="t.id = s.id",
source_alias="s",
target_alias="t",
merge_schema=True,
).when_matched_update(
{"price": "s.price", "sold": "s.sold+int'10'", "customer": "s.customer"}
).when_not_matched_insert_all().execute()

expected = Table(
{
"id": Array(
["1", "2", "3", "4", "5"],
ArrowField("id", type=DataType.string(), nullable=True),
),
"price": Array(
[0, 1, 2, 10, 100],
ArrowField("price", type=DataType.int64(), nullable=True),
),
"sold": Array(
[0, 1, 2, 20, 30],
ArrowField("sold", type=DataType.int32(), nullable=True),
),
"deleted": Array(
[False] * 5,
ArrowField("deleted", type=DataType.bool(), nullable=True),
),
"customer": Array(
[None, None, None, "john", "doe"],
ArrowField("customer", type=DataType.string(), nullable=True),
),
},
)

result = (
QueryBuilder()
.register("tbl", dt)
.execute("select * from tbl order by id asc")
.read_all()
)

last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result.schema == expected.schema
assert result == expected


@pytest.mark.parametrize("streaming", (True, False))
def test_merge_when_matched_update_all_wo_predicate(
tmp_path: pathlib.Path, sample_table: Table, streaming: bool
Expand Down
Loading