From e6171c205b79cb468a16320e7491d5b07700cb96 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 Aug 2025 21:58:35 +0000 Subject: [PATCH 1/3] Initial plan From 98cbda9dfcd2934afeffd6754cb2072a9c55abe5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 Aug 2025 22:03:43 +0000 Subject: [PATCH 2/3] Add bench_store function and pytest test for store_bench.py Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- examples/01_store/store_bench.py | 57 ++++++++++++++++++++--------- tests/examples/test_store_bench.py | 59 ++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 17 deletions(-) create mode 100644 tests/examples/test_store_bench.py diff --git a/examples/01_store/store_bench.py b/examples/01_store/store_bench.py index 835809f2..a3d551bb 100755 --- a/examples/01_store/store_bench.py +++ b/examples/01_store/store_bench.py @@ -85,13 +85,21 @@ def parse_args(): return vars(parser.parse_args()) -def run_experiment(shmem, args, source_rank, destination_rank, buffer): - dtype = torch_dtype_from_str(args["datatype"]) +def bench_store( + shmem, + source_rank, + destination_rank, + buffer, + BLOCK_SIZE, + dtype, + verbose=False, + validate=False, + num_experiments=1, + num_warmup=0, +): cur_rank = shmem.get_rank() world_size = shmem.get_num_ranks() - # Allocate source and destination buffers on the symmetric heap - if source_rank >= world_size: raise ValueError( f"Source rank must be less than or equal to the world size. World size is {world_size} and source rank is {source_rank}." @@ -101,29 +109,27 @@ def run_experiment(shmem, args, source_rank, destination_rank, buffer): f"Destination rank must be less than or equal to the world size. World size is {world_size} and destination rank is {destination_rank}." ) if cur_rank == 0: - if args["verbose"]: + if verbose: shmem.info(f"Measuring bandwidth between the ranks {source_rank} and {destination_rank}...") n_elements = buffer.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - def run_experiment(): + def run_store(): if cur_rank == source_rank: - kk = store_kernel[grid]( + store_kernel[grid]( buffer, n_elements, source_rank, destination_rank, - args["block_size"], + BLOCK_SIZE, shmem.get_heap_bases(), ) # Warmup - run_experiment() + run_store() shmem.barrier() - triton_ms = iris.do_bench( - run_experiment, shmem.barrier, n_repeat=args["num_experiments"], n_warmup=args["num_warmup"] - ) + triton_ms = iris.do_bench(run_store, shmem.barrier, n_repeat=num_experiments, n_warmup=num_warmup) bandwidth_gbps = 0 if cur_rank == source_rank: @@ -131,15 +137,15 @@ def run_experiment(): element_size_bytes = torch.tensor([], dtype=dtype).element_size() total_bytes = n_elements * element_size_bytes bandwidth_gbps = total_bytes / triton_sec / 2**30 - if args["verbose"]: + if verbose: shmem.info(f"Copied {total_bytes / 2**30:.2f} GiB in {triton_sec:.4f} seconds") shmem.info(f"Bandwidth between {source_rank} and {destination_rank} is {bandwidth_gbps:.4f} GiB/s") shmem.barrier() bandwidth_gbps = shmem.broadcast(bandwidth_gbps, source_rank) success = True - if args["validate"] and cur_rank == destination_rank: - if args["verbose"]: + if validate and cur_rank == destination_rank: + if verbose: shmem.info("Validating output...") expected = torch.arange(n_elements, dtype=dtype, device="cuda") @@ -157,15 +163,32 @@ def run_experiment(): success = False break - if success and args["verbose"]: + if success and verbose: shmem.info("Validation successful.") - if not success and args["verbose"]: + if not success and verbose: shmem.error("Validation failed.") shmem.barrier() return bandwidth_gbps +def run_experiment(shmem, args, source_rank, destination_rank, buffer): + dtype = torch_dtype_from_str(args["datatype"]) + + return bench_store( + shmem, + source_rank, + destination_rank, + buffer, + args["block_size"], + dtype, + verbose=args["verbose"], + validate=args["validate"], + num_experiments=args["num_experiments"], + num_warmup=args["num_warmup"], + ) + + def print_bandwidth_matrix(matrix, label="Unidirectional STORE bandwidth GiB/s [Remote write]", output_file=None): num_ranks = matrix.shape[0] col_width = 10 # Adjust for alignment diff --git a/tests/examples/test_store_bench.py b/tests/examples/test_store_bench.py new file mode 100644 index 00000000..f48172d2 --- /dev/null +++ b/tests/examples/test_store_bench.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch +import triton +import triton.language as tl +import numpy as np +import iris + +import importlib.util +from pathlib import Path + +current_dir = Path(__file__).parent +file_path = (current_dir / "../../examples/01_store/store_bench.py").resolve() +module_name = "store_bench" +spec = importlib.util.spec_from_file_location(module_name, file_path) +module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(module) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.int8, + torch.float16, + torch.bfloat16, + torch.float32, + ], +) +@pytest.mark.parametrize( + "buffer_size, heap_size", + [ + ((1 << 32), (1 << 33)), + ], +) +@pytest.mark.parametrize( + "block_size", + [ + 512, + 1024, + ], +) +def test_store_bench(dtype, buffer_size, heap_size, block_size): + shmem = iris.iris(heap_size) + num_ranks = shmem.get_num_ranks() + + bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32) + element_size_bytes = torch.tensor([], dtype=dtype).element_size() + buffer = shmem.zeros(buffer_size // element_size_bytes, dtype=dtype) + + shmem.barrier() + + for source_rank in range(num_ranks): + for destination_rank in range(num_ranks): + bandwidth_gbps = module.bench_store(shmem, source_rank, destination_rank, buffer, block_size, dtype) + bandwidth_matrix[source_rank, destination_rank] = bandwidth_gbps + shmem.barrier() From e753e36d7a8980c0439bd265d785e44259510304 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 30 Aug 2025 22:38:59 +0000 Subject: [PATCH 3/3] Fix barrier deadlock in bench_store function Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- examples/01_store/store_bench.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/01_store/store_bench.py b/examples/01_store/store_bench.py index a3d551bb..9dfdfa76 100755 --- a/examples/01_store/store_bench.py +++ b/examples/01_store/store_bench.py @@ -125,10 +125,6 @@ def run_store(): shmem.get_heap_bases(), ) - # Warmup - run_store() - shmem.barrier() - triton_ms = iris.do_bench(run_store, shmem.barrier, n_repeat=num_experiments, n_warmup=num_warmup) bandwidth_gbps = 0