Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions application_sdk/activities/query_extraction/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,18 +202,21 @@ async def fetch_queries(

try:
state = await self._get_state(workflow_args)

sql_input = SQLQueryInput(
engine=state.sql_client.engine,
query=self.get_formatted_query(self.fetch_queries_sql, workflow_args),
chunk_size=None,
)
sql_input = await sql_input.get_dataframe()

sql_input.columns = [str(c).upper() for c in sql_input.columns]

raw_output = ParquetOutput(
output_path=workflow_args["output_path"],
output_suffix="raw/query",
chunk_size=workflow_args["miner_args"].get("chunk_size", 100000),
start_marker=workflow_args["start_marker"],
start_marker=str(workflow_args["start_marker"]),
end_marker=workflow_args["end_marker"],
)
await raw_output.write_dataframe(sql_input)
Expand Down Expand Up @@ -528,9 +531,107 @@ async def get_query_batches(
store_name=UPSTREAM_OBJECT_STORE_NAME,
)

# Persist the full marker list in StateStore to avoid oversized activity results
try:
await self.write_marker(parallel_markers, workflow_args)
from application_sdk.services.statestore import StateStore, StateType

workflow_id: str = workflow_args.get("workflow_id") or get_workflow_id()
await StateStore.save_state(
key="query_batches",
value=parallel_markers,
id=workflow_id,
type=StateType.WORKFLOWS,
)
logger.info(
f"Saved {len(parallel_markers)} query batches to StateStore for {workflow_id}"
)
except Exception as e:
logger.warning(f"Failed to write marker file: {e}")
logger.error(f"Failed to save query batches in StateStore: {e}")
# Re-raise to ensure the workflow can retry per standards
raise

return parallel_markers
# Return a small handle to keep activity result size minimal
return [{"state_key": "query_batches", "count": len(parallel_markers)}]

@activity.defn
@auto_heartbeater
async def load_query_batches(
self, workflow_args: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Load previously saved query batches from StateStore.

Args:
workflow_args (Dict[str, Any]): Workflow arguments containing workflow_id

Returns:
List[Dict[str, Any]]: The list of parallelized query batch descriptors

Raises:
Exception: If retrieval from StateStore fails
"""
try:
from application_sdk.services.statestore import StateStore, StateType

workflow_id: str = workflow_args.get("workflow_id") or get_workflow_id()
state = await StateStore.get_state(workflow_id, StateType.WORKFLOWS)
batches: List[Dict[str, Any]] = state.get("query_batches", [])
logger.info(
f"Loaded {len(batches)} query batches from StateStore for {workflow_id}"
)
return batches
except Exception as e:
logger.error(
f"Failed to load query batches from StateStore: {e}",
exc_info=True,
)
raise

@activity.defn
@auto_heartbeater
async def fetch_single_batch(
self, workflow_args: Dict[str, Any], batch_index: int
) -> Dict[str, Any]:
"""Fetch a single batch by index from StateStore.

Args:
workflow_args (Dict[str, Any]): Workflow arguments containing workflow_id
batch_index (int): Index of the batch to fetch

Returns:
Dict[str, Any]: The single batch data

Raises:
Exception: If batch retrieval fails
"""
try:
from application_sdk.services.statestore import StateStore, StateType

workflow_id: str = workflow_args.get("workflow_id") or get_workflow_id()
state = await StateStore.get_state(workflow_id, StateType.WORKFLOWS)
batches: List[Dict[str, Any]] = state.get("query_batches", [])

if batch_index >= len(batches):
raise IndexError(
f"Batch index {batch_index} out of range for {len(batches)} batches"
)

batch = batches[batch_index]
logger.info(f"Fetched batch {batch_index + 1}/{len(batches)}")
return batch

except Exception as e:
logger.error(f"Failed to fetch batch {batch_index}: {e}")
raise

@activity.defn
@auto_heartbeater
async def write_final_marker(self, workflow_args: Dict[str, Any]) -> None:
"""Write final marker after all fetches complete.

Loads batches from StateStore and writes the last end marker as the markerfile.
"""
try:
batches = await self.load_query_batches(workflow_args)
await self.write_marker(batches, workflow_args)
except Exception as e:
logger.warning(f"Failed to write final marker file: {e}")
18 changes: 9 additions & 9 deletions application_sdk/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def path_gen(
chunk_count: Optional[int] = None,
chunk_part: int = 0,
start_marker: Optional[str] = None,
end_marker: Optional[str] = None,
) -> str:
"""Generate a file path for a chunk.

Expand All @@ -109,9 +108,12 @@ def path_gen(
Returns:
str: Generated file path for the chunk.
"""
# For Query Extraction - use start and end markers without chunk count
if start_marker and end_marker:
return f"{start_marker}_{end_marker}{self._EXTENSION}"
# For Query Extraction - use start marker
if start_marker:
if chunk_count is None:
return f"atlan_raw_mined_{str(start_marker)}_{str(chunk_part)}{self._EXTENSION}"
else:
return f"atlan_raw_mined_{str(start_marker)}_{str(chunk_count)}_{str(chunk_part)}{self._EXTENSION}"

# For regular chunking - include chunk count
if chunk_count is None:
Expand Down Expand Up @@ -213,7 +215,7 @@ async def write_dataframe(self, dataframe: "pd.DataFrame"):
self.current_buffer_size_bytes + chunk_size_bytes
> self.max_file_size_bytes
):
output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part)}"
output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part, self.start_marker)}"
if os.path.exists(output_file_name):
await self._upload_file(output_file_name)
self.chunk_part += 1
Expand All @@ -227,7 +229,7 @@ async def write_dataframe(self, dataframe: "pd.DataFrame"):

if self.current_buffer_size_bytes > 0:
# Finally upload the final file to the object store
output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part)}"
output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part, self.start_marker)}"
if os.path.exists(output_file_name):
await self._upload_file(output_file_name)
self.chunk_part += 1
Expand Down Expand Up @@ -361,9 +363,7 @@ async def _flush_buffer(self, chunk: "pd.DataFrame", chunk_part: int):
try:
if not is_empty_dataframe(chunk):
self.total_record_count += len(chunk)
output_file_name = (
f"{self.output_path}/{self.path_gen(self.chunk_count, chunk_part)}"
)
output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, chunk_part, self.start_marker)}"
await self.write_chunk(chunk, output_file_name)

self.current_buffer_size = 0
Expand Down
38 changes: 32 additions & 6 deletions application_sdk/workflows/query_extraction/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def get_activities(
"""
return [
activities.get_query_batches,
activities.load_query_batches,
activities.fetch_single_batch,
activities.fetch_queries,
activities.write_final_marker,
activities.preflight_check,
activities.get_workflow_args,
]
Expand Down Expand Up @@ -97,22 +100,36 @@ async def run(self, workflow_config: Dict[str, Any]):
backoff_coefficient=2,
)

results: List[Dict[str, Any]] = await workflow.execute_activity_method(
# Generate and persist batch markers (activity returns only a small handle)
batch_handles: List[Dict[str, Any]] = await workflow.execute_activity_method(
self.activities_cls.get_query_batches,
workflow_args,
retry_policy=retry_policy,
start_to_close_timeout=self.default_start_to_close_timeout,
heartbeat_timeout=self.default_heartbeat_timeout,
)

batch_count = batch_handles[0]["count"] if batch_handles else 0
logger.info(f"Processing {batch_count} query batches")

miner_activities: List[Coroutine[Any, Any, None]] = []

# Extract Queries
for result in results:
# Fetch and process each batch individually to avoid size limits
for batch_index in range(batch_count):
# Fetch the specific batch
batch_data = await workflow.execute_activity(
self.activities_cls.fetch_single_batch,
args=[workflow_args, batch_index],
retry_policy=retry_policy,
start_to_close_timeout=self.default_start_to_close_timeout,
heartbeat_timeout=self.default_heartbeat_timeout,
)

# Create activity args for this specific batch
activity_args = workflow_args.copy()
activity_args["sql_query"] = result["sql"]
activity_args["start_marker"] = result["start"]
activity_args["end_marker"] = result["end"]
activity_args["sql_query"] = batch_data["sql"]
activity_args["start_marker"] = batch_data["start"]
activity_args["end_marker"] = batch_data["end"]

miner_activities.append(
workflow.execute_activity(
Expand All @@ -126,4 +143,13 @@ async def run(self, workflow_config: Dict[str, Any]):

await asyncio.gather(*miner_activities)

# Write marker only after all fetches complete
await workflow.execute_activity_method(
self.activities_cls.write_final_marker,
workflow_args,
retry_policy=retry_policy,
start_to_close_timeout=self.default_start_to_close_timeout,
heartbeat_timeout=self.default_heartbeat_timeout,
)

logger.info(f"Miner workflow completed for {workflow_id}")
Loading