Skip to content

Commit 459f636

Browse files
Apply Ruff auto-fixes
1 parent ad03093 commit 459f636

File tree

1 file changed

+38
-26
lines changed

1 file changed

+38
-26
lines changed

tests/examples/test_load_latency.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from iris._mpi_helpers import mpi_allgather
1212
# from examples.common.utils import read_realtime
1313

14+
1415
@triton.jit
1516
def read_realtime():
1617
tmp = tl.inline_asm_elementwise(
@@ -23,21 +24,25 @@ def read_realtime():
2324
)
2425
return tmp
2526

27+
2628
@triton.jit()
2729
def gather_latencies(
28-
local_latency,
29-
global_latency,
30-
curr_rank,
31-
num_ranks ,
32-
BLOCK_SIZE: tl.constexpr,
33-
heap_bases: tl.tensor
30+
local_latency, global_latency, curr_rank, num_ranks, BLOCK_SIZE: tl.constexpr, heap_bases: tl.tensor
3431
):
3532
pid = tl.program_id(0)
3633
block_start = pid * BLOCK_SIZE
3734
offsets = block_start + tl.arange(0, BLOCK_SIZE)
3835

3936
latency_mask = offsets < num_ranks
40-
iris.put(local_latency + offsets, global_latency + curr_rank * num_ranks + offsets, curr_rank, 0, heap_bases, mask=latency_mask)
37+
iris.put(
38+
local_latency + offsets,
39+
global_latency + curr_rank * num_ranks + offsets,
40+
curr_rank,
41+
0,
42+
heap_bases,
43+
mask=latency_mask,
44+
)
45+
4146

4247
@triton.jit()
4348
def ping_pong(
@@ -66,7 +71,7 @@ def ping_pong(
6671
start = read_realtime()
6772
tl.atomic_xchg(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask)
6873
first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank)
69-
token_first_done = i + 1
74+
token_first_done = i + 1
7075
token_second_done = i + 2
7176
if curr_rank == first_rank:
7277
iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask)
@@ -82,8 +87,9 @@ def ping_pong(
8287
stop = read_realtime()
8388
tl.atomic_xchg(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask)
8489

90+
8591
if __name__ == "__main__":
86-
dtype = torch.int32
92+
dtype = torch.int32
8793
heap_size = 1 << 32
8894
shmem = iris.iris(heap_size)
8995
num_ranks = shmem.get_num_ranks()
@@ -96,42 +102,48 @@ def ping_pong(
96102
iter = 200
97103
skip = 1
98104
mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
99-
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
105+
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
100106

101-
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
107+
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
102108

103109
source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype)
104110
result_buffer = shmem.zeros_like(source_buffer)
105-
flag = shmem.ones(1, dtype=dtype)
111+
flag = shmem.ones(1, dtype=dtype)
106112

107113
grid = lambda meta: (1,)
108114
for source_rank in range(num_ranks):
109115
for destination_rank in range(num_ranks):
110116
if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]:
111117
peer_for_me = destination_rank if cur_rank == source_rank else source_rank
112-
ping_pong[grid](source_buffer,
113-
BUFFER_LEN,
114-
skip, iter,
115-
flag,
116-
cur_rank, peer_for_me,
117-
BLOCK_SIZE,
118-
heap_bases,
119-
mm_begin_timestamp,
120-
mm_end_timestamp)
118+
ping_pong[grid](
119+
source_buffer,
120+
BUFFER_LEN,
121+
skip,
122+
iter,
123+
flag,
124+
cur_rank,
125+
peer_for_me,
126+
BLOCK_SIZE,
127+
heap_bases,
128+
mm_begin_timestamp,
129+
mm_end_timestamp,
130+
)
121131
shmem.barrier()
122-
132+
123133
for destination_rank in range(num_ranks):
124-
local_latency[destination_rank] = (mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]) / iter
125-
134+
local_latency[destination_rank] = (
135+
mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]
136+
) / iter
137+
126138
latency_matrix = mpi_allgather(local_latency.cpu())
127139

128140
if cur_rank == 0:
129-
with open(f"latency.txt", "w") as f:
141+
with open("latency.txt", "w") as f:
130142
f.write(" ," + ", ".join(f"R{j}" for j in range(num_ranks)) + "\n")
131143
for i in range(num_ranks):
132144
row_entries = []
133145
for j in range(num_ranks):
134146
val = float(latency_matrix[i, j])
135147
row_entries.append(f"{val:0.6f}")
136148
line = f"R{i}," + ", ".join(row_entries) + "\n"
137-
f.write(line)
149+
f.write(line)

0 commit comments

Comments
 (0)