2
2
import torch .distributed as dist
3
3
import torch .multiprocessing as mp
4
4
import os
5
- import pytest
6
5
from integrators import MonteCarlo , MarkovChainMonteCarlo
7
6
8
- @pytest .fixture
9
- def rank ():
10
- return 0
11
7
12
- @pytest .fixture
13
- def world_size ():
14
- return 8
15
-
16
- def test_init_process (rank , world_size , fn , backend = "gloo" ):
8
+ def init_process (rank , world_size , fn , backend = "gloo" ):
17
9
# Set MASTER_ADDR and MASTER_PORT appropriately
18
10
# Assuming environment variables are set by the cluster's job scheduler
19
11
master_addr = os .getenv ("MASTER_ADDR" , "localhost" )
@@ -26,7 +18,7 @@ def test_init_process(rank, world_size, fn, backend="gloo"):
26
18
fn (rank , world_size )
27
19
28
20
29
- def test_run_mcmc (rank , world_size ):
21
+ def run_mcmc (rank , world_size ):
30
22
# Instantiate the MarkovChainMonteCarlo class
31
23
bounds = [(- 1 , 1 ), (- 1 , 1 )]
32
24
n_eval = 8000000
@@ -60,7 +52,6 @@ def two_integrands(x, f):
60
52
# Only rank 0 prints the result
61
53
print ("MarkovChainMonteCarlo Result:" , mcmc_result )
62
54
63
-
64
- if __name__ == "__main__" :
55
+ def test_mcmc ():
65
56
world_size = 8 # Number of processes to launch
66
- mp .spawn (test_init_process , args = (world_size , test_run_mcmc ), nprocs = world_size , join = True )
57
+ mp .spawn (init_process , args = (world_size , run_mcmc ), nprocs = world_size , join = True )
0 commit comments