Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions src/import_single_table_to_bigquery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Optional
from typing import List, Optional, Union

import dlt
from dlt.sources.sql_database import sql_database
Expand All @@ -10,6 +10,7 @@
from utilities.setup import (
get_jdbc_connection_string,
set_dlt_environment_variables,
validate_source_tables,
validate_write_dispostiion,
)

Expand All @@ -35,21 +36,36 @@ def type_adapter_callback(sql_type):
def run_import(
vendor_name: str,
source_schema_name: str,
source_table_names: List[str],
source_table_names: Union[List[str], None],
destination_schema_name: str,
connection_string: str,
write_disposition: str,
row_chunk_size: Optional[int] = 10_000,
include_views: bool = True,
):
"""
Executes an import from a remote host to the destination warehouse

Args:
vendor_name: Name of the vendor to sync (for alerting purposes)
source_schema_name: Schema to replicate on the source database
source_table_names: List of tables to replicate OR `None` (this will sync all tables)
destination_schema_name: Schema to write to in TMC's system
connection_string: JDBC string to authenticate source database
write_disposition: One of `append`, `replace`, or `drop`
row_chunk_size: Number of rows to return in a single request
include_views: If `True`, views on the source database will be replicated
"""

logger.info(f"Beginning sync to {destination_schema_name}")
for table in source_table_names:
logger.info(
f"{source_schema_name}.{table} -> {destination_schema_name}.{table}"
)
if source_table_names:
for table in source_table_names:
logger.info(
f"{source_schema_name}.{table} -> {destination_schema_name}.{table}"
)
else:
logger.info("BE ADVISED - All tables in the source schema will be replicated")

# Establish pipeline connection to BigQuery
pipeline = dlt.pipeline(
pipeline_name=f"tmc_{vendor_name}",
Expand All @@ -65,6 +81,7 @@ def run_import(
chunk_size=row_chunk_size,
query_adapter_callback=table_adapter_callback,
type_adapter_callback=type_adapter_callback,
include_views=include_views,
)
source_postgres_connection.max_table_nesting = 0

Expand All @@ -81,29 +98,29 @@ def run_import(
ENV_CONFIG = {**BIGQUERY_DESTINATION_CONFIG, **SQL_SOURCE_CONFIG}
set_dlt_environment_variables(ENV_CONFIG)

# Source parameters
CONNECTION_STRING = get_jdbc_connection_string(config=SQL_SOURCE_CONFIG)

###

VENDOR_NAME = os.environ["VENDOR_NAME"]

SOURCE_SCHEMA_NAME = os.environ["SOURCE_SCHEMA_NAME"]
SOURCE_TABLE_NAMES = [
table.strip() for table in os.environ["SOURCE_TABLE_NAME"].split(",")
]
SOURCE_TABLE_NAMES = validate_source_tables(os.environ["SOURCE_TABLE_NAME"])
INCLUDE_VIEWS = os.environ.get("INCLUDE_VIEWS") != "false"

# Destination parameters
DESTINATION_SCHEMA_NAME = os.environ["DESTINATION_SCHEMA_NAME"]

# Sync parameters
VENDOR_NAME = os.environ["VENDOR_NAME"]
ROW_CHUNK_SIZE = int(os.environ.get("ROW_CHUNK_SIZE", 10_000))

write_disposition = validate_write_dispostiion(
WRITE_DISPOSITION = validate_write_dispostiion(
os.environ["SOURCE_WRITE_DISPOSITION"]
)

run_import(
vendor_name=VENDOR_NAME.lower().replace(" ", "_"),
source_schema_name=SOURCE_SCHEMA_NAME,
source_table_names=SOURCE_TABLE_NAMES,
destination_schema_name=DESTINATION_SCHEMA_NAME,
connection_string=CONNECTION_STRING,
row_chunk_size=ROW_CHUNK_SIZE,
write_disposition=write_disposition,
write_disposition=WRITE_DISPOSITION,
include_views=INCLUDE_VIEWS,
)
25 changes: 25 additions & 0 deletions src/utilities/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List, Union
from urllib.parse import quote # Import the quote function for URL encoding

from utilities.logger import logger
Expand Down Expand Up @@ -53,12 +54,15 @@ def validate_write_dispostiion(write_disposition: str) -> None:
Validates the write disposition value.
Raises ValueError if the value is not valid.
"""

valid_dispositions = ["append", "replace", "merge", "drop"]

if write_disposition not in valid_dispositions:
raise ValueError(
f"Invalid write disposition: {write_disposition}. "
f"Valid options are: {', '.join(valid_dispositions)}."
)

if write_disposition == "drop":
write_disposition = None
# TODO - Someday maybe
Expand All @@ -67,3 +71,24 @@ def validate_write_dispostiion(write_disposition: str) -> None:
"We're not supporting merge as a write disposition yet - all pseduo-incremental loads are handled in dbt"
)
return write_disposition


def validate_source_tables(source_table_string: str) -> Union[List[str], None]:
"""
If the user supplies 'ALL' in the runtime environment
we want to pass in `None` to the import function (this
is how dlt expects to handle all tables in a source schema)

Args:
source_table_string: This should be a comma-separated list of tables in the environment (or 'ALL')

Returns:
Array of table names, or `None` if all tables are to be targeted
"""

target_tables = [table.strip() for table in source_table_string.split(",")]

if len(target_tables) == 1 and target_tables[0].upper() == "ALL":
return

return target_tables
Loading