Skip to content

Commit d6d2a5d

Browse files
committed
resolve the CUDA limit issue
1 parent d1b03ae commit d6d2a5d

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

octopipes/benchmark.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from multiprocess import Pool
1+
from concurrent.futures import ThreadPoolExecutor, as_completed
22

33
from tqdm import tqdm
44

@@ -9,7 +9,12 @@
99

1010
class Benchmark:
1111
def __init__(self, dataloader: Dataloader, workflows: list[Workflow],
12-
flows_factory: AggregateFlowsFactory | None = None) -> None:
12+
flows_factory: AggregateFlowsFactory | None = None, mode: str = "threaded") -> None:
13+
14+
"""
15+
:param str mode: "single" for sequential execution, "threaded" for multithreading.
16+
"""
17+
1318
self.dataloader = dataloader
1419
self.workflows = workflows
1520
self.results: list[AggregateFlows] = []
@@ -28,16 +33,30 @@ def run_sample(factory: AggregateFlowsFactory, workflows, sample):
2833
return aggregate
2934

3035
def run_tests(self):
31-
for batch in tqdm(self.dataloader):
32-
with Pool(processes=len(batch)) as pool:
33-
for result in pool.map(func=Run(self.factory, self.workflows), iterable=batch):
34-
self.results.append(result)
35-
36-
37-
class Run:
38-
def __init__(self, factory, workflows) -> None:
39-
self.factory = factory
40-
self.workflows = workflows
36+
if self.mode == "single":
37+
self._run_sequential()
38+
elif self.mode == "threaded":
39+
self._run_multithreaded()
40+
else:
41+
raise ValueError(f"Invalid model: {self.mode}. Choose 'single' or 'threaded' ")
42+
43+
44+
def _run_sequential(self):
45+
"""Runs workflows sequentially (single-threaded execution)."""
46+
for batch in tqdm(self.dataloader, desc= "processing batches"):
47+
for sample in batch:
48+
result=self.run_sample(self.factory, self.workflows, sample)
49+
self.results.append(result)
50+
51+
def _run_multithreaded(self):
52+
"""Runs workflows using multithreading (threaded execution)."""
53+
for batch in tqdm(self.dataloader,desc= "processing batches"):
54+
with ThreadPoolExecutor(max_workers=len(batch)) as executor:
55+
future_to_sample= {executor.submit(self.run_sample,self.factory, self.workflows, sample): sample for sample in batch}
56+
for future in as_completed(future_to_sample):
57+
try:
58+
result=future.result()
59+
self.results.append(result)
60+
except Exception as e:
61+
print(f"Error processing sample: {e}")
4162

42-
def __call__(self, sample):
43-
return Benchmark.run_sample(self.factory, self.workflows, sample)

0 commit comments

Comments
 (0)