Skip to content

Commit 745e61c

Browse files
authored
Update mc_multicpu_test.py
1 parent b0fd838 commit 745e61c

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

MCintegration/mc_multicpu_test.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,10 @@
22
import torch.distributed as dist
33
import torch.multiprocessing as mp
44
import os
5-
import pytest
65
from integrators import MonteCarlo, MarkovChainMonteCarlo
76

8-
@pytest.fixture
9-
def rank():
10-
return 0
117

12-
@pytest.fixture
13-
def world_size():
14-
return 8
15-
16-
def test_init_process(rank, world_size, fn, backend="gloo"):
8+
def init_process(rank, world_size, fn, backend="gloo"):
179
# Set MASTER_ADDR and MASTER_PORT appropriately
1810
# Assuming environment variables are set by the cluster's job scheduler
1911
master_addr = os.getenv("MASTER_ADDR", "localhost")
@@ -26,7 +18,7 @@ def test_init_process(rank, world_size, fn, backend="gloo"):
2618
fn(rank, world_size)
2719

2820

29-
def test_run_mcmc(rank, world_size):
21+
def run_mcmc(rank, world_size):
3022
# Instantiate the MarkovChainMonteCarlo class
3123
bounds = [(-1, 1), (-1, 1)]
3224
n_eval = 8000000
@@ -60,7 +52,6 @@ def two_integrands(x, f):
6052
# Only rank 0 prints the result
6153
print("MarkovChainMonteCarlo Result:", mcmc_result)
6254

63-
64-
if __name__ == "__main__":
55+
def test_mcmc():
6556
world_size = 8 # Number of processes to launch
66-
mp.spawn(test_init_process, args=(world_size, test_run_mcmc), nprocs=world_size, join=True)
57+
mp.spawn(init_process, args=(world_size, run_mcmc), nprocs=world_size, join=True)

0 commit comments

Comments
 (0)