Skip to content

Commit d5a876d

Browse files
authored
Implementation of Atomic And, Min, and Max (#92)
1 parent 2e095fb commit d5a876d

File tree

6 files changed

+369
-2
lines changed

6 files changed

+369
-2
lines changed

iris/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
atomic_cas,
1919
atomic_xchg,
2020
atomic_xor,
21-
atomic_or
21+
atomic_or,
22+
atomic_and,
23+
atomic_min,
24+
atomic_max,
2225
)
2326

2427
from .util import (
@@ -60,6 +63,9 @@
6063
"atomic_xchg",
6164
"atomic_xor",
6265
"atomic_or",
66+
"atomic_and",
67+
"atomic_min",
68+
"atomic_max",
6369
"do_bench",
6470
"memset_tensor",
6571
"hip",

iris/iris.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,33 @@ def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None
576576
return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope)
577577

578578

579+
@triton.jit
580+
def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):
581+
"""
582+
Performs an atomic and at the specified rank's memory location.
583+
584+
This function performs an atomic and operation by translating the pointer
585+
from the from_rank's address space to the to_rank's address space and atomically
586+
anding the provided data to the to_rank memory location. If the from_rank and to_rank are the same,
587+
this function performs a local atomic and operation.
588+
589+
Args:
590+
pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
591+
val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation.
592+
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
593+
to_rank (int): The rank ID to which the atomic operation will be performed.
594+
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
595+
mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.
596+
sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics.
597+
scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu".
598+
599+
Returns:
600+
Block: The data stored at pointer before the atomic operation.
601+
"""
602+
translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases)
603+
return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope)
604+
605+
579606
@triton.jit
580607
def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):
581608
"""
@@ -603,6 +630,60 @@ def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None,
603630
return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope)
604631

605632

633+
@triton.jit
634+
def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):
635+
"""
636+
Performs an atomic min at the specified rank's memory location.
637+
638+
This function performs an atomic min operation by translating the pointer
639+
from the from_rank's address space to the to_rank's address space and atomically
640+
performing the min on the provided data to the to_rank memory location. If the from_rank and to_rank are the same,
641+
this function performs a local atomic min operation.
642+
643+
Args:
644+
pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
645+
val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation.
646+
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
647+
to_rank (int): The rank ID to which the atomic operation will be performed.
648+
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
649+
mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.
650+
sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics.
651+
scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu".
652+
653+
Returns:
654+
Block: The data stored at pointer before the atomic operation.
655+
"""
656+
translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases)
657+
return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope)
658+
659+
660+
@triton.jit
661+
def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None):
662+
"""
663+
Performs an atomic max at the specified rank's memory location.
664+
665+
This function performs an atomic max operation by translating the pointer
666+
from the from_rank's address space to the to_rank's address space and atomically
667+
performing the max on the provided data to the to_rank memory location. If the from_rank and to_rank are the same,
668+
this function performs a local atomic max operation.
669+
670+
Args:
671+
pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
672+
val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation.
673+
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
674+
to_rank (int): The rank ID to which the atomic operation will be performed.
675+
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
676+
mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.
677+
sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics.
678+
scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu".
679+
680+
Returns:
681+
Block: The data stored at pointer before the atomic operation.
682+
"""
683+
translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases)
684+
return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope)
685+
686+
606687
def iris(heap_size=1 << 30):
607688
"""
608689
Create and return an Iris instance with the specified heap size.

