diff --git a/main.py b/main.py index 7fba0dc0..be6ca091 100644 --- a/main.py +++ b/main.py @@ -15,14 +15,20 @@ import sys import hydra +from omegaconf import OmegaConf -from sdp.run_processors import run_processors +from sdp.run_processors import SDPRunner + +OmegaConf.register_new_resolver("subfield", lambda node, field: node[field]) +OmegaConf.register_new_resolver("not", lambda x: not x) +OmegaConf.register_new_resolver("equal", lambda field, value: field == value) @hydra.main(version_base=None) def main(cfg): - run_processors(cfg) - + sdp = SDPRunner(cfg) + sdp.run() + if __name__ == "__main__": # hacking the arguments to always disable hydra's output diff --git a/sdp/data_units/abc_unit.py b/sdp/data_units/abc_unit.py new file mode 100644 index 00000000..a4cec293 --- /dev/null +++ b/sdp/data_units/abc_unit.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Iterable +from sdp.data_units.data_entry import DataEntry + + +class DataSource(ABC): + def __init__(self, source: Any): + self.source = source + self.number_of_entries = 0 + self.total_duration = 0.0 + self.metrics = [] + + + @abstractmethod + def read_entry(self) -> Dict: + pass + + @abstractmethod + def read_entries(self, in_memory_chunksize: int = None) -> List[Dict]: + pass + + @abstractmethod + def write_entry(self, data_entry: DataEntry): + pass + + @abstractmethod + def write_entries(self, data_entries: List[DataEntry]): + pass + + def update_metrics(self, data_entry: DataEntry): + if data_entry.metrics is not None: + self.metrics.append(data_entry.metrics) + if data_entry.data is dict: + self.total_duration += data_entry.data.get("duration", 0) + self.number_of_entries += 1 + +class DataSetter(ABC): + def __init__(self, processors_cfgs: List[Dict]): + self.processors_cfgs = processors_cfgs + + def get_resolvable_link(*args): + return f"${{{'.' + '.'.join(list(map(str, args)))}}}" \ No newline at end of file diff --git a/sdp/data_units/cache.py b/sdp/data_units/cache.py new file mode 100644 index 00000000..138d605d --- /dev/null +++ b/sdp/data_units/cache.py @@ -0,0 +1,22 @@ +import os +from tempfile import TemporaryDirectory +from uuid import uuid4 + +class CacheDir: + def __init__(self, cache_dirpath: str = None, prefix: str = None, suffix: str = None): + if cache_dirpath: + os.makedirs(cache_dirpath, exist_ok=True) + self.cache_dir = TemporaryDirectory(dir = cache_dirpath, prefix = prefix, suffix = suffix) + + def make_tmp_filepath(self): + return os.path.join(self.cache_dir.name, str(uuid4())) + + def makedir(self, **kwargs): + tmp_dir = CacheDir(cache_dirpath = self.cache_dir.name, **kwargs) + return tmp_dir + + def cleanup(self): + self.cache_dir.cleanup() + + +CACHE_DIR = CacheDir() \ No newline at end of file diff --git a/sdp/data_units/data_entry.py b/sdp/data_units/data_entry.py new file mode 100644 index 00000000..61cec04f --- /dev/null +++ b/sdp/data_units/data_entry.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional + +@dataclass +class DataEntry: + """A wrapper for data entry + any additional metrics.""" + + data: Optional[Dict] # can be None to drop the entry + metrics: Any = None \ No newline at end of file diff --git a/sdp/data_units/manifest.py b/sdp/data_units/manifest.py new file mode 100644 index 00000000..6218d478 --- /dev/null +++ b/sdp/data_units/manifest.py @@ -0,0 +1,106 @@ +from typing import List, Dict +import json +import os +from tqdm import tqdm + +from sdp.data_units.data_entry import DataEntry +from sdp.data_units.abc_unit import DataSource, DataSetter +from sdp.data_units.cache import CacheDir, CACHE_DIR + + +class Manifest(DataSource): + def __init__(self, filepath: str = None, cache_dir: CacheDir = CACHE_DIR): + self.write_mode = "w" + self.encoding = 'utf8' + self.file = None + + if not filepath: + filepath = cache_dir.make_tmp_filepath() + + super().__init__(filepath) + os.makedirs(os.path.dirname(self.source), exist_ok=True) + + def read_entry(self): + if self.source is None: + raise NotImplementedError("Override this method if the processor creates initial manifest") + + with open(file=self.source, mode='r', encoding = 'utf8') as file: + for line in file: + yield json.loads(line) + + def read_entries(self, in_memory_chunksize = None): + manifest_chunk = [] + for idx, data_entry in enumerate(self.read_entry(), 1): + if not in_memory_chunksize: + yield data_entry + else: + manifest_chunk.append(data_entry) + if idx % in_memory_chunksize == 0: + yield manifest_chunk + manifest_chunk = [] + if manifest_chunk: + yield manifest_chunk + + def write_entry(self, data_entry: DataEntry): + if not self.file: + self.file = open(file=self.source, mode="w", encoding=self.encoding) + self.write_mode = "a" + + self.update_metrics(data_entry) + if data_entry.data: + json.dump(data_entry.data, self.file, ensure_ascii=False) + self.file.write("\n") + + def write_entries(self, data_entries): + for data_entry in tqdm(data_entries): + self.write_entry(data_entry) + + self.close() + + def close(self): + if self.file: + self.file.close() + self.file = None + +class ManifestsSetter(DataSetter): + def __init__(self, processors_cfgs: List[Dict]): + super().__init__(processors_cfgs) + + def is_manifest_resolvable(self, processor_idx: int): + processor_cfg = self.processors_cfgs[processor_idx] + + if "input_manifest_file" not in processor_cfg: + if processor_idx == 0: ## ToDo + return + + if not ("output" in self.processors_cfgs[processor_idx - 1] and + isinstance(self.processors_cfgs[processor_idx - 1]["output"], Manifest)): + raise ValueError() + + def set_processor_manifests(self, processor_idx: int): + self.is_manifest_resolvable(processor_idx) + + processor_cfg = self.processors_cfgs[processor_idx] + + if "input_manifest_file" in processor_cfg: + input_manifest = Manifest(processor_cfg.pop("input_manifest_file")) + else: + #1 st processor + if processor_idx == 0: + input_manifest = None ##ToDo + else: + input_manifest = self.processors_cfgs[processor_idx - 1]["output"] + + processor_cfg["input"] = input_manifest + + if "output_manifest_file" in processor_cfg: + output_manifest = Manifest(processor_cfg.pop("output_manifest_file")) + else: + output_manifest = Manifest() + + processor_cfg["output"] = output_manifest + + self.processors_cfgs[processor_idx] = processor_cfg + print(processor_idx, processor_cfg) + + diff --git a/sdp/data_units/stream.py b/sdp/data_units/stream.py new file mode 100644 index 00000000..55ac7b84 --- /dev/null +++ b/sdp/data_units/stream.py @@ -0,0 +1,231 @@ +from io import BytesIO +from typing import List +import pickle +import json +from tqdm import tqdm + +from sdp.data_units.data_entry import DataEntry +from sdp.data_units.abc_unit import DataSource, DataSetter +from sdp.data_units.manifest import Manifest + + +class Stream(DataSource): + def __init__(self): + self.rw_amount = 0 + self.rw_limit = 1 + self.encoding = "utf8" + + super().__init__(BytesIO()) + + def rw_control(func): + def wrapper(self, *args, **kwargs): + self.source.seek(0) + result = func(self, *args, **kwargs) + self.rw_amount += 1 + if self.rw_amount >= self.rw_limit: + self.source.truncate(0) + return result + return wrapper + + def read_entry(self): + for line in self.source: + yield json.loads(line.decode(self.encoding)) + + @rw_control + def read_entries(self, in_memory_chunksize = None): + data_entries = [entry for entry in self.read_entry()] + if in_memory_chunksize: + data_entries = [data_entries] + return data_entries + + def write_entry(self, data_entry): + self.update_metrics(data_entry) + if data_entry.data: + self.source.write((json.dumps(data_entry.data) + '\n').encode(self.encoding)) + + @rw_control + def write_entries(self, data_entries): + for data_entry in tqdm(data_entries): + self.write_entry(data_entry) + + +class StreamsSetter(DataSetter): + def __init__(self, processors_cfgs): + super().__init__(processors_cfgs) + self.reference_stream_prefix = "stream" + + def resolve_stream(self, reference: str): + reference = reference.replace(self.reference_stream_prefix + ":", "") + if reference == "init": + return Stream() + + current_item = self.processors_cfgs + key_chain = reference.split('.') + for key in key_chain: + if key.isdigit(): + key = int(key) + #TODO: replace io_stream fields to "input" / "output" + if isinstance(key, str) and key.endswith("_stream"): + key = key.replace("_stream", "") + + current_item = current_item[key] + + if not isinstance(current_item, Stream): + raise ValueError() + + current_item.rw_limit += 1 + return current_item + + def is_manifest_to_stream(self, processor_idx, dry_run: bool = False): + processor_cfg = self.processors_cfgs[processor_idx] + + if processor_cfg['_target_'] == "sdp.processors.ManifestToStream": + if "input_manifest_file" not in processor_cfg: + if not ("output" in self.processors_cfgs[processor_idx - 1] and + isinstance(self.processors_cfgs[processor_idx - 1]["output"], Manifest)): + if dry_run: + return False + + raise ValueError() + + if "output_stream" in processor_cfg: + if processor_cfg["output_stream"] != f"{self.reference_stream_prefix}:init": + if dry_run: + return False + + raise ValueError() + + return True + else: + return False + + def as_manifest_to_stream(self, processor_idx): + processor_cfg = self.processors_cfgs[processor_idx] + + if "input_manifest_file" in processor_cfg: + input_manifest = Manifest(processor_cfg.pop("input_manifest_file")) + else: + input_manifest = self.processors_cfgs[processor_idx - 1]['output'] + + processor_cfg["input"] = input_manifest + + if "output_stream" in processor_cfg: + output_stream = self.resolve_stream(processor_cfg.pop("output_stream")) + else: + output_stream = Stream() + + processor_cfg["output"] = output_stream + return processor_cfg + + def is_stream_to_manifest(self, processor_idx, dry_run: bool = False): + processor_cfg = self.processors_cfgs[processor_idx] + + if self.processors_cfgs[processor_idx]['_target_'] == "sdp.processors.StreamToManifest": + if "input_stream" in processor_cfg: + if processor_cfg["input_stream"] == f"{self.reference_stream_prefix}:init": + if dry_run: + return False + + raise ValueError() + else: + if not ("output" in self.processors_cfgs[processor_idx - 1] and + isinstance(self.processors_cfgs[processor_idx - 1]["output"], Stream)): + if dry_run: + return False + + raise ValueError() + + return True + else: + return False + + def as_stream_to_manifest(self, processor_idx): + processor_cfg = self.processors_cfgs[processor_idx] + if "input_stream" in processor_cfg: + input_stream = self.resolve_stream(processor_cfg.pop("input_stream")) + else: + input_stream = self.processors_cfgs[processor_idx - 1]["output"] + input_stream.rw_limit += 1 + + processor_cfg["input"] = input_stream + + if "output_manifest_file" in processor_cfg: + output_manifest = Manifest(processor_cfg.pop("output_manifest_file")) + else: + output_manifest = Manifest() + + processor_cfg["output"] = output_manifest + return processor_cfg + + def traverse_processor(self, cfg): + if isinstance(cfg, list): + for i, item in enumerate(cfg): + cfg[i] = self.traverse_processor(item) + elif isinstance(cfg, dict): + for key, value in cfg.items(): + cfg[key] = self.traverse_processor(value) + elif isinstance(cfg, str) and cfg.startswith(self.reference_stream_prefix): + cfg = self.resolve_stream(cfg) + + return cfg + + def is_stream_resolvable(self, processor_idx, dry_run: bool = False): + processor_cfg = self.processors_cfgs[processor_idx] + + if "input_stream" in processor_cfg: + if not processor_cfg["input_stream"].startswith(self.reference_stream_prefix): + if dry_run: + return False + + raise ValueError() + + if processor_cfg["input_stream"] == f"{self.reference_stream_prefix}:init": + if dry_run: + return False + + raise ValueError() + + else: + if not("output" in self.processors_cfgs[processor_idx - 1] and + isinstance(self.processors_cfgs[processor_idx - 1]["output"], Stream) + ): + if dry_run: + return False + + raise ValueError() + + if "output_stream" in processor_cfg: + if processor_cfg["output_stream"] != f"{self.reference_stream_prefix}:init": + if dry_run: + return False + + raise ValueError() + + return True + + def set_processor_streams(self, processor_idx: int): + if self.is_manifest_to_stream(processor_idx): + processor_cfg = self.as_manifest_to_stream(processor_idx) + + elif self.is_stream_to_manifest(processor_idx): + processor_cfg = self.as_stream_to_manifest(processor_idx) + + elif self.is_stream_resolvable(processor_idx): + processor_cfg = self.processors_cfgs[processor_idx] + processor_cfg = self.traverse_processor(processor_cfg) + + if "input_stream" in processor_cfg: + input_stream = processor_cfg.pop("input_stream") + else: + input_stream = self.processors_cfgs[processor_idx - 1]["output"] + self.processors_cfgs[processor_idx - 1]["output"].rw_limit += 1 + + processor_cfg["input"] = input_stream + processor_cfg["output"] = processor_cfg.pop("output_stream", Stream()) + + self.processors_cfgs[processor_idx] = processor_cfg + print(processor_idx, processor_cfg) + #return processor_cfg + + #raise ValueError("Expected a Stream object for 'input'") + #Manifest() without path -> auto tmp \ No newline at end of file diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index fdafb521..c8b7505b 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -103,3 +103,5 @@ ) from sdp.processors.nemo.asr_inference import ASRInference from sdp.processors.nemo.pc_inference import PCInference + +from sdp.processors.stream.adapters import StreamToManifest, ManifestToStream \ No newline at end of file diff --git a/sdp/processors/base_processor.py b/sdp/processors/base_processor.py index da550714..989334d5 100644 --- a/sdp/processors/base_processor.py +++ b/sdp/processors/base_processor.py @@ -18,21 +18,17 @@ import os import time from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Union, Dict, List, Optional from tqdm import tqdm from tqdm.contrib.concurrent import process_map +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Process, Queue from sdp.logging import logger - - -@dataclass -class DataEntry: - """A wrapper for data entry + any additional metrics.""" - - data: Optional[Dict] # can be None to drop the entry - metrics: Any = None +from sdp.data_units.data_entry import DataEntry +from sdp.data_units.manifest import Manifest +from sdp.data_units.stream import Stream class BaseProcessor(ABC): @@ -57,15 +53,15 @@ class BaseProcessor(ABC): as ``input_manifest_file``. """ - def __init__(self, output_manifest_file: str, input_manifest_file: Optional[str] = None): + def __init__(self, output: Union[Manifest | Stream], input: Optional[Union[Manifest | Stream]] = None): - if output_manifest_file and input_manifest_file and (output_manifest_file == input_manifest_file): + if output and input and (output == input): # we cannot have the same input and output manifest file specified because we need to be able to # read from the input_manifest_file and write to the output_manifest_file at the same time - raise ValueError("A processor's specified input_manifest_file and output_manifest_file cannot be the same") + raise ValueError("A processor's specified input and output cannot be the same.") - self.output_manifest_file = output_manifest_file - self.input_manifest_file = input_manifest_file + self.output = output + self.input = input @abstractmethod def process(self): @@ -73,6 +69,7 @@ def process(self): pass def test(self): + #assert type(self.input) == type(self.output), f"Input ({type(self.input)}) and output ({type(self.output)}) types do not match." """This method can be used to perform "runtime" tests. This can be any kind of self-consistency tests, but are usually @@ -130,7 +127,7 @@ def __init__( # need to convert to list to avoid errors in iteration over None if self.test_cases is None: self.test_cases = [] - + def process(self): """Parallelized implementation of the data processing. @@ -179,13 +176,9 @@ def process(self): """ self.prepare() - os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) - metrics = [] - - with open(self.output_manifest_file, "wt", encoding="utf8") as fout: - for manifest_chunk in self._chunk_manifest(): - # this will unroll all inner lists - data = itertools.chain( + + for manifest_chunk in self.input.read_entries(self.in_memory_chunksize): + data = itertools.chain( *process_map( self.process_dataset_entry, manifest_chunk, @@ -193,16 +186,12 @@ def process(self): chunksize=self.chunksize, ) ) - for data_entry in tqdm(data): - metrics.append(data_entry.metrics) - if data_entry.data is None: - continue - json.dump(data_entry.data, fout, ensure_ascii=False) - self.number_of_entries += 1 - self.total_duration += data_entry.data.get("duration", 0) - fout.write("\n") - - self.finalize(metrics) + + self.output.write_entries(data) + self.number_of_entries = self.output.number_of_entries + self.total_duration = self.output.total_duration + + self.finalize(self.output.metrics) def prepare(self): """Can be used in derived classes to prepare the processing in any way. @@ -211,30 +200,6 @@ def prepare(self): starting processing the data. """ - def _chunk_manifest(self): - """Splits the manifest into smaller chunks defined by ``in_memory_chunksize``.""" - manifest_chunk = [] - for idx, data_entry in enumerate(self.read_manifest(), 1): - manifest_chunk.append(data_entry) - if idx % self.in_memory_chunksize == 0: - yield manifest_chunk - manifest_chunk = [] - if len(manifest_chunk) > 0: - yield manifest_chunk - - def read_manifest(self): - """Reading the input manifest file. - - .. note:: - This function should be overridden in the "initial" class creating - manifest to read from the original source of data. - """ - if self.input_manifest_file is None: - raise NotImplementedError("Override this method if the processor creates initial manifest") - - with open(self.input_manifest_file, "rt", encoding="utf8") as fin: - for line in fin: - yield json.loads(line) @abstractmethod def process_dataset_entry(self, data_entry) -> List[DataEntry]: @@ -299,6 +264,8 @@ def finalize(self, metrics: List): def test(self): """Applies processing to "test_cases" and raises an error in case of mismatch.""" + super().test() + for test_case in self.test_cases: generated_outputs = self.process_dataset_entry(test_case["input"].copy()) expected_outputs = ( diff --git a/sdp/processors/stream/adapters.py b/sdp/processors/stream/adapters.py new file mode 100644 index 00000000..cb18ef94 --- /dev/null +++ b/sdp/processors/stream/adapters.py @@ -0,0 +1,38 @@ +from sdp.processors.base_processor import BaseProcessor +from sdp.data_units.data_entry import DataEntry +from sdp.data_units.stream import Stream +from sdp.data_units.manifest import Manifest + + +class ManifestToStream(BaseProcessor): + def __init__(self, + output: Stream, + input: Manifest): + + super().__init__(output = output, + input = input) + + def process(self): + data = [DataEntry(data_entry) for data_entry in self.input.read_entry()] + self.output.write_entries(data) + + def test(self): + assert type(self.input) is Manifest, "" + assert type(self.output) is Stream, "" + + +class StreamToManifest(BaseProcessor): + def __init__(self, + output: Manifest, + input: Stream): + + super().__init__(output = output, + input = input) + + def process(self): + data = [DataEntry(data) for data in self.input.read_entry()] + self.output.write_entries(data) + + def test(self): + assert type(self.input) is Stream, f"" + assert type(self.output) is Manifest, "" \ No newline at end of file diff --git a/sdp/run_processors.py b/sdp/run_processors.py index b9002de6..b68155c6 100644 --- a/sdp/run_processors.py +++ b/sdp/run_processors.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,21 +13,16 @@ # limitations under the License. import logging -import os -import tempfile -import uuid -from typing import List - +from omegaconf import OmegaConf import hydra -from omegaconf import OmegaConf, open_dict - -from sdp.logging import logger +import yaml +import traceback -# registering new resolvers to simplify config files -OmegaConf.register_new_resolver("subfield", lambda node, field: node[field]) -OmegaConf.register_new_resolver("not", lambda x: not x) -OmegaConf.register_new_resolver("equal", lambda field, value: field == value) +from sdp.data_units.cache import CACHE_DIR +from sdp.data_units.manifest import ManifestsSetter +from sdp.data_units.stream import StreamsSetter +from sdp.logging import logger # customizing logger logger.setLevel(logging.INFO) @@ -42,110 +37,88 @@ logger.addHandler(handler) logger.propagate = False - -def select_subset(input_list: List, select_str: str) -> List: - """This function parses a string and selects objects based on that. - - The string is expected to be a valid representation of Python slice. The - only difference with using an actual slice is that we are always returning - a list, never a single element. See examples below for more details. - - Examples:: - - >>> processors_to_run = [1, 2, 3, 4, 5] - >>> select_subset(processors_to_run, "3:") # to exclude first 3 objects - [4, 5] - - >>> select_subset(processors_to_run, ":-1") # to select all but last - [1, 2, 3, 4] - - >>> select_subset(processors_to_run, "2:5") # to select 3rd to 5th - [3, 4, 5] - - >>> # note that unlike normal slice, we still return a list here - >>> select_subset(processors_to_run, "0") # to select only the first - [1] - - >>> select_subset(processors_to_run, "-1") # to select only the last - [5] - - Args: - input_list (list): input list to select objects from. - select_str (str): string representing Python slice. - - Returns: - list: a subset of the input according to the ``select_str`` - - """ - if ":" not in select_str: - selected_objects = [input_list[int(select_str)]] - else: - slice_obj = slice(*map(lambda x: int(x.strip()) if x.strip() else None, select_str.split(":"))) - selected_objects = input_list[slice_obj] - return selected_objects - - -def run_processors(cfg): - logger.info(f"Hydra config: {OmegaConf.to_yaml(cfg)}") - processors_to_run = cfg.get("processors_to_run", "all") - - if processors_to_run == "all": - processors_to_run = ":" - selected_cfgs = select_subset(cfg.processors, processors_to_run) - # filtering out any processors that have should_run=False - processors_cfgs = [] - for processor_cfg in selected_cfgs: - with open_dict(processor_cfg): +class SDPRunner(ManifestsSetter, StreamsSetter): + def __init__(self, cfg: OmegaConf): + OmegaConf.resolve(cfg) + self.processors_from_cfg = cfg.processors + self.processors_cfgs = self.select_processors_to_run(cfg.get("processors_to_run", "all")) + self.processors = [] + + self.use_streams = cfg.get("use_streams", False) + + super().__init__(self.processors_cfgs) + + def select_processors_to_run(self, processors_to_select: str): + selected_cfgs = [] + if processors_to_select == "all": + selected_cfgs = self.processors_from_cfg[:] + elif ":" not in processors_to_select: + selected_cfgs = [self.processors_from_cfg[int(processors_to_select)]] + else: + slice_obj = slice(*map(lambda x: int(x.strip()) if x.strip() else None, processors_to_select.split(":"))) + selected_cfgs = self.processors_from_cfg[slice_obj] + + processors_cfgs = [] + for processor_cfg in selected_cfgs: + processor_cfg = OmegaConf.to_container(processor_cfg) should_run = processor_cfg.pop("should_run", True) - if should_run: - processors_cfgs.append(processor_cfg) - - logger.info( - "Specified to run the following processors: %s ", - [cfg["_target_"] for cfg in processors_cfgs], - ) - processors = [] - # let's build all processors first to automatically check - # for errors in parameters - with tempfile.TemporaryDirectory() as tmp_dir: - # special check for the first processor. - # In case user selected something that does not start from - # manifest creation we will try to infer the input from previous - # output file - if processors_cfgs[0] is not cfg.processors[0] and "input_manifest_file" not in processors_cfgs[0]: - # locating starting processor - for idx, processor in enumerate(cfg.processors): - if processor is processors_cfgs[0]: # we don't do a copy, so can just check object ids - if "output_manifest_file" in cfg.processors[idx - 1]: - with open_dict(processors_cfgs[0]): - processors_cfgs[0]["input_manifest_file"] = cfg.processors[idx - 1]["output_manifest_file"] + if should_run: + processors_cfgs.append(processor_cfg) + + return processors_cfgs + + def infer_init_input(self): + if (self.processors_cfgs[0] is not self.processors_from_cfg[0] and + "input_manifest_file" not in self.processors_cfgs[0]): + + for processor_idx, processor_cfg in enumerate(self.processors_from_cfg): + if processor_cfg is self.processors_cfgs[0]: + if "output_manifest_file" in self.processors_from_cfg[processor_idx - 1]: + self.processors_cfgs[0]["input_manifest_file"] = self.processors_from_cfg[processor_idx - 1]["output_manifest_file"] break - for idx, processor_cfg in enumerate(processors_cfgs): - logger.info('=> Building processor "%s"', processor_cfg["_target_"]) - - # we assume that each processor defines "output_manifest_file" - # and "input_manifest_file" keys, which can be optional. In case they - # are missing, we create tmp files here for them - # (1) first use a temporary file for the "output_manifest_file" if it is unspecified - if "output_manifest_file" not in processor_cfg: - tmp_file_path = os.path.join(tmp_dir, str(uuid.uuid4())) - with open_dict(processor_cfg): - processor_cfg["output_manifest_file"] = tmp_file_path - - # (2) then link the current processor's output_manifest_file to the next processor's input_manifest_file - # if it hasn't been specified (and if you are not on the last processor) - if idx != len(processors_cfgs) - 1 and "input_manifest_file" not in processors_cfgs[idx + 1]: - with open_dict(processors_cfgs[idx + 1]): - processors_cfgs[idx + 1]["input_manifest_file"] = processor_cfg["output_manifest_file"] - + def set(self): + self.infer_init_input() + + for processor_idx in range(len(self.processors_cfgs)): + if not self.use_streams: + self.set_processor_manifests(processor_idx) + + else: + if (self.is_manifest_to_stream(processor_idx, dry_run = True) or + self.is_stream_to_manifest(processor_idx, dry_run = True) or + self.is_stream_resolvable(processor_idx, dry_run = True)): + self.set_processor_streams(processor_idx) + else: + self.set_processor_manifests(processor_idx) + + def build_processors(self): + for processor_cfg in self.processors_cfgs: processor = hydra.utils.instantiate(processor_cfg) - # running runtime tests to fail right-away if something is not - # matching users expectations + self.processors.append(processor) + + + def test_processors(self): + for processor in self.processors: processor.test() - processors.append(processor) - - for processor in processors: - # TODO: add proper str method to all classes for good display - logger.info('=> Running processor "%s"', processor) - processor.process() + + def run(self): + try: + self.set() + logger.info( + "Specified to run the following processors:\n %s", + (yaml.dump(self.processors_cfgs, default_flow_style=False)), + ) + + self.build_processors() + self.test_processors() + + for processor in self.processors: + logger.info('=> Running processor "%s"', processor) + processor.process() + + except Exception: + print(f"An error occurred: {traceback.format_exc()}") + + finally: + CACHE_DIR.cleanup() \ No newline at end of file