Skip to content

Nvls channel setup inside the container #477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
rajagond opened this issue Mar 6, 2025 · 11 comments
Open

Nvls channel setup inside the container #477

rajagond opened this issue Mar 6, 2025 · 11 comments

Comments

@rajagond
Copy link

rajagond commented Mar 6, 2025

Hi,

am trying to use MSCCLPP inside a Singularity container with NVLS support. The setup below doesn't work inside the container. Is there a possible workaround? Thanks!

For GPU with nvls support, the IMEX channels should be set up (refer [cuMemCreate](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VA.html#group__CUDA__VA_1g899d69a862bba36449789c64b430dc7c)). You can set up the channels manually via:

sudo nvidia-modprobe -s -i <start:number of minors>
@Binyang2014
Copy link
Contributor

AFAIK, it requires execute the command on host/vm side... Can you access the vm? BTW, the pytorch doesn't support nvls right now, it's good to know more details about your case

@rajagond
Copy link
Author

rajagond commented Mar 6, 2025

I don't have access to the VM. It looks like NCCL doesn't require this.

I set the following environment variable and then ran the allreduce operation:

export NCCL_ALGO="allreduce:nvls"
# NCCL version 2.21.5+cuda12.4

It didn’t give me any errors. Am I missing anything?
I ran the nccl allreduce via torch.distributed.allreduce.

@Binyang2014
Copy link
Contributor

I think pytorch still working on nvls support, it requires memory registration before running communication collectives. pytorch/pytorch#136567. You can test the latency/busBW for allreduce, if it reaches the number like this for 1GB data, then NVLS is enabled. Otherwise, I think it just ignores the environment variable

#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
   268435456      67108864     float     sum      -1   1017.9  263.70  461.48      0   1017.5  263.83  461.70      0
   536870912     134217728     float     sum      -1   1987.8  270.08  472.64      0   1987.5  270.13  472.73      0
  1073741824     268435456     float     sum      -1   3930.7  273.17  478.05      0   3931.2  273.13  477.98      0

@rajagond
Copy link
Author

rajagond commented Mar 7, 2025

Just to confirm, there is no other way to use SHARP without creating an IMEX channel on an intranode H100 setup, right?

node-0:493516:493645 [1] NCCL INFO Connected NVLS tree
node-0:493516:493645 [1] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
node-0:493516:493645 [1] NCCL INFO 24 coll channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer
node-0:493516:493647 [3] NCCL INFO Connected NVLS tree
node-0:493516:493647 [3] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
node-0:493516:493647 [3] NCCL INFO 24 coll channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer
node-0:493516:493648 [4] NCCL INFO Connected NVLS tree
node-0:493516:493648 [4] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
node-0:493516:493648 [4] NCCL INFO 24 coll channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer
node-0:493516:493644 [0] NCCL INFO Connected NVLS tree
node-0:493516:493646 [2] NCCL INFO Connected NVLS tree
node-0:493516:493646 [2] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
node-0:493516:493646 [2] NCCL INFO 24 coll channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer
node-0:493516:493644 [0] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
node-0:493516:493644 [0] NCCL INFO 24 coll channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer
node-0:493516:493649 [5] NCCL INFO Connected NVLS tree
node-0:493516:493649 [5] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
node-0:493516:493649 [5] NCCL INFO 24 coll channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer
node-0:493516:493650 [6] NCCL INFO Connected NVLS tree
node-0:493516:493650 [6] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
node-0:493516:493650 [6] NCCL INFO 24 coll channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer
node-0:493516:493651 [7] NCCL INFO Connected NVLS tree
node-0:493516:493651 [7] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
node-0:493516:493651 [7] NCCL INFO 24 coll channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer
#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
   268435456      67108864     float     sum      -1   1138.6  235.76  412.58      0   1139.5  235.57  412.24      0
   536870912     134217728     float     sum      -1   2192.9  244.82  428.43      0   2199.2  244.12  427.21      0
  1073741824     268435456     float     sum      -1   4174.9  257.19  450.08      0   4177.6  257.02  449.79      0
  2147483648     536870912     float     sum      -1   8041.8  267.04  467.32      0   8032.8  267.34  467.84      0

I am getting this using nccl-tests.

@rajagond
Copy link
Author

rajagond commented Mar 7, 2025

I checked the other algos also.

#!/bin/bash

NCCL_DEBUG_SUBSYS="INIT,ENV,TUNING" NCCL_DEBUG="INFO"
NCCL_ALGO=Ring ./build/all_reduce_perf -b 256M -e 2G -f 2 -g 8 > results_ring.txt
NCCL_ALGO=Tree ./build/all_reduce_perf -b 256M -e 2G -f 2 -g 8 > results_tree.txt
NCCL_ALGO=NVLS ./build/all_reduce_perf -b 256M -e 2G -f 2 -g 8 > results_nvls.txt
export NCCL_DEBUG=INFO

# nThread 1 nGpus 8 minBytes 268435456 maxBytes 2147483648 step: 2(factor) warmup iters: 5 iters: 20 agg iters: 1 validation: 1 graph: 0
#
# Using devices
#  Rank  0 Group  0 Pid 493516 on     node-0 device  0 [0001:00:00] NVIDIA H100 80GB HBM3
#  Rank  1 Group  0 Pid 493516 on     node-0 device  1 [0002:00:00] NVIDIA H100 80GB HBM3
#  Rank  2 Group  0 Pid 493516 on     node-0 device  2 [0003:00:00] NVIDIA H100 80GB HBM3
#  Rank  3 Group  0 Pid 493516 on     node-0 device  3 [0008:00:00] NVIDIA H100 80GB HBM3
#  Rank  4 Group  0 Pid 493516 on     node-0 device  4 [0009:00:00] NVIDIA H100 80GB HBM3
#  Rank  5 Group  0 Pid 493516 on     node-0 device  5 [000a:00:00] NVIDIA H100 80GB HBM3
#  Rank  6 Group  0 Pid 493516 on     node-0 device  6 [000b:00:00] NVIDIA H100 80GB HBM3
#  Rank  7 Group  0 Pid 493516 on     node-0 device  7 [000c:00:00] NVIDIA H100 80GB HBM3
node-0:493516:493516 [0] NCCL INFO cudaDriverVersion 12040
NCCL version 2.19.4+cuda12.4
# NVLS
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
   268435456      67108864     float     sum      -1   1141.1  235.23  411.66      0   1139.9  235.50  412.12      0
   536870912     134217728     float     sum      -1   2195.5  244.53  427.93      0   2192.0  244.93  428.62      0
  1073741824     268435456     float     sum      -1   4183.2  256.68  449.19      0   4175.4  257.16  450.03      0
  2147483648     536870912     float     sum      -1   8041.4  267.05  467.34      0   8031.7  267.38  467.91      0

# Ring
#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
   268435456      67108864     float     sum      -1   1305.7  205.58  359.77      0   1306.7  205.42  359.49      0
   536870912     134217728     float     sum      -1   2573.4  208.62  365.09      0   2575.1  208.48  364.85      0
  1073741824     268435456     float     sum      -1   5105.9  210.29  368.01      0   5109.6  210.14  367.75      0
  2147483648     536870912     float     sum      -1    10194  210.66  368.66      0    10180  210.95  369.17      0

# Tree
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
   268435456      67108864     float     sum      -1   2076.5  129.27  226.23      0   2057.7  130.45  228.29      0
   536870912     134217728     float     sum      -1   3401.8  157.82  276.19      0   3407.1  157.57  275.75      0
  1073741824     268435456     float     sum      -1   6746.8  159.15  278.51      0   6744.5  159.20  278.60      0
  2147483648     536870912     float     sum      -1    13071  164.30  287.52      0    13098  163.95  286.91      0

Using Pytorch

# Pytorch default
Running on NVIDIA H100 80GB HBM3
Number of GPUs: 8
Tensor size: 65536 elements (1.00 GB for bfloat16)
Number of iterations: 10
Number of warmup iterations: 5
Number of trials: 10
Trial 1: Elapsed time = 40.701 ms, Bandwidth = 491.390 GB/s, Latency = 4070.086 us
Trial 2: Elapsed time = 40.438 ms, Bandwidth = 494.581 GB/s, Latency = 4043.824 us
Trial 3: Elapsed time = 40.369 ms, Bandwidth = 495.426 GB/s, Latency = 4036.928 us
Trial 4: Elapsed time = 40.468 ms, Bandwidth = 494.213 GB/s, Latency = 4046.835 us
Trial 5: Elapsed time = 40.373 ms, Bandwidth = 495.385 GB/s, Latency = 4037.261 us
Trial 6: Elapsed time = 40.345 ms, Bandwidth = 495.723 GB/s, Latency = 4034.512 us
Trial 7: Elapsed time = 40.398 ms, Bandwidth = 495.077 GB/s, Latency = 4039.776 us
Trial 8: Elapsed time = 40.397 ms, Bandwidth = 495.083 GB/s, Latency = 4039.725 us
Trial 9: Elapsed time = 40.398 ms, Bandwidth = 495.072 GB/s, Latency = 4039.814 us
Trial 10: Elapsed time = 40.355 ms, Bandwidth = 495.597 GB/s, Latency = 4035.536 us

Median Results:
Elapsed time = 40.398 ms
Bandwidth = 495.080 GB/s
Latency = 4039.750 us
e-0:518094:518094 [0] NCCL INFO 1073741824 Bytes -> Algo 4 proto 2 time 4051.531982

@Binyang2014
Copy link
Contributor

Now, in msccl++, we cannot use NVLS without IMEX channel. But seems your torch program worked with nvls support. Could you share you torch script and torch version you used? We want to know if we can enable the nvls in pytorch with some changes.

@rajagond
Copy link
Author

rajagond commented Mar 8, 2025

(vllm) aiscuser@node-0:~/c2_overlap$ python3
Python 3.12.9 | packaged by Anaconda, Inc. | (main, Feb  6 2025, 18:56:27) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.__version__
'2.5.1+cu124'
import os
import time
import argparse
import numpy as np
import torch
import torch.distributed as dist
import torch.cuda.nvtx as nvtx
import torch.multiprocessing as mp
from torch.cuda import Event

def init_process(rank, world_size, args, backend='nccl'):
    """Initialize the distributed process."""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    run(rank, world_size, args)
    dist.destroy_process_group()

def run(rank, world_size, args):
    """Run the benchmark on a single process."""
    # Set device
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")
    
    # Print information about the device
    if rank == 0:
        print(f"Running on {torch.cuda.get_device_name(device)}")
        print(f"Number of GPUs: {world_size}")
        print(f"Tensor size: {args.tokens} elements ({(args.tokens * 8192 * 2) / (1024**3):.2f} GB for bfloat16)")
        print(f"Number of iterations: {args.iters}")
        print(f"Number of warmup iterations: {args.warmup}")
        print(f"Number of trials: {args.trials}")
    
    # Create CUDA streams
    measure_stream = torch.cuda.Stream()
    spin_stream = torch.cuda.Stream()
    
    # Create tensors
    tensor = torch.rand(args.tokens, 8192, dtype=torch.bfloat16, device=device)
    
    # Create a GPU flag tensor for synchronization
    flag = torch.zeros(1, dtype=torch.int32, device=device)
    
    # Define a simple spin kernel as a PyTorch operation
    def spin_wait():
        with torch.cuda.stream(spin_stream):
            # Keep the GPU busy until flag is updated
            while flag.item() == 0:
                torch.cuda._sleep(100)  # Sleep for a short time to reduce polling overhead
    
    # Warmup
    for _ in range(args.warmup):
        dist.all_reduce(tensor)
    
    # Synchronize before starting measurements
    torch.cuda.synchronize()
    
    # Create events for timing
    timings = []
    
    for trial in range(args.trials):
        torch.cuda.synchronize()
        
        # Measurement code
        with torch.cuda.stream(measure_stream):
            # Create events for timing
            start_event = Event(enable_timing=True)
            end_event = Event(enable_timing=True)
            
            # Record start event
            start_event.record(measure_stream)
            
            # Run iterations of all_reduce
            for _ in range(args.iters):
                dist.all_reduce(tensor)
            
            # Record end event
            end_event.record(measure_stream)
        
        # Wait for the measurement stream to finish
        measure_stream.synchronize()
        
        # Calculate elapsed time (ms)
        elapsed_time = start_event.elapsed_time(end_event)
        
        # Calculate bandwidth (GB/s)
        tensor_size_bytes = tensor.numel() * tensor.element_size()
        total_bytes = tensor_size_bytes * args.iters * 2  # factor of 2 for all-reduce (send and receive)
        bandwidth = (total_bytes / (elapsed_time / 1000)) / (1024**3)  # GB/s
        
        # Calculate latency (us)
        latency = (elapsed_time * 1000) / args.iters  # us
        
        timings.append((elapsed_time, bandwidth, latency))
        
        if rank == 0:
            print(f"Trial {trial+1}: Elapsed time = {elapsed_time:.3f} ms, "
                  f"Bandwidth = {bandwidth:.3f} GB/s, "
                  f"Latency = {latency:.3f} us")
    
    # Report median results
    if rank == 0:
        elapsed_times, bandwidths, latencies = zip(*timings)
        median_time = np.median(elapsed_times)
        median_bandwidth = np.median(bandwidths)
        median_latency = np.median(latencies)
        
        print("\nMedian Results:")
        print(f"Elapsed time = {median_time:.3f} ms")
        print(f"Bandwidth = {median_bandwidth:.3f} GB/s")
        print(f"Latency = {median_latency:.3f} us")

def main():
    parser = argparse.ArgumentParser(description='PyTorch All-Reduce Benchmark')
    parser.add_argument('--tokens', type=int, default=65536, 
                        help='Tensor size in number of elements')
    parser.add_argument('--iters', type=int, default=10, 
                        help='Number of iterations per trial')
    parser.add_argument('--trials', type=int, default=10, 
                        help='Number of trials to run')
    parser.add_argument('--warmup', type=int, default=5, 
                        help='Number of warmup iterations')
    args = parser.parse_args()



    world_size = 8
    mp.set_start_method('spawn')
    
    processes = []
    for rank in range(world_size):
        p = mp.Process(target=init_process, args=(rank, world_size, args))
        p.start()
        processes.append(p)
    
    for p in processes:
        p.join()

if __name__ == "__main__":
    main()

@rajagond
Copy link
Author

rajagond commented Mar 8, 2025

Now, in msccl++, we cannot use NVLS without IMEX channel. But seems your torch program worked with nvls support. Could you share you torch script and torch version you used? We want to know if we can enable the nvls in pytorch with some changes.

As per #1632, we don't require imex channel for the intra-node setup. Please let me know once you integrate pytorch support.

@Binyang2014
Copy link
Contributor

Just checked the NCCL code, if memory not registered, NCCL will use a temp buffer to copy data to that buffer, do nvls allreduce then copy data to output buffer. Now we are working on the NVLS support

@Binyang2014
Copy link
Contributor

Binyang2014 commented Apr 3, 2025

Hi @rajagond, we already fixed this issue at branch: binyli/export_dlpack, will merge to main branch soon. Here is a sample code for how to use nvls inside container:

import os
import argparse
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda import Event

from mscclpp import RawGpuBuffer
from mscclpp.utils import GpuBuffer


def init_process(rank, world_size, args, backend='nccl'):
    """Initialize the distributed process."""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    run(rank, world_size, args)
    dist.destroy_process_group()

def run(rank, world_size, args):
    """Run the benchmark on a single process."""
    # Set device
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")
    
    # Print information about the device
    if rank == 0:
        print(f"Running on {torch.cuda.get_device_name(device)}")
        print(f"Number of GPUs: {world_size}")
        print(f"Tensor size: {args.tokens} elements ({(args.tokens * 8192 * 2) / (1024**3):.2f} GB for bfloat16)")
        print(f"Number of iterations: {args.iters}")
        print(f"Number of warmup iterations: {args.warmup}")
        print(f"Number of trials: {args.trials}")
    
    # Create CUDA streams
    measure_stream = torch.cuda.Stream()
    spin_stream = torch.cuda.Stream()

    buffer = RawGpuBuffer(args.tokens * 8192 * 2) # 2 bytes for bfloat16
    dl_pack = buffer.to_dlpack(str(torch.bfloat16))
    tensor = torch.utils.dlpack.from_dlpack(dl_pack)
    tensor.uniform_()

    # Warmup
    for _ in range(args.warmup):
        dist.all_reduce(tensor)
    
    # Synchronize before starting measurements
    torch.cuda.synchronize()
    
    # Create events for timing
    timings = []

    # dist.barrier()
    
    for trial in range(args.trials):
        torch.cuda.synchronize()
        
        # Measurement code
        with torch.cuda.stream(measure_stream):
            # Create events for timing
            start_event = Event(enable_timing=True)
            end_event = Event(enable_timing=True)
            
            torch.cuda.synchronize()
            # Record start event
            start_event.record(measure_stream)
            
            # Run iterations of all_reduce
            for _ in range(args.iters):
                dist.all_reduce(tensor)

            # Record end event
            end_event.record(measure_stream)
            measure_stream.synchronize()
        
        
        # Calculate elapsed time (ms)
        elapsed_time = start_event.elapsed_time(end_event)
        
        # Calculate bandwidth (GB/s)
        tensor_size_bytes = tensor.numel() * tensor.element_size()
        total_bytes = tensor_size_bytes * args.iters * 7 * 2 / 8
        bandwidth = (total_bytes / (elapsed_time / 1000)) / (1024**3)  # GB/s
        
        # Calculate latency (us)
        latency = (elapsed_time * 1000) / args.iters  # us
        
        timings.append((elapsed_time, bandwidth, latency))
        
        if rank == 0:
            print(f"Trial {trial+1}: Elapsed time = {elapsed_time:.3f} ms, "
                  f"Bandwidth = {bandwidth:.3f} GB/s, "
                  f"Latency = {latency:.3f} us")
    
    # Report median results
    if rank == 0:
        elapsed_times, bandwidths, latencies = zip(*timings)
        median_time = np.median(elapsed_times)
        median_bandwidth = np.median(bandwidths)
        median_latency = np.median(latencies)
        
        print("\nMedian Results:")
        print(f"Elapsed time = {median_time:.3f} ms")
        print(f"Bandwidth = {median_bandwidth:.3f} GB/s")
        print(f"Latency = {median_latency:.3f} us")

def main():
    parser = argparse.ArgumentParser(description='PyTorch All-Reduce Benchmark')
    parser.add_argument('--tokens', type=int, default=65536, 
                        help='Tensor size in number of elements')
    parser.add_argument('--iters', type=int, default=10, 
                        help='Number of iterations per trial')
    parser.add_argument('--trials', type=int, default=10, 
                        help='Number of trials to run')
    parser.add_argument('--warmup', type=int, default=5, 
                        help='Number of warmup iterations')
    args = parser.parse_args()



    world_size = 8
    mp.set_start_method('spawn')
    
    processes = []
    for rank in range(world_size):
        p = mp.Process(target=init_process, args=(rank, world_size, args))
        p.start()
        processes.append(p)
    
    for p in processes:
        p.join()

if __name__ == "__main__":
    main()

Run with nccl via command python3 torch_test.py
Run with mscclpp via command LD_PRELOAD=${MSCCLPP_HOME}/build/apps/nccl/libmscclpp_nccl.so MSCCLPP_EXECUTION_PLAN_DIR=${EXECUTION_PLAN_DIR} python3 torch_test.py
Please notice: you need to generate nvls execution plan via command:

mkdir -p ${EXECUTION_PLAN_DIR}
python3 ${MSCCLPP_HOME}/python/examples/allreduce_nvls.py 8 8 > ${EXECUTION_PLAN_DIR}/allreduce.json

@Binyang2014
Copy link
Contributor

The perf for nccl:

Number of GPUs: 8
Tensor size: 65536 elements (1.00 GB for bfloat16)
Number of iterations: 10
Number of warmup iterations: 5
Number of trials: 10
Trial 1: Elapsed time = 40.509 ms, Bandwidth = 431.998 GB/s, Latency = 4050.947 us
Trial 2: Elapsed time = 40.439 ms, Bandwidth = 432.752 GB/s, Latency = 4043.888 us
Trial 3: Elapsed time = 40.402 ms, Bandwidth = 433.147 GB/s, Latency = 4040.202 us
Trial 4: Elapsed time = 40.326 ms, Bandwidth = 433.963 GB/s, Latency = 4032.598 us
Trial 5: Elapsed time = 40.393 ms, Bandwidth = 433.244 GB/s, Latency = 4039.293 us
Trial 6: Elapsed time = 40.616 ms, Bandwidth = 430.868 GB/s, Latency = 4061.571 us
Trial 7: Elapsed time = 40.349 ms, Bandwidth = 433.718 GB/s, Latency = 4034.880 us
Trial 8: Elapsed time = 40.371 ms, Bandwidth = 433.479 GB/s, Latency = 4037.107 us
Trial 9: Elapsed time = 40.581 ms, Bandwidth = 431.237 GB/s, Latency = 4058.089 us
Trial 10: Elapsed time = 40.433 ms, Bandwidth = 432.817 GB/s, Latency = 4043.283 us

Median Results:
Elapsed time = 40.417 ms
Bandwidth = 432.982 GB/s
Latency = 4041.742 us

The perf for MSCCL++

Number of GPUs: 8
Tensor size: 65536 elements (1.00 GB for bfloat16)
Number of iterations: 10
Number of warmup iterations: 5
Number of trials: 10
Trial 1: Elapsed time = 41.816 ms, Bandwidth = 418.505 GB/s, Latency = 4181.552 us
Trial 2: Elapsed time = 39.202 ms, Bandwidth = 446.411 GB/s, Latency = 3920.157 us
Trial 3: Elapsed time = 39.191 ms, Bandwidth = 446.528 GB/s, Latency = 3919.127 us
Trial 4: Elapsed time = 41.031 ms, Bandwidth = 426.505 GB/s, Latency = 4103.117 us
Trial 5: Elapsed time = 39.242 ms, Bandwidth = 445.950 GB/s, Latency = 3924.208 us
Trial 6: Elapsed time = 41.658 ms, Bandwidth = 420.085 GB/s, Latency = 4165.827 us
Trial 7: Elapsed time = 39.251 ms, Bandwidth = 445.845 GB/s, Latency = 3925.130 us
Trial 8: Elapsed time = 39.198 ms, Bandwidth = 446.453 GB/s, Latency = 3919.786 us
Trial 9: Elapsed time = 39.204 ms, Bandwidth = 446.379 GB/s, Latency = 3920.432 us
Trial 10: Elapsed time = 41.404 ms, Bandwidth = 422.667 GB/s, Latency = 4140.374 us

Median Results:
Elapsed time = 39.247 ms
Bandwidth = 445.897 GB/s
Latency = 3924.669 us

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants