You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Performs an atomic and at the specified rank's memory location.
583
+
584
+
This function performs an atomic and operation by translating the pointer
585
+
from the from_rank's address space to the to_rank's address space and atomically
586
+
anding the provided data to the to_rank memory location. If the from_rank and to_rank are the same,
587
+
this function performs a local atomic and operation.
588
+
589
+
Args:
590
+
pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
591
+
val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation.
592
+
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
593
+
to_rank (int): The rank ID to which the atomic operation will be performed.
594
+
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
595
+
mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.
596
+
sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics.
597
+
scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu".
598
+
599
+
Returns:
600
+
Block: The data stored at pointer before the atomic operation.
Performs an atomic min at the specified rank's memory location.
637
+
638
+
This function performs an atomic min operation by translating the pointer
639
+
from the from_rank's address space to the to_rank's address space and atomically
640
+
performing the min on the provided data to the to_rank memory location. If the from_rank and to_rank are the same,
641
+
this function performs a local atomic min operation.
642
+
643
+
Args:
644
+
pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
645
+
val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation.
646
+
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
647
+
to_rank (int): The rank ID to which the atomic operation will be performed.
648
+
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
649
+
mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.
650
+
sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics.
651
+
scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu".
652
+
653
+
Returns:
654
+
Block: The data stored at pointer before the atomic operation.
Performs an atomic max at the specified rank's memory location.
664
+
665
+
This function performs an atomic max operation by translating the pointer
666
+
from the from_rank's address space to the to_rank's address space and atomically
667
+
performing the max on the provided data to the to_rank memory location. If the from_rank and to_rank are the same,
668
+
this function performs a local atomic max operation.
669
+
670
+
Args:
671
+
pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
672
+
val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation.
673
+
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
674
+
to_rank (int): The rank ID to which the atomic operation will be performed.
675
+
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
676
+
mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None.
677
+
sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics.
678
+
scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu".
679
+
680
+
Returns:
681
+
Block: The data stored at pointer before the atomic operation.
0 commit comments