@@ -19,18 +19,14 @@ def load_kernel(
19
19
):
20
20
pid = tl .program_id (0 )
21
21
22
+ partner = int ((source_rank + num_ranks // 2 ) % num_ranks )
22
23
# Compute start index of this block
23
24
block_start = pid * BLOCK_SIZE
24
25
offsets = block_start + tl .arange (0 , BLOCK_SIZE )
25
26
26
27
# Guard for out-of-bounds accesses
27
28
mask = offsets < BLOCK_SIZE
28
-
29
- result = tl .zeros ([BLOCK_SIZE ], dtype = data .type .element_ty )
30
- for target_rank in range (num_ranks ):
31
- result += iris .load (data + offsets , source_rank , target_rank , heap_bases , mask = mask )
32
-
33
- # Store data to result buffer
29
+ result = iris .load (data + offsets , source_rank , partner , heap_bases , mask = mask )
34
30
tl .store (results + offsets , result , mask = mask )
35
31
36
32
@@ -58,16 +54,17 @@ def test_load_api(dtype, BLOCK_SIZE):
58
54
num_ranks = shmem .get_num_ranks ()
59
55
heap_bases = shmem .get_heap_bases ()
60
56
source_rank = shmem .get_rank ()
57
+ partner = int ((source_rank + num_ranks // 2 ) % num_ranks )
61
58
62
- data = shmem .ones ( BLOCK_SIZE , dtype = dtype )
59
+ data = shmem .full (( BLOCK_SIZE ,), source_rank , dtype = dtype )
63
60
results = shmem .zeros_like (data )
64
61
65
62
grid = lambda meta : (1 ,)
66
63
load_kernel [grid ](data , results , source_rank , num_ranks , BLOCK_SIZE , heap_bases )
67
64
shmem .barrier ()
68
65
69
66
# Verify the result
70
- expected = torch .ones (BLOCK_SIZE , dtype = dtype , device = "cuda" ) * num_ranks
67
+ expected = torch .ones (BLOCK_SIZE , dtype = dtype , device = "cuda" ) * partner
71
68
72
69
try :
73
70
torch .testing .assert_close (results , expected , rtol = 0 , atol = 0 )
0 commit comments