diff --git a/code_generator_python/python_code_generator/templates/transform_single_file.py b/code_generator_python/python_code_generator/templates/transform_single_file.py index 8fd6d2987..4f2d1ebe3 100644 --- a/code_generator_python/python_code_generator/templates/transform_single_file.py +++ b/code_generator_python/python_code_generator/templates/transform_single_file.py @@ -3,10 +3,6 @@ import time from pathlib import Path import generated_transformer -import awkward as ak -import uproot -import pyarrow.parquet as pq -import numpy as np instance = os.environ.get('INSTANCE_NAME', 'Unknown') default_tree_name = "servicex" default_branch_name = "branch" @@ -22,58 +18,74 @@ def transform_single_file(file_path: str, output_path: Path, output_format: str) try: stime = time.time() - output = generated_transformer.run_query(file_path) - - ttime = time.time() - - if output_format == 'root-file': + # We first see if the function has the signature to directly write output + # If it doesn't, then we assume it's giving us back awkward array results + try: + generated_transformer.run_query(file_path, str(output_path)) + if not output_path.exists(): + raise RuntimeError("Transformation did not produce expected output file " + f"{output_path}") + ttime = time.time() etime = time.time() - if isinstance(output, ak.Array): - awkward_arrays = {default_tree_name: output} - elif isinstance(output, dict): - awkward_arrays = output - with open(output_path, 'b+w') as wfile: - with uproot.recreate(wfile) as writer: - for key in awkward_arrays.keys(): - total_events = awkward_arrays[key].__len__() - if awkward_arrays[key].fields and total_events: - o_dict = {field: awkward_arrays[key][field] - for field in awkward_arrays[key].fields} - elif awkward_arrays[key].fields and not total_events: - o_dict = {field: np.array([]) - for field in awkward_arrays[key].fields} - elif not awkward_arrays[key].fields and total_events: - o_dict = {default_branch_name: awkward_arrays[key]} - else: - o_dict = {default_branch_name: np.array([])} - writer[key] = o_dict - wtime = time.time() - elif output_format == 'raw-file': - etime = time.time() total_events = 0 - output_path = output - wtime = time.time() - else: - if isinstance(output, dict): - tree_name = list(output.keys())[0] - awkward_array = output[tree_name] - print(f'Returned type from your Python function is a dictionary - ' - f'Only the first key {tree_name} will be written as parquet files. ' - f'Please use root-file output to write all trees.') + except AttributeError: + import awkward as ak + import uproot + import pyarrow.parquet as pq + import numpy as np + + output = generated_transformer.run_query(file_path) + + ttime = time.time() + if output_format == 'root-file': + etime = time.time() + if isinstance(output, ak.Array): + awkward_arrays = {default_tree_name: output} + elif isinstance(output, dict): + awkward_arrays = output + with open(output_path, 'b+w') as wfile: + with uproot.recreate(wfile) as writer: + for key in awkward_arrays.keys(): + total_events = awkward_arrays[key].__len__() + if awkward_arrays[key].fields and total_events: + o_dict = {field: awkward_arrays[key][field] + for field in awkward_arrays[key].fields} + elif awkward_arrays[key].fields and not total_events: + o_dict = {field: np.array([]) + for field in awkward_arrays[key].fields} + elif not awkward_arrays[key].fields and total_events: + o_dict = {default_branch_name: awkward_arrays[key]} + else: + o_dict = {default_branch_name: np.array([])} + writer[key] = o_dict + + wtime = time.time() + elif output_format == 'raw-file': + etime = time.time() + total_events = 0 + output_path = output + wtime = time.time() else: - awkward_array = output + if isinstance(output, dict): + tree_name = list(output.keys())[0] + awkward_array = output[tree_name] + print(f'Returned type from your Python function is a dictionary - ' + f'Only the first key {tree_name} will be written as parquet files. ' + f'Please use root-file output to write all trees.') + else: + awkward_array = output - total_events = ak.num(awkward_array, axis=0) - arrow = ak.to_arrow_table(awkward_array) + total_events = ak.num(awkward_array, axis=0) + arrow = ak.to_arrow_table(awkward_array) - etime = time.time() + etime = time.time() - writer = pq.ParquetWriter(output_path, arrow.schema) - writer.write_table(table=arrow) - writer.close() + writer = pq.ParquetWriter(output_path, arrow.schema) + writer.write_table(table=arrow) + writer.close() - wtime = time.time() + wtime = time.time() output_size = os.stat(output_path).st_size print(f'Detailed transformer times. query_time:{round(ttime - stime, 3)} '