Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions drt/destinations/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,39 @@ def _serialize_value(
return value


def _split_qualified(table: str) -> tuple[str | None, str]:
"""Split an optional ``schema.table`` name into schema and relation parts."""
if "." not in table:
return None, table
schema, relation = table.split(".", 1)
return schema or None, relation


def _join_qualified(schema: str | None, relation: str) -> str:
if schema is None:
return relation
return f"{schema}.{relation}"


def _qualified_ident(table: str) -> Any:
"""Return a psycopg2 Identifier that quotes each qualified-name part."""
from psycopg2 import sql as _pgsql

schema, relation = _split_qualified(table)
if schema is None:
return _pgsql.Identifier(relation)
return _pgsql.Identifier(schema, relation)


def _relation_name(table: str) -> str:
return _split_qualified(table)[1]


def _with_relation_suffix(table: str, suffix: str) -> str:
schema, relation = _split_qualified(table)
return _join_qualified(schema, f"{relation}{suffix}")


class PostgresDestination:
"""Upsert or replace records into a PostgreSQL table."""

Expand Down Expand Up @@ -166,7 +199,7 @@ def get_row_count(self, config: DestinationConfig) -> int:
try:
cur = conn.cursor()
query = sql.SQL("SELECT COUNT(*) FROM {}").format(
sql.Identifier(config.table)
_qualified_ident(config.table)
)
cur.execute(query)
row = cur.fetchone()
Expand Down Expand Up @@ -199,7 +232,7 @@ def _load_replace(
result = SyncResult()

if not self._replace_truncated:
cur.execute(_pgsql.SQL("TRUNCATE TABLE {}").format(_pgsql.Identifier(table)))
cur.execute(_pgsql.SQL("TRUNCATE TABLE {}").format(_qualified_ident(table)))
self._replace_truncated = True

query = self._build_insert_sql(table, columns)
Expand All @@ -225,7 +258,11 @@ def _load_replace(
conn.rollback()
cur = conn.cursor()
if not self._replace_truncated:
cur.execute(_pgsql.SQL("TRUNCATE TABLE {}").format(_pgsql.Identifier(table)))
cur.execute(
_pgsql.SQL("TRUNCATE TABLE {}").format(
_qualified_ident(table)
)
)
self._replace_truncated = True
continue

Expand All @@ -245,16 +282,16 @@ def _load_replace_swap(
"""Build a shadow table per sync; atomic rename happens in finalize_sync."""
from psycopg2 import sql as _pgsql
result = SyncResult()
shadow = f"{table}__drt_swap"
shadow = _with_relation_suffix(table, "__drt_swap")

if not self._swap_shadow_created:
cur.execute(
_pgsql.SQL("DROP TABLE IF EXISTS {}").format(_pgsql.Identifier(shadow))
_pgsql.SQL("DROP TABLE IF EXISTS {}").format(_qualified_ident(shadow))
)
cur.execute(
_pgsql.SQL("CREATE TABLE {} (LIKE {} INCLUDING ALL)").format(
_pgsql.Identifier(shadow),
_pgsql.Identifier(table),
_qualified_ident(shadow),
_qualified_ident(table),
)
)
self._swap_shadow_created = True
Expand Down Expand Up @@ -283,7 +320,7 @@ def _load_replace_swap(
cur = conn.cursor()
cur.execute(
_pgsql.SQL("DROP TABLE IF EXISTS {}").format(
_pgsql.Identifier(shadow)
_qualified_ident(shadow)
)
)
conn.commit()
Expand All @@ -308,8 +345,8 @@ def finalize_sync(

assert isinstance(config, PostgresDestinationConfig)
table = self._swap_table
shadow = f"{table}__drt_swap"
old = f"{table}__drt_old"
shadow = _with_relation_suffix(table, "__drt_swap")
old = _with_relation_suffix(table, "__drt_old")

conn = self._connect(config)
try:
Expand All @@ -319,19 +356,19 @@ def finalize_sync(
# the schema is preserved automatically.
cur.execute(
_pgsql.SQL("ALTER TABLE {} RENAME TO {}").format(
_pgsql.Identifier(table),
_pgsql.Identifier(old.split(".")[-1]),
_qualified_ident(table),
_pgsql.Identifier(_relation_name(old)),
)
)
cur.execute(
_pgsql.SQL("ALTER TABLE {} RENAME TO {}").format(
_pgsql.Identifier(shadow),
_pgsql.Identifier(table.split(".")[-1]),
_qualified_ident(shadow),
_pgsql.Identifier(_relation_name(table)),
)
)
conn.commit()
# DROP old in separate tx (failure here doesn't break the swap).
cur.execute(_pgsql.SQL("DROP TABLE {}").format(_pgsql.Identifier(old)))
cur.execute(_pgsql.SQL("DROP TABLE {}").format(_qualified_ident(old)))
conn.commit()
finally:
conn.close()
Expand Down Expand Up @@ -387,7 +424,7 @@ def _load_upsert(
def _build_insert_sql(table: str, columns: list[str]) -> Any:
from psycopg2 import sql as _pgsql
return _pgsql.SQL("INSERT INTO {} ({}) VALUES ({})").format(
_pgsql.Identifier(table),
_qualified_ident(table),
_pgsql.SQL(", ").join(_pgsql.Identifier(c) for c in columns),
_pgsql.SQL(", ").join(_pgsql.Placeholder() for _ in columns),
)
Expand All @@ -414,7 +451,7 @@ def _build_upsert_sql(
return _pgsql.SQL(
"INSERT INTO {} ({}) VALUES ({}) ON CONFLICT ({}) {}"
).format(
_pgsql.Identifier(table),
_qualified_ident(table),
_pgsql.SQL(", ").join(_pgsql.Identifier(c) for c in columns),
_pgsql.SQL(", ").join(_pgsql.Placeholder() for _ in columns),
_pgsql.SQL(", ").join(_pgsql.Identifier(c) for c in upsert_key),
Expand Down
Loading
Loading