11
11
from iris ._mpi_helpers import mpi_allgather
12
12
# from examples.common.utils import read_realtime
13
13
14
+
14
15
@triton .jit
15
16
def read_realtime ():
16
17
tmp = tl .inline_asm_elementwise (
@@ -23,21 +24,25 @@ def read_realtime():
23
24
)
24
25
return tmp
25
26
27
+
26
28
@triton .jit ()
27
29
def gather_latencies (
28
- local_latency ,
29
- global_latency ,
30
- curr_rank ,
31
- num_ranks ,
32
- BLOCK_SIZE : tl .constexpr ,
33
- heap_bases : tl .tensor
30
+ local_latency , global_latency , curr_rank , num_ranks , BLOCK_SIZE : tl .constexpr , heap_bases : tl .tensor
34
31
):
35
32
pid = tl .program_id (0 )
36
33
block_start = pid * BLOCK_SIZE
37
34
offsets = block_start + tl .arange (0 , BLOCK_SIZE )
38
35
39
36
latency_mask = offsets < num_ranks
40
- iris .put (local_latency + offsets , global_latency + curr_rank * num_ranks + offsets , curr_rank , 0 , heap_bases , mask = latency_mask )
37
+ iris .put (
38
+ local_latency + offsets ,
39
+ global_latency + curr_rank * num_ranks + offsets ,
40
+ curr_rank ,
41
+ 0 ,
42
+ heap_bases ,
43
+ mask = latency_mask ,
44
+ )
45
+
41
46
42
47
@triton .jit ()
43
48
def ping_pong (
@@ -66,7 +71,7 @@ def ping_pong(
66
71
start = read_realtime ()
67
72
tl .atomic_xchg (mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , start , time_stmp_mask )
68
73
first_rank = tl .minimum (curr_rank , peer_rank ) if (i % 2 ) == 0 else tl .maximum (curr_rank , peer_rank )
69
- token_first_done = i + 1
74
+ token_first_done = i + 1
70
75
token_second_done = i + 2
71
76
if curr_rank == first_rank :
72
77
iris .put (data + offsets , data + offsets , curr_rank , peer_rank , heap_bases , mask = data_mask )
@@ -82,8 +87,9 @@ def ping_pong(
82
87
stop = read_realtime ()
83
88
tl .atomic_xchg (mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , stop , time_stmp_mask )
84
89
90
+
85
91
if __name__ == "__main__" :
86
- dtype = torch .int32
92
+ dtype = torch .int32
87
93
heap_size = 1 << 32
88
94
shmem = iris .iris (heap_size )
89
95
num_ranks = shmem .get_num_ranks ()
@@ -96,42 +102,48 @@ def ping_pong(
96
102
iter = 200
97
103
skip = 1
98
104
mm_begin_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
99
- mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
105
+ mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
100
106
101
- local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
107
+ local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
102
108
103
109
source_buffer = shmem .ones (BUFFER_LEN , dtype = dtype )
104
110
result_buffer = shmem .zeros_like (source_buffer )
105
- flag = shmem .ones (1 , dtype = dtype )
111
+ flag = shmem .ones (1 , dtype = dtype )
106
112
107
113
grid = lambda meta : (1 ,)
108
114
for source_rank in range (num_ranks ):
109
115
for destination_rank in range (num_ranks ):
110
116
if source_rank != destination_rank and cur_rank in [source_rank , destination_rank ]:
111
117
peer_for_me = destination_rank if cur_rank == source_rank else source_rank
112
- ping_pong [grid ](source_buffer ,
113
- BUFFER_LEN ,
114
- skip , iter ,
115
- flag ,
116
- cur_rank , peer_for_me ,
117
- BLOCK_SIZE ,
118
- heap_bases ,
119
- mm_begin_timestamp ,
120
- mm_end_timestamp )
118
+ ping_pong [grid ](
119
+ source_buffer ,
120
+ BUFFER_LEN ,
121
+ skip ,
122
+ iter ,
123
+ flag ,
124
+ cur_rank ,
125
+ peer_for_me ,
126
+ BLOCK_SIZE ,
127
+ heap_bases ,
128
+ mm_begin_timestamp ,
129
+ mm_end_timestamp ,
130
+ )
121
131
shmem .barrier ()
122
-
132
+
123
133
for destination_rank in range (num_ranks ):
124
- local_latency [destination_rank ] = (mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]) / iter
125
-
134
+ local_latency [destination_rank ] = (
135
+ mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]
136
+ ) / iter
137
+
126
138
latency_matrix = mpi_allgather (local_latency .cpu ())
127
139
128
140
if cur_rank == 0 :
129
- with open (f "latency.txt" , "w" ) as f :
141
+ with open ("latency.txt" , "w" ) as f :
130
142
f .write (" ," + ", " .join (f"R{ j } " for j in range (num_ranks )) + "\n " )
131
143
for i in range (num_ranks ):
132
144
row_entries = []
133
145
for j in range (num_ranks ):
134
146
val = float (latency_matrix [i , j ])
135
147
row_entries .append (f"{ val :0.6f} " )
136
148
line = f"R{ i } ," + ", " .join (row_entries ) + "\n "
137
- f .write (line )
149
+ f .write (line )
0 commit comments