1
- from multiprocess import Pool
1
+ from concurrent . futures import ThreadPoolExecutor , as_completed
2
2
3
3
from tqdm import tqdm
4
4
9
9
10
10
class Benchmark :
11
11
def __init__ (self , dataloader : Dataloader , workflows : list [Workflow ],
12
- flows_factory : AggregateFlowsFactory | None = None ) -> None :
12
+ flows_factory : AggregateFlowsFactory | None = None , mode : str = "threaded" ) -> None :
13
+
14
+ """
15
+ :param str mode: "single" for sequential execution, "threaded" for multithreading.
16
+ """
17
+
13
18
self .dataloader = dataloader
14
19
self .workflows = workflows
15
20
self .results : list [AggregateFlows ] = []
@@ -28,16 +33,30 @@ def run_sample(factory: AggregateFlowsFactory, workflows, sample):
28
33
return aggregate
29
34
30
35
def run_tests (self ):
31
- for batch in tqdm (self .dataloader ):
32
- with Pool (processes = len (batch )) as pool :
33
- for result in pool .map (func = Run (self .factory , self .workflows ), iterable = batch ):
34
- self .results .append (result )
35
-
36
-
37
- class Run :
38
- def __init__ (self , factory , workflows ) -> None :
39
- self .factory = factory
40
- self .workflows = workflows
36
+ if self .mode == "single" :
37
+ self ._run_sequential ()
38
+ elif self .mode == "threaded" :
39
+ self ._run_multithreaded ()
40
+ else :
41
+ raise ValueError (f"Invalid model: { self .mode } . Choose 'single' or 'threaded' " )
42
+
43
+
44
+ def _run_sequential (self ):
45
+ """Runs workflows sequentially (single-threaded execution)."""
46
+ for batch in tqdm (self .dataloader , desc = "processing batches" ):
47
+ for sample in batch :
48
+ result = self .run_sample (self .factory , self .workflows , sample )
49
+ self .results .append (result )
50
+
51
+ def _run_multithreaded (self ):
52
+ """Runs workflows using multithreading (threaded execution)."""
53
+ for batch in tqdm (self .dataloader ,desc = "processing batches" ):
54
+ with ThreadPoolExecutor (max_workers = len (batch )) as executor :
55
+ future_to_sample = {executor .submit (self .run_sample ,self .factory , self .workflows , sample ): sample for sample in batch }
56
+ for future in as_completed (future_to_sample ):
57
+ try :
58
+ result = future .result ()
59
+ self .results .append (result )
60
+ except Exception as e :
61
+ print (f"Error processing sample: { e } " )
41
62
42
- def __call__ (self , sample ):
43
- return Benchmark .run_sample (self .factory , self .workflows , sample )
0 commit comments