diff --git a/src/import_single_table_to_bigquery.py b/src/import_single_table_to_bigquery.py index c3c4191..f7b6038 100644 --- a/src/import_single_table_to_bigquery.py +++ b/src/import_single_table_to_bigquery.py @@ -1,5 +1,5 @@ import os -from typing import List, Optional +from typing import List, Optional, Union import dlt from dlt.sources.sql_database import sql_database @@ -10,6 +10,7 @@ from utilities.setup import ( get_jdbc_connection_string, set_dlt_environment_variables, + validate_source_tables, validate_write_dispostiion, ) @@ -35,21 +36,36 @@ def type_adapter_callback(sql_type): def run_import( vendor_name: str, source_schema_name: str, - source_table_names: List[str], + source_table_names: Union[List[str], None], destination_schema_name: str, connection_string: str, write_disposition: str, row_chunk_size: Optional[int] = 10_000, + include_views: bool = True, ): """ Executes an import from a remote host to the destination warehouse + + Args: + vendor_name: Name of the vendor to sync (for alerting purposes) + source_schema_name: Schema to replicate on the source database + source_table_names: List of tables to replicate OR `None` (this will sync all tables) + destination_schema_name: Schema to write to in TMC's system + connection_string: JDBC string to authenticate source database + write_disposition: One of `append`, `replace`, or `drop` + row_chunk_size: Number of rows to return in a single request + include_views: If `True`, views on the source database will be replicated """ logger.info(f"Beginning sync to {destination_schema_name}") - for table in source_table_names: - logger.info( - f"{source_schema_name}.{table} -> {destination_schema_name}.{table}" - ) + if source_table_names: + for table in source_table_names: + logger.info( + f"{source_schema_name}.{table} -> {destination_schema_name}.{table}" + ) + else: + logger.info("BE ADVISED - All tables in the source schema will be replicated") + # Establish pipeline connection to BigQuery pipeline = dlt.pipeline( pipeline_name=f"tmc_{vendor_name}", @@ -65,6 +81,7 @@ def run_import( chunk_size=row_chunk_size, query_adapter_callback=table_adapter_callback, type_adapter_callback=type_adapter_callback, + include_views=include_views, ) source_postgres_connection.max_table_nesting = 0 @@ -81,23 +98,22 @@ def run_import( ENV_CONFIG = {**BIGQUERY_DESTINATION_CONFIG, **SQL_SOURCE_CONFIG} set_dlt_environment_variables(ENV_CONFIG) + # Source parameters CONNECTION_STRING = get_jdbc_connection_string(config=SQL_SOURCE_CONFIG) - - ### - - VENDOR_NAME = os.environ["VENDOR_NAME"] - SOURCE_SCHEMA_NAME = os.environ["SOURCE_SCHEMA_NAME"] - SOURCE_TABLE_NAMES = [ - table.strip() for table in os.environ["SOURCE_TABLE_NAME"].split(",") - ] + SOURCE_TABLE_NAMES = validate_source_tables(os.environ["SOURCE_TABLE_NAME"]) + INCLUDE_VIEWS = os.environ.get("INCLUDE_VIEWS") != "false" + + # Destination parameters DESTINATION_SCHEMA_NAME = os.environ["DESTINATION_SCHEMA_NAME"] + # Sync parameters + VENDOR_NAME = os.environ["VENDOR_NAME"] ROW_CHUNK_SIZE = int(os.environ.get("ROW_CHUNK_SIZE", 10_000)) - - write_disposition = validate_write_dispostiion( + WRITE_DISPOSITION = validate_write_dispostiion( os.environ["SOURCE_WRITE_DISPOSITION"] ) + run_import( vendor_name=VENDOR_NAME.lower().replace(" ", "_"), source_schema_name=SOURCE_SCHEMA_NAME, @@ -105,5 +121,6 @@ def run_import( destination_schema_name=DESTINATION_SCHEMA_NAME, connection_string=CONNECTION_STRING, row_chunk_size=ROW_CHUNK_SIZE, - write_disposition=write_disposition, + write_disposition=WRITE_DISPOSITION, + include_views=INCLUDE_VIEWS, ) diff --git a/src/utilities/setup.py b/src/utilities/setup.py index 62b73b4..8ef91d5 100644 --- a/src/utilities/setup.py +++ b/src/utilities/setup.py @@ -1,4 +1,5 @@ import os +from typing import List, Union from urllib.parse import quote # Import the quote function for URL encoding from utilities.logger import logger @@ -53,12 +54,15 @@ def validate_write_dispostiion(write_disposition: str) -> None: Validates the write disposition value. Raises ValueError if the value is not valid. """ + valid_dispositions = ["append", "replace", "merge", "drop"] + if write_disposition not in valid_dispositions: raise ValueError( f"Invalid write disposition: {write_disposition}. " f"Valid options are: {', '.join(valid_dispositions)}." ) + if write_disposition == "drop": write_disposition = None # TODO - Someday maybe @@ -67,3 +71,24 @@ def validate_write_dispostiion(write_disposition: str) -> None: "We're not supporting merge as a write disposition yet - all pseduo-incremental loads are handled in dbt" ) return write_disposition + + +def validate_source_tables(source_table_string: str) -> Union[List[str], None]: + """ + If the user supplies 'ALL' in the runtime environment + we want to pass in `None` to the import function (this + is how dlt expects to handle all tables in a source schema) + + Args: + source_table_string: This should be a comma-separated list of tables in the environment (or 'ALL') + + Returns: + Array of table names, or `None` if all tables are to be targeted + """ + + target_tables = [table.strip() for table in source_table_string.split(",")] + + if len(target_tables) == 1 and target_tables[0].upper() == "ALL": + return + + return target_tables