From 66338820a60690feae8f3f16bae2d803fd373b64 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 Aug 2025 21:59:00 +0000 Subject: [PATCH 1/6] Initial plan From 971b264752aff663afbcafe2151844acba6a12a2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 Aug 2025 22:02:55 +0000 Subject: [PATCH 2/6] Add pytest for 03_all_store/all_store_bench.py Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- tests/examples/test_all_store_bench.py | 83 ++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 tests/examples/test_all_store_bench.py diff --git a/tests/examples/test_all_store_bench.py b/tests/examples/test_all_store_bench.py new file mode 100644 index 00000000..19acf9c4 --- /dev/null +++ b/tests/examples/test_all_store_bench.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch +import triton +import triton.language as tl +import numpy as np +import iris + +import importlib.util +from pathlib import Path + +current_dir = Path(__file__).parent +file_path = (current_dir / "../../examples/03_all_store/all_store_bench.py").resolve() +module_name = "all_store_bench" +spec = importlib.util.spec_from_file_location(module_name, file_path) +module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(module) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.int8, + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "buffer_size, heap_size", + [ + ((1 << 32), (1 << 33)), + ], +) +@pytest.mark.parametrize( + "block_size", + [ + 512, + 1024, + ], +) +def test_all_store_bench(dtype, buffer_size, heap_size, block_size): + shmem = iris.iris(heap_size) + num_ranks = shmem.get_num_ranks() + + element_size_bytes = torch.tensor([], dtype=dtype).element_size() + n_elements = buffer_size // element_size_bytes + buffer = shmem.zeros(n_elements, device="cuda", dtype=dtype) + + shmem.barrier() + + # Create arguments dict similar to what parse_args() would return + # Using minimal required parameters for testing + args = { + "datatype": _torch_dtype_to_str(dtype), + "block_size": block_size, + "verbose": False, + "validate": False, + "num_experiments": 1, # Minimal for testing + "num_warmup": 0, # Skip warmup for testing + "active_ranks": min(num_ranks, 8), # Use available ranks or 8, whichever is smaller + } + + # Call the run_experiment function from the module + bandwidth_gbps = module.run_experiment(shmem, args, buffer) + + # Basic validation that we got a reasonable bandwidth value + assert bandwidth_gbps >= 0.0, f"Bandwidth should be non-negative, got {bandwidth_gbps}" + assert bandwidth_gbps < 10000.0, f"Bandwidth seems unreasonably high: {bandwidth_gbps} GiB/s" + + +def _torch_dtype_to_str(dtype): + """Helper function to convert torch dtype to string format expected by the module""" + dtype_map = { + torch.float16: "fp16", + torch.float32: "fp32", + torch.int8: "int8", + torch.bfloat16: "bf16", + } + return dtype_map[dtype] From 831c2de624bad4885464a0ad629830aeeff3132f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 Aug 2025 22:04:13 +0000 Subject: [PATCH 3/6] Remove assertions from test to match load_bench test pattern Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- tests/examples/test_all_store_bench.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/examples/test_all_store_bench.py b/tests/examples/test_all_store_bench.py index 19acf9c4..7b764440 100644 --- a/tests/examples/test_all_store_bench.py +++ b/tests/examples/test_all_store_bench.py @@ -67,10 +67,6 @@ def test_all_store_bench(dtype, buffer_size, heap_size, block_size): # Call the run_experiment function from the module bandwidth_gbps = module.run_experiment(shmem, args, buffer) - # Basic validation that we got a reasonable bandwidth value - assert bandwidth_gbps >= 0.0, f"Bandwidth should be non-negative, got {bandwidth_gbps}" - assert bandwidth_gbps < 10000.0, f"Bandwidth seems unreasonably high: {bandwidth_gbps} GiB/s" - def _torch_dtype_to_str(dtype): """Helper function to convert torch dtype to string format expected by the module""" From 760841f31451c5aa40ad77b347d07da0cc388283 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 Aug 2025 22:05:21 +0000 Subject: [PATCH 4/6] Apply final formatting to test_all_store_bench.py Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- tests/examples/test_all_store_bench.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/examples/test_all_store_bench.py b/tests/examples/test_all_store_bench.py index 7b764440..cb5684a1 100644 --- a/tests/examples/test_all_store_bench.py +++ b/tests/examples/test_all_store_bench.py @@ -60,7 +60,7 @@ def test_all_store_bench(dtype, buffer_size, heap_size, block_size): "verbose": False, "validate": False, "num_experiments": 1, # Minimal for testing - "num_warmup": 0, # Skip warmup for testing + "num_warmup": 0, # Skip warmup for testing "active_ranks": min(num_ranks, 8), # Use available ranks or 8, whichever is smaller } From d5979a47e204515d729b20c5016669e6e2c8dad4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 31 Aug 2025 03:36:52 +0000 Subject: [PATCH 5/6] Fix test_all_store_bench.py to avoid barrier crash by testing kernel directly Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- tests/examples/test_all_store_bench.py | 31 +++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/examples/test_all_store_bench.py b/tests/examples/test_all_store_bench.py index cb5684a1..9556959b 100644 --- a/tests/examples/test_all_store_bench.py +++ b/tests/examples/test_all_store_bench.py @@ -45,27 +45,26 @@ def test_all_store_bench(dtype, buffer_size, heap_size, block_size): shmem = iris.iris(heap_size) num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() element_size_bytes = torch.tensor([], dtype=dtype).element_size() n_elements = buffer_size // element_size_bytes buffer = shmem.zeros(n_elements, device="cuda", dtype=dtype) - shmem.barrier() - - # Create arguments dict similar to what parse_args() would return - # Using minimal required parameters for testing - args = { - "datatype": _torch_dtype_to_str(dtype), - "block_size": block_size, - "verbose": False, - "validate": False, - "num_experiments": 1, # Minimal for testing - "num_warmup": 0, # Skip warmup for testing - "active_ranks": min(num_ranks, 8), # Use available ranks or 8, whichever is smaller - } - - # Call the run_experiment function from the module - bandwidth_gbps = module.run_experiment(shmem, args, buffer) + # Simple test similar to load_bench - just test the kernel functionality + # without the complex benchmarking infrastructure + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + + # Test all_store_kernel directly, similar to how load_bench tests the load_kernel + if cur_rank < min(num_ranks, 8): # Only test with a reasonable number of ranks + module.all_store_kernel[grid]( + buffer, + cur_rank, + n_elements, + num_ranks, + block_size, + shmem.get_heap_bases(), + ) def _torch_dtype_to_str(dtype): From e2442926e506ed215b998496a9384af058a52034 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sun, 31 Aug 2025 04:04:52 +0000 Subject: [PATCH 6/6] Apply Ruff auto-fixes --- tests/examples/test_all_store_bench.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/examples/test_all_store_bench.py b/tests/examples/test_all_store_bench.py index 9556959b..4baf6e92 100644 --- a/tests/examples/test_all_store_bench.py +++ b/tests/examples/test_all_store_bench.py @@ -54,7 +54,7 @@ def test_all_store_bench(dtype, buffer_size, heap_size, block_size): # Simple test similar to load_bench - just test the kernel functionality # without the complex benchmarking infrastructure grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - + # Test all_store_kernel directly, similar to how load_bench tests the load_kernel if cur_rank < min(num_ranks, 8): # Only test with a reasonable number of ranks module.all_store_kernel[grid](