Skip to content

Commit d1cc73f

Browse files
committed
Merge branch 'main' into astroC86/get-or-put-to-copy
2 parents 504b3f4 + 7564882 commit d1cc73f

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

iris/iris.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None):
356356
Returns:
357357
Block: The loaded value from the target memory location.
358358
"""
359-
translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases)
359+
translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases)
360360
result = tl.load(translated_ptr, mask=mask)
361361
return result
362362

@@ -407,7 +407,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
407407
Returns:
408408
None
409409
"""
410-
translated_from_ptr = __translate(from_ptr, to_rank, from_rank, heap_bases)
410+
translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases)
411411

412412
data = tl.load(translated_from_ptr, mask=mask)
413413

tests/unittests/test_load.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,14 @@ def load_kernel(
1919
):
2020
pid = tl.program_id(0)
2121

22+
partner = int((source_rank + num_ranks // 2) % num_ranks)
2223
# Compute start index of this block
2324
block_start = pid * BLOCK_SIZE
2425
offsets = block_start + tl.arange(0, BLOCK_SIZE)
2526

2627
# Guard for out-of-bounds accesses
2728
mask = offsets < BLOCK_SIZE
28-
29-
result = tl.zeros([BLOCK_SIZE], dtype=data.type.element_ty)
30-
for target_rank in range(num_ranks):
31-
result += iris.load(data + offsets, source_rank, target_rank, heap_bases, mask=mask)
32-
33-
# Store data to result buffer
29+
result = iris.load(data + offsets, source_rank, partner, heap_bases, mask=mask)
3430
tl.store(results + offsets, result, mask=mask)
3531

3632

@@ -58,16 +54,17 @@ def test_load_api(dtype, BLOCK_SIZE):
5854
num_ranks = shmem.get_num_ranks()
5955
heap_bases = shmem.get_heap_bases()
6056
source_rank = shmem.get_rank()
57+
partner = int((source_rank + num_ranks // 2) % num_ranks)
6158

62-
data = shmem.ones(BLOCK_SIZE, dtype=dtype)
59+
data = shmem.full((BLOCK_SIZE,), source_rank, dtype=dtype)
6360
results = shmem.zeros_like(data)
6461

6562
grid = lambda meta: (1,)
6663
load_kernel[grid](data, results, source_rank, num_ranks, BLOCK_SIZE, heap_bases)
6764
shmem.barrier()
6865

6966
# Verify the result
70-
expected = torch.ones(BLOCK_SIZE, dtype=dtype, device="cuda") * num_ranks
67+
expected = torch.ones(BLOCK_SIZE, dtype=dtype, device="cuda") * partner
7168

7269
try:
7370
torch.testing.assert_close(results, expected, rtol=0, atol=0)

0 commit comments

Comments
 (0)