Skip to content

Commit c2ca89c

Browse files
committed
Intial impl of copy
1 parent d1cc73f commit c2ca89c

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

iris/iris.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -407,11 +407,7 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
407407
Returns:
408408
None
409409
"""
410-
translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases)
411-
412-
data = tl.load(translated_from_ptr, mask=mask)
413-
414-
tl.store(to_ptr, data, mask=mask)
410+
copy(from_ptr, to_ptr, from_rank, to_rank , heap_bases, mask)
415411

416412

417413
@triton.jit
@@ -434,15 +430,29 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
434430
Returns:
435431
None
436432
"""
437-
translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases)
433+
copy(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask)
438434

439-
data = tl.load(from_ptr, mask=mask)
440435

441-
tl.store(translated_to_ptr, data, mask=mask)
436+
@triton.jit
437+
def copy(src_ptr, dst_ptr, from_rank, to_rank, heap_bases, mask=None):
438+
"""
439+
Copies data from the specified rank's memory into the destination rank's memory.
440+
This function performs the transfer by translating src_ptr from the from_rank's address
441+
space to the to_rank's address space, performing a masked load from the translated
442+
source, and storing the loaded data to dst_ptr in the to_rank memory location.
443+
If from_rank and to_rank are the same, this function performs a local copy operation.
442444
445+
Args:
446+
src_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's local memory from which to read data.
447+
dst_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the to_rank's local memory where the data will be written.
448+
from_rank (int): The rank ID that owns src_ptr (source rank).
449+
to_rank (int): The rank ID that will receive the data (destination rank).
450+
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
451+
mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None.
443452
444-
@triton.jit
445-
def copy(dst_ptr, src_ptr, from_rank, to_rank, heap_bases, mask=None):
453+
Returns:
454+
None
455+
"""
446456
translated_src = __translate(src_ptr, from_rank, to_rank, heap_bases)
447457
data = tl.load(translated_src, mask=mask)
448458
tl.store(dst_ptr, data, mask=mask)

tests/unittests/test_copy.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def copy_kernel(
2626
src_data = data + BLOCK_SIZE * cur_rank
2727
dest_data = results + BLOCK_SIZE * target_rank
2828
iris.copy(
29-
dest_data + offsets,
3029
src_data + offsets,
30+
dest_data + offsets,
3131
cur_rank,
3232
target_rank,
3333
heap_bases,
@@ -77,15 +77,13 @@ def test_copy_get_semantics(dtype, BLOCK_SIZE):
7777
shmem.barrier()
7878

7979
expected = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=dtype)
80-
expected_2 = torch.zeros((num_ranks, BLOCK_SIZE), dtype=dtype, device="cuda")
8180
for rank_id in range(num_ranks):
82-
expected[rank_id, :] = 999999
83-
expected_2[rank_id, :] = (rank_id + num_ranks) * (cur_rank + 1)
81+
expected[rank_id, :] = (rank_id + num_ranks) * (cur_rank + 1)
8482

8583
try:
8684
torch.testing.assert_close(results, expected, rtol=0, atol=0)
8785
except AssertionError as e:
8886
print(e)
89-
print("Expected:", expected_2)
87+
print("Expected:", expected)
9088
print("Actual:", results)
9189
raise

0 commit comments

Comments
 (0)