Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 39 additions & 20 deletions examples/01_store/store_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,21 @@ def parse_args():
return vars(parser.parse_args())


def run_experiment(shmem, args, source_rank, destination_rank, buffer):
dtype = torch_dtype_from_str(args["datatype"])
def bench_store(
shmem,
source_rank,
destination_rank,
buffer,
BLOCK_SIZE,
dtype,
verbose=False,
validate=False,
num_experiments=1,
num_warmup=0,
):
cur_rank = shmem.get_rank()
world_size = shmem.get_num_ranks()

# Allocate source and destination buffers on the symmetric heap

if source_rank >= world_size:
raise ValueError(
f"Source rank must be less than or equal to the world size. World size is {world_size} and source rank is {source_rank}."
Expand All @@ -101,45 +109,39 @@ def run_experiment(shmem, args, source_rank, destination_rank, buffer):
f"Destination rank must be less than or equal to the world size. World size is {world_size} and destination rank is {destination_rank}."
)
if cur_rank == 0:
if args["verbose"]:
if verbose:
shmem.info(f"Measuring bandwidth between the ranks {source_rank} and {destination_rank}...")
n_elements = buffer.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

def run_experiment():
def run_store():
if cur_rank == source_rank:
kk = store_kernel[grid](
store_kernel[grid](
buffer,
n_elements,
source_rank,
destination_rank,
args["block_size"],
BLOCK_SIZE,
shmem.get_heap_bases(),
)

# Warmup
run_experiment()
shmem.barrier()

triton_ms = iris.do_bench(
run_experiment, shmem.barrier, n_repeat=args["num_experiments"], n_warmup=args["num_warmup"]
)
triton_ms = iris.do_bench(run_store, shmem.barrier, n_repeat=num_experiments, n_warmup=num_warmup)

bandwidth_gbps = 0
if cur_rank == source_rank:
triton_sec = triton_ms * 1e-3
element_size_bytes = torch.tensor([], dtype=dtype).element_size()
total_bytes = n_elements * element_size_bytes
bandwidth_gbps = total_bytes / triton_sec / 2**30
if args["verbose"]:
if verbose:
shmem.info(f"Copied {total_bytes / 2**30:.2f} GiB in {triton_sec:.4f} seconds")
shmem.info(f"Bandwidth between {source_rank} and {destination_rank} is {bandwidth_gbps:.4f} GiB/s")
shmem.barrier()
bandwidth_gbps = shmem.broadcast(bandwidth_gbps, source_rank)

success = True
if args["validate"] and cur_rank == destination_rank:
if args["verbose"]:
if validate and cur_rank == destination_rank:
if verbose:
shmem.info("Validating output...")

expected = torch.arange(n_elements, dtype=dtype, device="cuda")
Expand All @@ -157,15 +159,32 @@ def run_experiment():
success = False
break

if success and args["verbose"]:
if success and verbose:
shmem.info("Validation successful.")
if not success and args["verbose"]:
if not success and verbose:
shmem.error("Validation failed.")

shmem.barrier()
return bandwidth_gbps


def run_experiment(shmem, args, source_rank, destination_rank, buffer):
dtype = torch_dtype_from_str(args["datatype"])

return bench_store(
shmem,
source_rank,
destination_rank,
buffer,
args["block_size"],
dtype,
verbose=args["verbose"],
validate=args["validate"],
num_experiments=args["num_experiments"],
num_warmup=args["num_warmup"],
)


def print_bandwidth_matrix(matrix, label="Unidirectional STORE bandwidth GiB/s [Remote write]", output_file=None):
num_ranks = matrix.shape[0]
col_width = 10 # Adjust for alignment
Expand Down
59 changes: 59 additions & 0 deletions tests/examples/test_store_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import pytest
import torch
import triton
import triton.language as tl
import numpy as np
import iris

import importlib.util
from pathlib import Path

current_dir = Path(__file__).parent
file_path = (current_dir / "../../examples/01_store/store_bench.py").resolve()
module_name = "store_bench"
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)


@pytest.mark.parametrize(
"dtype",
[
torch.int8,
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"buffer_size, heap_size",
[
((1 << 32), (1 << 33)),
],
)
@pytest.mark.parametrize(
"block_size",
[
512,
1024,
],
)
def test_store_bench(dtype, buffer_size, heap_size, block_size):
shmem = iris.iris(heap_size)
num_ranks = shmem.get_num_ranks()

bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32)
element_size_bytes = torch.tensor([], dtype=dtype).element_size()
buffer = shmem.zeros(buffer_size // element_size_bytes, dtype=dtype)

shmem.barrier()

for source_rank in range(num_ranks):
for destination_rank in range(num_ranks):
bandwidth_gbps = module.bench_store(shmem, source_rank, destination_rank, buffer, block_size, dtype)
bandwidth_matrix[source_rank, destination_rank] = bandwidth_gbps
shmem.barrier()
Loading