tests/unittests/test_atomic_and.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
import pytest
8+
import iris
9+
10+
11+
@triton.jit
12+
def atomic_and_kernel(
13+
results,
14+
sem: tl.constexpr,
15+
scope: tl.constexpr,
16+
cur_rank: tl.constexpr,
17+
num_ranks: tl.constexpr,
18+
BLOCK_SIZE: tl.constexpr,
19+
heap_bases: tl.tensor,
20+
):
21+
pid = tl.program_id(0)
22+
block_start = pid * BLOCK_SIZE
23+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
24+
mask = offsets < BLOCK_SIZE
25+
26+
bit = (cur_rank // 32) % 2
27+
val = bit << (cur_rank % results.type.element_ty.primitive_bitwidth)
28+
acc = tl.full([BLOCK_SIZE], val, dtype=results.type.element_ty)
29+
30+
for target_rank in range(num_ranks):
31+
iris.atomic_and(results + offsets, acc, cur_rank, target_rank, heap_bases, mask, sem=sem, scope=scope)
32+
33+
34+
@pytest.mark.parametrize(
35+
"dtype",
36+
[
37+
torch.int32,
38+
torch.int64,
39+
],
40+
)
41+
@pytest.mark.parametrize(
42+
"sem",
43+
[
44+
"acquire",
45+
"release",
46+
"acq_rel",
47+
],
48+
)
49+
@pytest.mark.parametrize(
50+
"scope",
51+
[
52+
"cta",
53+
"gpu",
54+
"sys",
55+
],
56+
)
57+
@pytest.mark.parametrize(
58+
"BLOCK_SIZE",
59+
[
60+
1,
61+
8,
62+
16,
63+
32,
64+
],
65+
)
66+
def test_atomic_and_api(dtype, sem, scope, BLOCK_SIZE):
67+
# TODO: Adjust heap size.
68+
shmem = iris.iris(1 << 20)
69+
num_ranks = shmem.get_num_ranks()
70+
heap_bases = shmem.get_heap_bases()
71+
cur_rank = shmem.get_rank()
72+
73+
bit_width = 32 if dtype == torch.int32 else 64
74+
effective_bits = min(num_ranks, bit_width)
75+
initial_mask = (1 << effective_bits) - 1
76+
77+
results = shmem.full((BLOCK_SIZE,), initial_mask, dtype=dtype)
78+
79+
grid = lambda meta: (1,)
80+
atomic_and_kernel[grid](results, sem, scope, cur_rank, num_ranks, BLOCK_SIZE, heap_bases)
81+
shmem.barrier()
82+
83+
# All ranks start out with a full mask vector 0xFFFFFF (initial_mask)
84+
# All ranks then take turns in clearing the their bit position in the mask
85+
# By the end we would have effective_bits - num_ranks many ones followed by num_ranks zeros
86+
expected_scalar = ~((1 << num_ranks) - 1) & initial_mask
87+
expected = torch.full((BLOCK_SIZE,), expected_scalar, dtype=dtype, device="cuda")
88+
89+
try:
90+
torch.testing.assert_close(results, expected, rtol=0, atol=0)
91+
except AssertionError as e:
92+
print(e)
93+
print("Expected:", expected)
94+
print("Actual :", results)
95+
raise

tests/unittests/test_atomic_max.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
import pytest
8+
import iris
9+
10+
11+
@triton.jit
12+
def atomic_max_kernel(
13+
results,
14+
sem: tl.constexpr,
15+
scope: tl.constexpr,
16+
cur_rank: tl.constexpr,
17+
num_ranks: tl.constexpr,
18+
BLOCK_SIZE: tl.constexpr,
19+
heap_bases: tl.tensor,
20+
):
21+
pid = tl.program_id(0)
22+
block_start = pid * BLOCK_SIZE
23+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
24+
mask = offsets < BLOCK_SIZE
25+
26+
acc = tl.full([BLOCK_SIZE], cur_rank + 1, dtype=results.type.element_ty)
27+
28+
for target_rank in range(num_ranks):
29+
iris.atomic_max(results + offsets, acc, cur_rank, target_rank, heap_bases, mask, sem=sem, scope=scope)
30+
31+
32+
@pytest.mark.parametrize(
33+
"dtype",
34+
[
35+
torch.int32,
36+
torch.int64,
37+
],
38+
)
39+
@pytest.mark.parametrize(
40+
"sem",
41+
[
42+
"acquire",
43+
"release",
44+
"acq_rel",
45+
],
46+
)
47+
@pytest.mark.parametrize(
48+
"scope",
49+
[
50+
"cta",
51+
"gpu",
52+
"sys",
53+
],
54+
)
55+
@pytest.mark.parametrize(
56+
"BLOCK_SIZE",
57+
[
58+
1,
59+
8,
60+
16,
61+
32,
62+
],
63+
)
64+
def test_atomic_max_api(dtype, sem, scope, BLOCK_SIZE):
65+
# TODO: Adjust heap size.
66+
shmem = iris.iris(1 << 20)
67+
num_ranks = shmem.get_num_ranks()
68+
heap_bases = shmem.get_heap_bases()
69+
cur_rank = shmem.get_rank()
70+
71+
min_val = torch.iinfo(dtype).min
72+
results = shmem.full((BLOCK_SIZE,), min_val, dtype=dtype)
73+
74+
grid = lambda meta: (1,)
75+
atomic_max_kernel[grid](results, sem, scope, cur_rank, num_ranks, BLOCK_SIZE, heap_bases)
76+
shmem.barrier()
77+
78+
# All ranks participate in performing the max operation
79+
# Each rank performs the atomic operation: max(rank_id + 1)
80+
# The result equals the ID of the last rank + 1
81+
expected = torch.full((BLOCK_SIZE,), num_ranks, dtype=dtype, device="cuda")
82+
83+
try:
84+
torch.testing.assert_close(results, expected, rtol=0, atol=0)
85+
except AssertionError as e:
86+
print(e)
87+
print("Expected:", expected)
88+
print("Actual :", results)
89+
raise

tests/unittests/test_atomic_min.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
import pytest
8+
import iris
9+
10+
11+
@triton.jit
12+
def atomic_min_kernel(
13+
results,
14+
sem: tl.constexpr,
15+
scope: tl.constexpr,
16+
cur_rank: tl.constexpr,
17+
num_ranks: tl.constexpr,
18+
BLOCK_SIZE: tl.constexpr,
19+
heap_bases: tl.tensor,
20+
):
21+
pid = tl.program_id(0)
22+
block_start = pid * BLOCK_SIZE
23+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
24+
mask = offsets < BLOCK_SIZE
25+
26+
acc = tl.full([BLOCK_SIZE], cur_rank + 1, dtype=results.type.element_ty)
27+
28+
for target_rank in range(num_ranks):
29+
iris.atomic_min(results + offsets, acc, cur_rank, target_rank, heap_bases, mask, sem=sem, scope=scope)
30+
31+
32+
@pytest.mark.parametrize(
33+
"dtype",
34+
[
35+
torch.int32,
36+
torch.int64,
37+
],
38+
)
39+
@pytest.mark.parametrize(
40+
"sem",
41+
[
42+
"acquire",
43+
"release",
44+
"acq_rel",
45+
],
46+
)
47+
@pytest.mark.parametrize(
48+
"scope",
49+
[
50+
"cta",
51+
"gpu",
52+
"sys",
53+
],
54+
)
55+
@pytest.mark.parametrize(
56+
"BLOCK_SIZE",
57+
[
58+
1,
59+
8,
60+
16,
61+
32,
62+
],
63+
)
64+
def test_atomic_min_api(dtype, sem, scope, BLOCK_SIZE):
65+
# TODO: Adjust heap size.
66+
shmem = iris.iris(1 << 20)
67+
num_ranks = shmem.get_num_ranks()
68+
heap_bases = shmem.get_heap_bases()
69+
cur_rank = shmem.get_rank()
70+
71+
max_val = torch.iinfo(dtype).max
72+
results = shmem.full((BLOCK_SIZE,), max_val, dtype=dtype)
73+
74+
grid = lambda meta: (1,)
75+
atomic_min_kernel[grid](results, sem, scope, cur_rank, num_ranks, BLOCK_SIZE, heap_bases)
76+
shmem.barrier()
77+
# All ranks participate in performing the max operation
78+
# Each rank performs the atomic operation: max(rank_id + 1)
79+
# The result equals the ID of the first rank + 1
80+
expected = torch.full((BLOCK_SIZE,), 1, dtype=dtype, device="cuda")
81+
82+
try:
83+
torch.testing.assert_close(results, expected, rtol=0, atol=0)
84+
except AssertionError as e:
85+
print(e)
86+
print("Expected:", expected)
87+
print("Actual :", results)
88+
raise

0 commit comments

Comments
 (0)