Skip to content

Commit 9126733

Browse files
authored
feat(api): implement upsert() using MERGE INTO (#11624)
## Description of changes Implement `Backend.upsert()` using `sqlglot.expressions.merge()` under the hood. Upsert support is very important, especially for data engineering use cases. Starting with a most basic implementation, including only supporting one join column. I think this could be expanded to support a list without much effort. `MERGE INTO` support is limited. [DuckDB only added support for `MERGE` statements earlier today in 1.4.0](https://duckdb.org/2025/09/16/announcing-duckdb-140.html#merge-statement), and many other backends don't support it. However, it seems like the more standard/correct approach for supporting upserts, and it doesn't require merge keys defined ahead of time on tables. Backends that work: - DuckDB (from 1.4.0) - Flink - Oracle ~(currently using a hack to work around "AS" getting added to MERGE statement)~ - MS SQL (currently throwing a `;` onto the end of ~every~ statement 😅) - Postgres Should work, need help to test: - Databricks - Snowflake - BigQuery Backends that don't work: - PySpark ("MERGE INTO TABLE is not supported temporarily.") - Clickhouse - DataFusion - SQLite (supports the nonstandard UPSERT statement) - Impala ("The `MERGE` statement is only supported for Iceberg tables.") - MySQL - Polars - RisingWave - Athena ("`MERGE INTO` is transactional and is supported only for Apache Iceberg tables in Athena engine version 3.") - Trino ("connector does not support modifying table rows") ## Issues closed * Resolves #5391 --------- Signed-off-by: Deepyaman Datta <[email protected]>
1 parent 185fdf5 commit 9126733

File tree

4 files changed

+318
-10
lines changed

4 files changed

+318
-10
lines changed

ibis/backends/sql/__init__.py

Lines changed: 123 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def insert(
423423
Parameters
424424
----------
425425
name
426-
The name of the table to which data needs will be inserted
426+
The name of the table to which data will be inserted
427427
obj
428428
The source data or expression to insert
429429
database
@@ -453,22 +453,30 @@ def insert(
453453
with self._safe_raw_sql(query):
454454
pass
455455

456-
def _build_insert_from_table(
456+
def _get_columns_to_insert(
457457
self, *, target: str, source, db: str | None = None, catalog: str | None = None
458458
):
459-
compiler = self.compiler
460-
quoted = compiler.quoted
461459
# Compare the columns between the target table and the object to be inserted
462460
# If source is a subset of target, use source columns for insert list
463461
# Otherwise, assume auto-generated column names and use positional ordering.
464462
target_cols = self.get_schema(target, catalog=catalog, database=db).keys()
465463

466-
columns = (
464+
return (
467465
source_cols
468466
if (source_cols := source.schema().keys()) <= target_cols
469467
else target_cols
470468
)
471469

470+
def _build_insert_from_table(
471+
self, *, target: str, source, db: str | None = None, catalog: str | None = None
472+
):
473+
compiler = self.compiler
474+
quoted = compiler.quoted
475+
476+
columns = self._get_columns_to_insert(
477+
target=target, source=source, db=db, catalog=catalog
478+
)
479+
472480
query = sge.insert(
473481
expression=self.compile(source),
474482
into=sg.table(target, db=db, catalog=catalog, quoted=quoted),
@@ -526,6 +534,116 @@ def _build_insert_template(
526534
),
527535
).sql(self.dialect)
528536

537+
def upsert(
538+
self,
539+
name: str,
540+
/,
541+
obj: pd.DataFrame | ir.Table | list | dict,
542+
on: str,
543+
*,
544+
database: str | None = None,
545+
) -> None:
546+
"""Upsert data into a table.
547+
548+
::: {.callout-note}
549+
## Ibis does not use the word `schema` to refer to database hierarchy.
550+
551+
A collection of `table` is referred to as a `database`.
552+
A collection of `database` is referred to as a `catalog`.
553+
554+
These terms are mapped onto the corresponding features in each
555+
backend (where available), regardless of whether the backend itself
556+
uses the same terminology.
557+
:::
558+
559+
Parameters
560+
----------
561+
name
562+
The name of the table to which data will be upserted
563+
obj
564+
The source data or expression to upsert
565+
on
566+
Column name to join on
567+
database
568+
Name of the attached database that the table is located in.
569+
570+
For backends that support multi-level table hierarchies, you can
571+
pass in a dotted string path like `"catalog.database"` or a tuple of
572+
strings like `("catalog", "database")`.
573+
"""
574+
table_loc = self._to_sqlglot_table(database)
575+
catalog, db = self._to_catalog_db_tuple(table_loc)
576+
577+
if not isinstance(obj, ir.Table):
578+
obj = ibis.memtable(obj)
579+
580+
self._run_pre_execute_hooks(obj)
581+
582+
query = self._build_upsert_from_table(
583+
target=name, source=obj, on=on, db=db, catalog=catalog
584+
)
585+
586+
with self._safe_raw_sql(query):
587+
pass
588+
589+
def _build_upsert_from_table(
590+
self,
591+
*,
592+
target: str,
593+
source,
594+
on: str,
595+
db: str | None = None,
596+
catalog: str | None = None,
597+
):
598+
compiler = self.compiler
599+
quoted = compiler.quoted
600+
601+
columns = self._get_columns_to_insert(
602+
target=target, source=source, db=db, catalog=catalog
603+
)
604+
605+
source_alias = util.gen_name("source")
606+
target_alias = util.gen_name("target")
607+
query = sge.merge(
608+
sge.When(
609+
matched=True,
610+
then=sge.Update(
611+
expressions=[
612+
sg.column(col, quoted=quoted).eq(
613+
sg.column(col, table=source_alias, quoted=quoted)
614+
)
615+
for col in columns
616+
if col != on
617+
]
618+
),
619+
),
620+
sge.When(
621+
matched=False,
622+
then=sge.Insert(
623+
this=sge.Tuple(
624+
expressions=[sg.column(col, quoted=quoted) for col in columns]
625+
),
626+
expression=sge.Tuple(
627+
expressions=[
628+
sg.column(col, table=source_alias, quoted=quoted)
629+
for col in columns
630+
]
631+
),
632+
),
633+
),
634+
into=sg.table(target, db=db, catalog=catalog, quoted=quoted).as_(
635+
sg.to_identifier(target_alias, quoted=quoted), table=True
636+
),
637+
using=f"({self.compile(source)}) AS {sg.to_identifier(source_alias, quoted=quoted)}",
638+
on=sge.Paren(
639+
this=sg.column(on, table=target_alias, quoted=quoted).eq(
640+
sg.column(on, table=source_alias, quoted=quoted)
641+
)
642+
),
643+
dialect=compiler.dialect,
644+
)
645+
return query
646+
529647
def truncate_table(self, name: str, /, *, database: str | None = None) -> None:
530648
"""Delete all rows from a table.
531649

ibis/backends/tests/conftest.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
11
from __future__ import annotations
22

3+
import sqlite3
4+
35
import pytest
6+
from packaging.version import parse as vparse
47

58
import ibis.common.exceptions as com
6-
from ibis.backends.tests.errors import MySQLOperationalError
9+
from ibis.backends.tests.errors import (
10+
ClickHouseDatabaseError,
11+
ImpalaHiveServer2Error,
12+
MySQLOperationalError,
13+
MySQLProgrammingError,
14+
PsycoPg2InternalError,
15+
Py4JJavaError,
16+
PySparkUnsupportedOperationException,
17+
TrinoUserError,
18+
)
719

820

921
def combine_marks(marks: list) -> callable:
@@ -50,7 +62,6 @@ def decorator(func):
5062
]
5163
NO_ARRAY_SUPPORT = combine_marks(NO_ARRAY_SUPPORT_MARKS)
5264

53-
5465
NO_STRUCT_SUPPORT_MARKS = [
5566
pytest.mark.never(["mysql", "sqlite", "mssql"], reason="No struct support"),
5667
pytest.mark.notyet(["impala"]),
@@ -78,3 +89,53 @@ def decorator(func):
7889
pytest.mark.notimpl(["datafusion", "exasol", "mssql", "druid", "oracle"]),
7990
]
8091
NO_JSON_SUPPORT = combine_marks(NO_JSON_SUPPORT_MARKS)
92+
93+
try:
94+
import pyspark
95+
96+
pyspark_merge_exception = (
97+
PySparkUnsupportedOperationException
98+
if vparse(pyspark.__version__) >= vparse("3.5")
99+
else Py4JJavaError
100+
)
101+
except ImportError:
102+
pyspark_merge_exception = None
103+
104+
NO_MERGE_SUPPORT_MARKS = [
105+
pytest.mark.notyet(
106+
["clickhouse"],
107+
raises=ClickHouseDatabaseError,
108+
reason="MERGE INTO is not supported",
109+
),
110+
pytest.mark.notyet(["datafusion"], reason="MERGE INTO is not supported"),
111+
pytest.mark.notyet(
112+
["impala"],
113+
raises=ImpalaHiveServer2Error,
114+
reason="target table must be an Iceberg table",
115+
),
116+
pytest.mark.notyet(
117+
["mysql"], raises=MySQLProgrammingError, reason="MERGE INTO is not supported"
118+
),
119+
pytest.mark.notimpl(["polars"], reason="`upsert` method not implemented"),
120+
pytest.mark.notyet(
121+
["pyspark"],
122+
raises=pyspark_merge_exception,
123+
reason="MERGE INTO TABLE is not supported temporarily",
124+
),
125+
pytest.mark.notyet(
126+
["risingwave"],
127+
raises=PsycoPg2InternalError,
128+
reason="MERGE INTO is not supported",
129+
),
130+
pytest.mark.notyet(
131+
["sqlite"],
132+
raises=sqlite3.OperationalError,
133+
reason="MERGE INTO is not supported",
134+
),
135+
pytest.mark.notyet(
136+
["trino"],
137+
raises=TrinoUserError,
138+
reason="connector does not support modifying table rows",
139+
),
140+
]
141+
NO_MERGE_SUPPORT = combine_marks(NO_MERGE_SUPPORT_MARKS)

ibis/backends/tests/errors.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,18 @@
5555
from pyspark.errors.exceptions.base import ParseException as PySparkParseException
5656
from pyspark.errors.exceptions.base import PySparkValueError
5757
from pyspark.errors.exceptions.base import PythonException as PySparkPythonException
58+
from pyspark.errors.exceptions.base import (
59+
UnsupportedOperationException as PySparkUnsupportedOperationException,
60+
)
5861
from pyspark.errors.exceptions.connect import (
5962
SparkConnectGrpcException as PySparkConnectGrpcException,
6063
)
6164
except ImportError:
6265
PySparkParseException = PySparkAnalysisException = PySparkArithmeticException = (
6366
PySparkPythonException
64-
) = PySparkConnectGrpcException = PySparkValueError = None
67+
) = PySparkUnsupportedOperationException = PySparkConnectGrpcException = (
68+
PySparkValueError
69+
) = None
6570

6671
try:
6772
from google.api_core.exceptions import BadRequest as GoogleBadRequest

0 commit comments

Comments
 (0)