Skip to content

Commit 29f326d

Browse files
committed
Add row_hash via PyArrow to ingested data
1 parent 4193b8a commit 29f326d

File tree

2 files changed

+109
-22
lines changed
  • opendata_stack_platform_project/opendata_stack_platform/dlt/sources/taxi_trip

2 files changed

+109
-22
lines changed

opendata_stack_platform_project/opendata_stack_platform/dlt/sources/taxi_trip/__init__.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,57 @@
1313
BUCKET_URL = "s3://datalake"
1414

1515

16-
@dlt.source(name="taxi_trip_source")
17-
def taxi_trip_source(dataset_type: str, partition_key: Optional[str] = None) -> DltSource:
18-
"""Source for taxi trips data (yellow, green, or FHV) based on file path.
16+
def get_key_columns_for_dataset(dataset_type: str) -> list[str]:
17+
"""
18+
Get the list of columns to use for row hash calculation based on dataset type.
1919
2020
Args:
21-
dataset_type: Type of dataset ('yellow', 'green', or 'fhvhv')
22-
partition_key: Optional partition key for filtering data
21+
dataset_type (str): Type of dataset ('yellow', 'green', or 'fhvhv')
2322
2423
Returns:
25-
DltSource: A data source for the specified taxi trip type
24+
List[str]: List of column names to use for row hash
2625
"""
2726
if dataset_type not in ["yellow", "green", "fhvhv"]:
2827
raise ValueError("dataset_type must be one of 'yellow', 'green', or 'fhvhv'.")
2928

30-
# Define natural key based on dataset type, using snake_case for DuckDB
3129
if dataset_type in ["yellow", "green"]:
3230
pickup_datetime = (
3331
"tpep_pickup_datetime" if dataset_type == "yellow" else "lpep_pickup_datetime"
3432
)
35-
natural_key = (
36-
"vendor_id",
33+
return [
3734
pickup_datetime,
3835
"pu_location_id",
3936
"do_location_id",
4037
"partition_key",
41-
)
38+
]
4239
else: # fhvhv
43-
natural_key = (
44-
"hvfhs_license_num",
40+
return [
4541
"pickup_datetime",
4642
"pu_location_id",
4743
"do_location_id",
4844
"partition_key",
49-
)
45+
]
46+
47+
48+
@dlt.source(name="taxi_trip_source")
49+
def taxi_trip_source(dataset_type: str, partition_key: Optional[str] = None) -> DltSource:
50+
"""Source for taxi trips data (yellow, green, or FHV) based on file path.
51+
52+
Args:
53+
dataset_type: Type of dataset ('yellow', 'green', or 'fhvhv')
54+
partition_key: Optional partition key for filtering data
55+
56+
Returns:
57+
DltSource: A data source for the specified taxi trip type
58+
"""
59+
if dataset_type not in ["yellow", "green", "fhvhv"]:
60+
raise ValueError("dataset_type must be one of 'yellow', 'green', or 'fhvhv'.")
61+
62+
# Get key columns for row hash from utility function
63+
key_columns = get_key_columns_for_dataset(dataset_type)
64+
65+
# Natural key is always the row hash
66+
natural_key = ["row_hash"]
5067

5168
# Construct file glob pattern for the dataset type
5269
file_glob = constants.TAXI_TRIPS_RAW_KEY_TEMPLATE.format(
@@ -62,9 +79,13 @@ def taxi_trip_source(dataset_type: str, partition_key: Optional[str] = None) ->
6279
raw_files.add_filter(lambda item: partition_key[:-3] in item["file_name"])
6380

6481
# Create source with transformations
65-
source = (raw_files | read_parquet_custom(partition_key=partition_key)).with_name(
66-
f"{dataset_type}_taxi_trip_bronze"
67-
)
82+
source = (
83+
raw_files
84+
| read_parquet_custom(
85+
partition_key=partition_key,
86+
key_columns=key_columns,
87+
)
88+
).with_name(f"{dataset_type}_taxi_trip_bronze")
6889

6990
# Apply write configuration hints
7091
source.apply_hints(

opendata_stack_platform_project/opendata_stack_platform/dlt/sources/taxi_trip/utils.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import hashlib
2+
import json
3+
14
from collections.abc import Iterator
25
from datetime import date, datetime, timezone
3-
from typing import Optional
46

57
import dlt
68
import pyarrow as pa
@@ -33,10 +35,68 @@ def add_partition_column(batch: pa.RecordBatch, partition_key: str) -> pa.Record
3335
return new_batch
3436

3537

38+
def add_row_hash(
39+
batch: pa.RecordBatch, key_columns: list[str], hash_column_name: str = "row_hash"
40+
) -> pa.RecordBatch:
41+
"""
42+
Add a hash column to a PyArrow RecordBatch based on selected columns.
43+
44+
Args:
45+
batch: PyArrow RecordBatch to process
46+
key_columns: List of column names to include in hash
47+
hash_column_name: Name for the new hash column
48+
49+
Returns:
50+
PyArrow RecordBatch with added hash column
51+
"""
52+
# Filter out columns that don't exist in the batch
53+
existing_columns = [col for col in key_columns if col in batch.schema.names]
54+
55+
if not existing_columns:
56+
raise ValueError(
57+
f"None of the key columns {key_columns} exist in the batch schema"
58+
)
59+
60+
# Initialize a list to store hash values
61+
hash_values = []
62+
63+
# Process each row directly using PyArrow
64+
for i in range(batch.num_rows):
65+
# Create a dictionary for this row
66+
row_dict = {}
67+
68+
# Extract values for each column in this row
69+
for col in existing_columns:
70+
# Get the column array
71+
col_array = batch.column(batch.schema.get_field_index(col))
72+
# Get the value at this row index
73+
value = col_array[i].as_py()
74+
75+
# Only include non-None values
76+
if value is not None:
77+
row_dict[col] = value
78+
79+
# Convert to sorted JSON string for consistent hashing
80+
json_str = json.dumps(row_dict, sort_keys=True, default=str)
81+
82+
# Create hash
83+
hash_obj = hashlib.md5(json_str.encode())
84+
hash_value = hash_obj.hexdigest()
85+
hash_values.append(hash_value)
86+
87+
# Create PyArrow array from hash values
88+
hash_array = pa.array(hash_values, type=pa.string())
89+
90+
# Add hash column to batch
91+
new_batch = batch.append_column(hash_column_name, hash_array)
92+
return new_batch
93+
94+
3695
@dlt.transformer(standalone=True)
3796
def read_parquet_custom(
3897
items: Iterator[FileItemDict],
39-
partition_key: Optional[str] = None,
98+
partition_key: str,
99+
key_columns: list[str],
40100
batch_size: int = 64_000,
41101
) -> Iterator[pa.RecordBatch]:
42102
"""
@@ -45,6 +105,8 @@ def read_parquet_custom(
45105
46106
Args:
47107
items (Iterator[FileItemDict]): Iterator over file items.
108+
partition_key (Optional[str]): Partition key to add to the data.
109+
key_columns (Optional[List[str]]): Columns to use for row hash calculation.
48110
batch_size (int, optional): Maximum number of rows to process per batch
49111
50112
Yields:
@@ -54,8 +116,12 @@ def read_parquet_custom(
54116
with file_obj.open() as f:
55117
parquet_file = pq.ParquetFile(f)
56118
# Iterate over RecordBatch objects
57-
for batch in parquet_file.iter_batches(batch_size=batch_size):
58-
# Create a new RecordBatch with the existing columns and the new column
59-
batch_with_metadata = add_partition_column(batch, partition_key)
119+
for raw_batch in parquet_file.iter_batches(batch_size=batch_size):
120+
# Add partition column
121+
processed_batch = add_partition_column(raw_batch, partition_key)
122+
123+
# Add row hash
124+
processed_batch = add_row_hash(processed_batch, key_columns)
125+
60126
# Yield the enriched RecordBatch
61-
yield batch_with_metadata
127+
yield processed_batch

0 commit comments

Comments
 (0)