-
Notifications
You must be signed in to change notification settings - Fork 294
[DCU] Support the deployment and operation of tilelang on the Hygon DCU backend #1145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThis pull request adds comprehensive DCU (AMD CDNA) GPU support to TileLang, including DCU-specific GEMM layouts, matrix intrinsic codegen paths (MMAC), device-side HIP templates for core operations, a new matmul intrinsic emitter, build configuration adjustments, and example/test code. Changes
Sequence Diagram(s)sequenceDiagram
participant Python as Python (tl_matmul)
participant Lower as TileLang Lowering
participant Codegen as HIP Codegen
participant Device as Device Code (HIP)
Python->>Lower: JIT compile GEMM with MatrixCoreIntrinEmitter
Lower->>Codegen: Emit intrinsics for DCU target
alt TargetIsDCU check
Codegen->>Codegen: Use makeGemmFragmentCDCU layout
Codegen->>Codegen: Dispatch tl::tvm_mmac() path
else Non-DCU target
Codegen->>Codegen: Use makeGemmFragmentCCDNA layout
end
Codegen->>Codegen: Generate __builtin_amdgcn_mmac_* calls
Codegen-->>Device: Emit HIP kernel code
Device->>Device: ldmatrix load A/B tiles
Device->>Device: mmac multiply-accumulate
Device->>Device: stmatrix store results
Device-->>Python: Return kernel + profiler
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Areas requiring extra attention:
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 17
🧹 Nitpick comments (16)
tilelang/contrib/rocm.py (1)
230-230: Note: Inconsistent default paths across the file.While this change targets DCU support, other functions in this file still default to
/opt/rocm(see lines 57, 140-143, 280). For a complete DCU integration, you may want to apply similar path detection logic to those functions as well, or alternatively usefind_rocm_path()(which checks environment variables) as the source of truth for the installation path.examples/minference/ops/vertical_slash_index.hip (3)
1-14: Remove unused include and tighten device helper annotations.ATen/dtk_macros.h appears unused and may break some builds. Also consider adding restrict/inline hints for device helpers.
Apply:
-#include <ATen/dtk_macros.h> +#include <ATen/ATen.h>Optional tweaks:
-__device__ void save_blocks( +__device__ __forceinline__ void save_blocks(And annotate kernel params as restrict for better codegen:
- const int* seqlens, + const int* __restrict__ seqlens,(Apply restrict to other pointer params similarly.)
21-35: Consider documenting input ordering and monotonicity assumptions.Kernel assumes slash/vertical indexes are per-(batch, head) slices and monotonically increasing; otherwise logic breaks. Add comments or TORCH_CHECKs to assert sortedness in debug paths.
Would you like a lightweight CPU validator to assert monotonicity/sizes before launch?
95-113: Optional: add launch_bounds to guide occupancy.Since N_THREADS is fixed at 64, annotate kernel for better occupancy/compile-time checks.
Apply:
-__global__ void convert_vertical_slash_indexes_kernel( +__launch_bounds__(64) +__global__ void convert_vertical_slash_indexes_kernel(tilelang/intrinsics/mmac_macro_generator.py (2)
19-31: Consider annotating class-level constants withClassVar.The
dtype_abbrvdictionary is a class-level constant that should be annotated withtyping.ClassVarto clarify intent and satisfy static analysis.Apply this diff:
+from typing import ClassVar + class MatrixCoreIntrinEmitter(object): """ To eliminate Python syntax within TIR Macro. """ M_DIM = 16 N_DIM = 16 WARP_SIZE = 64 - dtype_abbrv = { + dtype_abbrv: ClassVar[dict[str, str]] = { "float16": "fp16", "bfloat16": "bf16", "float32": "fp32", "int8": "int8", "int32": "int32", "float8_e4m3": "e4m3", "float8_e5m2": "e5m2", "float8_e4m3fnuz": "e4m3fnuz", }
582-582: Consider using consistent notation for thread transformations.Lines 582 and 592 use bitwise operations
((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4), while lines 293 and 304 use arithmetic operations(tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16for the same transformation. Both are functionally equivalent but mixing styles reduces readability.Also applies to: 592-592
src/layout/gemm_layouts.cc (1)
750-750: Clarify the purpose of the commented alternative.The commented-out call to
makeHalfBankSwizzleLayoutsuggests an alternative swizzling strategy. Consider either:
- Removing it if it's no longer relevant
- Adding a brief comment explaining why this alternative exists and under what conditions it might be preferred
src/op/gemm.cc (1)
7-7: Remove unused include.The
<fstream>header appears to be unused in this file. Consider removing it to keep the codebase clean.-#include <fstream>src/tl_templates/dcu_hip/threadblock_swizzle.h (1)
1-46: Consider extracting the duplicated ceil_div lambda.Both
rasterization2DRowandrasterization2DColumndefine identicalceil_divlambdas. While this duplication is minor and the functions are templates (so there's no runtime cost), extracting it to a shared helper would improve maintainability.namespace tl { namespace detail { constexpr auto ceil_div = [](int a, int b) { return (a + b - 1) / b; }; } template <int panel_width> TL_DEVICE dim3 rasterization2DRow() { const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; const unsigned int grid_size = gridDim.x * gridDim.y; const unsigned int panel_size = panel_width * gridDim.x; const unsigned int panel_offset = block_idx % panel_size; const unsigned int panel_idx = block_idx / panel_size; const unsigned int total_panel = detail::ceil_div(grid_size, panel_size); // ... rest of implementation } template <int panel_width> TL_DEVICE dim3 rasterization2DColumn() { const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x; const unsigned int grid_size = gridDim.x * gridDim.y; const unsigned int panel_size = panel_width * gridDim.y; const unsigned int panel_offset = block_idx % panel_size; const unsigned int panel_idx = block_idx / panel_size; const unsigned int total_panel = detail::ceil_div(grid_size, panel_size); // ... rest of implementation } } // namespace tltesting/python/dcu/test_tilelang_gemm_mmac_intrinsic.py (1)
66-66: Remove unused variable.The variable
cache_write_sharedis assigned but never used. Consider removing it to keep the code clean.shared_scope = "shared" - cache_write_shared = Falseexamples/gemm/example_gemm_intrinsics_dcu.py (1)
17-25: Swizzle enablement is overly strictThe 512-bit row check only enables swizzle for 64B rows (e.g., fp16, K=32). For int8 (row_bytes=32B) swizzle remains off, though the layout util supports 32B/64B/128B. Consider enabling when row_bytes % 32 == 0 and passing swizzle_bytes accordingly.
src/tl_templates/dcu_hip/reduce.h (1)
44-53: AllReduce requires non-null red_buf for threads ≥ warpSizeWhen offset >= warpSize, you unconditionally write to red_buf. Document and assert this precondition, or switch to an internal shared-memory buffer.
Would you like me to wire a templated shared-memory scratch allocation path for this?
src/tl_templates/dcu_hip/copy.h (2)
58-62: Clarify fence selection incp_async_wait.The function defaults to
async_gld_fence(N)with a commented alternativeasync_gld_sld_fence(N). Consider documenting when each fence type should be used, or provide template parameters to select the appropriate fence based on memory access patterns.
64-73: Unused template parameterpre_nop.The template parameter
pre_nopis declared but never used in the function body. Consider removing it if not needed, or add a TODO comment if it's reserved for future functionality.src/tl_templates/dcu_hip/gemm.h (2)
32-51: Address const-correctness in bfloat16 MFMA.Similar to the int8 case, lines 39-40 use
const_castto remove constness. Additionally, the manual loop (lines 43-46) to copy data could be simplified if alignment is guaranteed.Consider:
- Remove
const_castand use const pointers throughout- If alignment is guaranteed, use direct casting instead of manual copying:
- short *b_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(b)); - short *a_short = reinterpret_cast<short *>(const_cast<bfloat16_t *>(a)); + const short *b_short = reinterpret_cast<const short *>(b); + const short *a_short = reinterpret_cast<const short *>(a); - // Copy the data - for (int i = 0; i < 4; ++i) { - b_vec[i] = b_short[i]; - a_vec[i] = a_short[i]; - } + b_vec = *reinterpret_cast<const bfloat16x4_vec *>(b_short); + a_vec = *reinterpret_cast<const bfloat16x4_vec *>(a_short);
72-72: Resolve commented-out static_assert.Line 72 has a commented-out
static_assertforclear_accum. Either uncomment it if the feature is unsupported, or remove it if support has been added.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (23)
examples/gemm/example_gemm_intrinsics_dcu.py(1 hunks)examples/minference/ops/vertical_slash_index.hip(1 hunks)src/layout/gemm_layouts.cc(2 hunks)src/layout/layout.h(1 hunks)src/op/gemm.cc(2 hunks)src/target/codegen_hip.cc(3 hunks)src/target/intrin_rule_hip.cc(2 hunks)src/target/utils.cc(2 hunks)src/target/utils.h(1 hunks)src/tl_templates/dcu_hip/common.h(1 hunks)src/tl_templates/dcu_hip/copy.h(1 hunks)src/tl_templates/dcu_hip/core.hpp(1 hunks)src/tl_templates/dcu_hip/debug.h(1 hunks)src/tl_templates/dcu_hip/gemm.h(1 hunks)src/tl_templates/dcu_hip/hip_fp8.h(1 hunks)src/tl_templates/dcu_hip/ldsm.h(1 hunks)src/tl_templates/dcu_hip/reduce.h(1 hunks)src/tl_templates/dcu_hip/threadblock_swizzle.h(1 hunks)testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py(1 hunks)tilelang/contrib/hipcc.py(1 hunks)tilelang/contrib/rocm.py(1 hunks)tilelang/engine/lower.py(1 hunks)tilelang/intrinsics/mmac_macro_generator.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (13)
src/op/gemm.cc (2)
src/target/utils.cc (2)
TargetIsDCU(82-90)TargetIsDCU(82-82)src/layout/gemm_layouts.cc (4)
makeGemmFragmentCDCU(159-174)makeGemmFragmentCDCU(159-161)makeGemmFragmentCCDNA(176-191)makeGemmFragmentCCDNA(176-178)
src/layout/layout.h (1)
src/layout/gemm_layouts.cc (2)
makeGemmFragmentCDCU(159-174)makeGemmFragmentCDCU(159-161)
src/tl_templates/dcu_hip/reduce.h (2)
src/tl_templates/dcu_hip/core.hpp (20)
T(31-34)T(37-40)T(43-46)T(69-72)T(75-78)T(81-84)max(31-31)max(37-37)max(43-43)max(49-49)max(55-58)max(55-55)max(62-62)min(69-69)min(75-75)min(81-81)min(87-87)min(93-96)min(93-93)min(100-100)src/tl_templates/dcu_hip/common.h (1)
__half(113-113)
examples/gemm/example_gemm_intrinsics_dcu.py (10)
tilelang/intrinsics/mma_layout.py (1)
get_swizzle_layout(166-201)tilelang/intrinsics/mmac_macro_generator.py (7)
MatrixCoreIntrinEmitter(14-396)ldmatrix_a(227-266)ldmatrix_a(453-525)ldmatrix_b(268-312)ldmatrix_b(527-602)mmac(314-346)stmatrix(348-396)tilelang/transform/simplify.py (1)
simplify_prim_func(53-59)tilelang/env.py (1)
disable_cache(271-272)tilelang/language/allocate.py (2)
alloc_shared(24-39)alloc_local(42-53)tilelang/language/annotations.py (2)
annotate_layout(25-36)use_swizzle(17-22)tilelang/language/fill.py (1)
clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/parallel.py (1)
Parallel(9-29)tilelang/profiler/__init__.py (1)
assert_allclose(77-146)
src/tl_templates/dcu_hip/threadblock_swizzle.h (3)
tilelang/carver/roller/rasterization.py (1)
panel_width(16-18)tilelang/carver/roller/hint.py (1)
stride(45-46)src/tl_templates/cuda/threadblock_swizzle.h (1)
rasterization2DColumn(25-41)
src/tl_templates/dcu_hip/common.h (2)
src/tl_templates/dcu_hip/gemm.h (4)
half(23-29)bfloat16_t(32-51)void(155-234)void(236-296)src/tl_templates/dcu_hip/reduce.h (3)
_Float16(32-32)T(92-164)__half(29-29)
src/target/utils.h (1)
src/target/utils.cc (2)
TargetIsDCU(82-90)TargetIsDCU(82-82)
src/tl_templates/dcu_hip/copy.h (1)
src/tl_templates/cuda/copy.h (1)
cp_async_wait(20-26)
src/tl_templates/dcu_hip/hip_fp8.h (1)
src/tl_templates/cuda/cuda_fp8.h (3)
fp8_e4_2_t(9-12)make_fp8_e4_4_t(88-97)make_fp8_e4_8_t(100-109)
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py (12)
tilelang/intrinsics/mma_layout.py (1)
get_swizzle_layout(166-201)tilelang/intrinsics/mmac_macro_generator.py (7)
MatrixCoreIntrinEmitter(14-396)ldmatrix_a(227-266)ldmatrix_a(453-525)ldmatrix_b(268-312)ldmatrix_b(527-602)mmac(314-346)stmatrix(348-396)tilelang/transform/simplify.py (1)
simplify_prim_func(53-59)tilelang/testing/__init__.py (1)
set_random_seed(30-35)tilelang/env.py (1)
disable_cache(271-272)tilelang/language/allocate.py (2)
alloc_shared(24-39)alloc_local(42-53)tilelang/language/annotations.py (2)
annotate_layout(25-36)use_swizzle(17-22)tilelang/language/fill.py (1)
clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(11-87)tilelang/language/parallel.py (1)
Parallel(9-29)tilelang/jit/__init__.py (1)
compile(30-79)
src/tl_templates/dcu_hip/debug.h (1)
src/tl_templates/dcu_hip/common.h (1)
half_t(116-116)
src/tl_templates/dcu_hip/gemm.h (1)
tilelang/intrinsics/mfma_layout.py (1)
make_mfma_swizzle_layout(130-152)
tilelang/intrinsics/mmac_macro_generator.py (5)
tilelang/intrinsics/utils.py (1)
mfma_store_index_map(85-86)tilelang/tileop/gemm/gemm_base.py (2)
k_pack(111-112)chunk(63-64)tilelang/language/kernel.py (2)
threads(215-219)KernelLaunchFrame(95-226)tilelang/intrinsics/mfma_layout.py (16)
shared_16x4_to_local_64x1_layout_A(6-8)shared_4x16_to_local_64x1_layout_B(17-19)shared_16x16_to_local_64x4_layout_A(46-49)shared_16x16_to_local_64x4_layout_B(58-61)shared_16x32_to_local_64x8_layout_A(88-91)shared_16x32_to_local_64x8_layout_B(100-103)shared_16x64_to_local_64x16_layout_A(112-115)shared_16x64_to_local_64x16_layout_B(124-127)thread_id_shared_access_64x1_to_16x4_layout_A(11-14)thread_id_shared_access_64x1_to_4x16_layout_B(22-25)thread_id_shared_access_64x4_to_16x16_layout_A(40-43)thread_id_shared_access_64x4_to_16x16_layout_B(52-55)thread_id_shared_access_64x8_to_16x32_layout_A(82-85)thread_id_shared_access_64x8_to_16x32_layout_B(94-97)thread_id_shared_access_64x16_to_16x64_layout_A(106-109)thread_id_shared_access_64x16_to_16x64_layout_B(118-121)tilelang/language/ast/ir.py (2)
index_map(1673-1679)meta_var(1731-1750)
🪛 Ruff (0.14.1)
testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py
66-66: Local variable cache_write_shared is assigned to but never used
Remove assignment to unused variable cache_write_shared
(F841)
tilelang/intrinsics/mmac_macro_generator.py
22-31: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
96-96: Avoid specifying long messages outside the exception class
(TRY003)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
253-253: Ambiguous variable name: l
(E741)
261-261: Ambiguous variable name: l
(E741)
294-294: Ambiguous variable name: l
(E741)
305-305: Ambiguous variable name: l
(E741)
481-481: Ambiguous variable name: l
(E741)
490-490: Ambiguous variable name: l
(E741)
509-509: Ambiguous variable name: l
(E741)
519-519: Ambiguous variable name: l
(E741)
555-555: Ambiguous variable name: l
(E741)
564-564: Ambiguous variable name: l
(E741)
583-583: Ambiguous variable name: l
(E741)
593-593: Ambiguous variable name: l
(E741)
🔇 Additional comments (29)
examples/minference/ops/vertical_slash_index.hip (1)
121-158: The review comment is incorrect—Python binding exists.The pybind11 binding for
convert_vertical_slash_indexesis already present inexamples/minference/ops/kernels.cpp(lines 13-16), explicitly exposing the function to Python viam.def("convert_vertical_slash_indexes", ...). No action is needed.Likely an incorrect or invalid review comment.
src/layout/layout.h (1)
153-155: LGTM! DCU fragment factory follows established patterns.The new
makeGemmFragmentCDCUdeclaration is consistent with existing GEMM fragment factories (CCDNA, Hopper, etc.) and has a corresponding implementation.src/target/utils.h (1)
25-25: LGTM! Target predicate follows established conventions.The
TargetIsDCUdeclaration is consistent with existing target detection APIs.src/target/intrin_rule_hip.cc (1)
243-251: LGTM! Intrinsic registration follows HIP patterns.The
tir.hip.__shflregistration correctly implements a non-sync shuffle variant with appropriate arguments (var, lane, width).tilelang/intrinsics/mmac_macro_generator.py (3)
227-266: LGTM! Matrix load logic correctly handles transposed and non-transposed layouts.The
ldmatrix_aimplementation properly:
- Extracts thread bindings and applies reverse index mapping
- Handles both transposed and non-transposed cases with appropriate coordinate calculations
- Uses vectorized loads for efficiency
314-346: LGTM! MMAC intrinsic invocation correctly handles vectorized types.The method properly:
- Constructs vectorized data type strings when
local_size > 1- Computes correct buffer offsets for A, B, and C matrices
- Invokes the
tvm_mmacintrinsic with appropriate layout and type parameters
293-293: Request manual verification of thread ID transformation logic for DCU MMAC B matrix.Lines 293 and 304 in the base
MatrixCoreIntrinEmitter.ldmatrix_b()method apply thread ID transformation:(tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16. This transformation is also present (in equivalent bit-shift form) at lines 582 and 592 in the derived class's preshuffle implementation, but notably absent from the preshuffle variant's main path (lines 554, 563).The transformation appears intentional but lacks documentation or tests explaining:
- Why DCU MMAC B-matrix requires thread reordering while A-matrix doesn't (ldmatrix_a uses plain
tx)- Why the preshuffle optimization eliminates this transformation requirement
- Whether this formula is architecturally correct for gfx936
Verify this transformation is correct for your target DCU architecture and document the design rationale.
tilelang/engine/lower.py (1)
109-109: Verify the optimization level change is intentional.The addition of
-O1represents a significant reduction from the typical-O3level. While this may be necessary for DCU compatibility or to work around compiler issues, it could negatively impact performance.Please confirm:
- Is this optimization level required for DCU correctness?
- Have performance implications been evaluated?
- Can a higher optimization level be used once DCU support matures?
src/tl_templates/dcu_hip/ldsm.h (1)
1-3: LGTM! Minimal DCU HIP header structure is appropriate.The header correctly uses
#pragma onceand includes the common DCU HIP definitions.tilelang/contrib/hipcc.py (1)
64-65: Verify compiler flag changes for DCU.Two significant changes:
- Optimization level: Changed to
-O1(matching the change intilelang/engine/lower.py). This reduces optimization but may be required for DCU compatibility.- Warning suppression: Added
-Wno-invalid-constexpr. This suggests the generated HIP code contains constexpr usage that doesn't meet HIP compiler requirements.Please confirm:
- Are these flags specifically required for DCU/gfx936 targets?
- What constexpr issues does the warning suppression address?
- Is there a plan to fix the underlying constexpr issues rather than suppressing warnings?
src/layout/gemm_layouts.cc (1)
159-174: LGTM! DCU-specific fragment layout follows the established pattern.The new
makeGemmFragmentCDCUfunction mirrors the structure ofmakeGemmFragmentCCDNAwith the key distinction being the finalRepeatcall parameters(true, true)vs(true, false). This differentiation aligns with DCU-specific layout requirements while maintaining consistency with the existing codebase architecture.src/op/gemm.cc (1)
831-840: LGTM! DCU fragment selection follows the established pattern.The conditional selection between
makeGemmFragmentCDCUandmakeGemmFragmentCCDNAbased onTargetIsDCU(T.target)correctly routes DCU targets to their specialized fragment path while preserving existing CDNA behavior.src/target/codegen_hip.cc (2)
140-140: LGTM! Macro enables required warp synchronization features.Defining
HIP_ENABLE_WARP_SYNC_BUILTINSbefore includinghip_runtime.his necessary for accessing warp-level synchronization primitives on HIP.
150-155: LGTM! Include path updates align with DCU-specific template organization.The switch from
tl_templates/hip/...totl_templates/dcu_hip/...correctly routes to the DCU-specific implementations introduced in this PR.testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py (2)
15-27: LGTM! Swizzle layout helper correctly implements bank conflict avoidance.The
make_swizzle_layoutfunction appropriately checks for 512-bit row alignment before applying swizzling, ensuring optimal shared memory access patterns for DCU.
30-179: LGTM! Matrix multiplication kernel demonstrates proper DCU intrinsic usage.The
tl_matmulfunction correctly orchestrates:
- Shared memory allocation with appropriate swizzling
- Pipelined outer loop for K-dimension blocking
- MMAC intrinsic-based computation via
MatrixCoreIntrinEmitter- Proper store-back from local fragments to global memory
The implementation serves as a solid reference for DCU-based GEMM operations.
examples/gemm/example_gemm_intrinsics_dcu.py (1)
123-124: Confirm panel_size=10 is supported by the swizzle patternThreadblock swizzle device function is templated by panel size; many implementations assume powers-of-two. Please confirm 10 is valid for your target, or change to 8/16.
src/tl_templates/dcu_hip/core.hpp (1)
28-66: LGTM for min/max helpers and host/device guardsThe overload set covers scalar and variadic forms with proper device specializations for float/double.
Also applies to: 68-105
src/tl_templates/dcu_hip/copy.h (4)
1-14: LGTM: Type aliases and includes are well-structured.The type aliases are clear and the use of
ck_tile::int32x4_tfrom the common header provides consistency across the DCU backend.
16-31: LGTM: Buffer resource construction follows DCU patterns.The
buffer_resourcestruct andmake_wave_buffer_resourcefunction correctly construct a buffer descriptor and normalize lanes using__builtin_amdgcn_readfirstlane, which is the appropriate pattern for ensuring uniform values across a wave.
33-39: LGTM: M0 register manipulation is correct.These helpers correctly use inline assembly to manipulate the M0 register, which is standard practice for controlling LDS (Local Data Share) operations on AMDGPU.
82-86: Pointer arithmetic analysis verified—clarify or simplify N=4 access pattern.Your analysis is mathematically correct. The pointer operations cause all threads to access the same location:
- Base:
global_base_ptr - threadIdx.x * 4(in bytes)- Offset:
threadIdx.x * 4(equals N when N=4)- Effective:
(global_base_ptr - threadIdx.x * 4) + threadIdx.x * 4 = global_base_ptrThis same pattern appears in both
cp_async_gsandcp_async_gs_conditionalfunctions insrc/tl_templates/dcu_hip/copy.h(lines 84-85, 102-103) and mirrors the same code insrc/tl_templates/hip/copy.h.The unusual pointer subtraction followed by offset reinstatement should either be:
- Documented to clarify the wave-level intent, or
- Simplified to
make_wave_buffer_resource(global_base_ptr), threadIdx.x * Nif per-thread access is intended.Given this pattern's consistency across multiple locations, verify whether this is intentional wave-level behavior or an unintended cancellation that needs correction.
src/tl_templates/dcu_hip/debug.h (2)
1-99: LGTM: Debug print specializations are comprehensive and type-safe.The
debug_print_varspecializations cover a good range of types with appropriate format specifiers. The casting of narrow types (char, short) to int/unsigned int for printf is correct practice for variadic functions.
101-191: LGTM: Buffer value printing is well-implemented.The
debug_print_buffer_valuespecializations correctly handle buffer context along with value printing. The half_t specialization appropriately converts to float for printf compatibility.src/tl_templates/dcu_hip/gemm.h (5)
23-29: LGTM: Half-precision MFMA correctly uses vector types.The specialization appropriately casts to
float16x4pointers for the MFMA intrinsic. This assumes proper alignment, which should be guaranteed by the calling context.
53-64: LGTM: FP8 specialization correctly handles const pointers.The FP8 MFMA specialization properly uses
reinterpret_cast<const int64_t *>without const_cast violations. Good example of proper const-correctness.
117-139: LGTM: Swizzle layout matches reference implementation.The
make_mfma_swizzle_layoutfunction correctly implements the bank conflict avoidance pattern consistent with the Python reference intilelang/intrinsics/mfma_layout.py.
155-234: LGTM: GEMM body implements standard MFMA pattern.The shared-to-register GEMM computation correctly implements the fetch-and-compute pattern for MFMA instructions. The lane ID swizzling (line 168) and transpose logic are appropriate for the MFMA 16x16 instruction.
236-323: LGTM: Register-shared GEMM and wrappers are correctly implemented.The
body_rsfunction appropriately handles the case where matrix A is already in registers, and the free function wrappers provide a clean API for both shared-shared and register-shared GEMM variants.
| if out_dtype == "int32": | ||
| micro_size_k = 32 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
micro_size_k should depend on in_dtype, not out_dtype
Int8 paths require K-fragment=32 regardless of the final out dtype. Gate on in_dtype to avoid incorrect tiling when out_dtype is int32 but inputs aren’t int8.
Apply this diff:
- if out_dtype == "int32":
- micro_size_k = 32
+ if in_dtype == "int8":
+ micro_size_k = 32📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if out_dtype == "int32": | |
| micro_size_k = 32 | |
| if in_dtype == "int8": | |
| micro_size_k = 32 | |
🤖 Prompt for AI Agents
In examples/gemm/example_gemm_intrinsics_dcu.py around lines 50 to 52, the code
sets micro_size_k based on out_dtype but it should be gated on in_dtype: change
the condition to check if in_dtype == "int8" (or equivalent int8 type sentinel
used elsewhere) and set micro_size_k = 32 for int8 input paths; leave other
dtype branches unchanged so int8 inputs always use K-fragment=32 even if
out_dtype is int32.
| block_M = block_row_warps * warp_row_tiles | ||
| block_N = block_col_warps * warp_col_tiles | ||
| block_K = chunk | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Tile edge handling missing; kernel reads/writes out of bounds for non-multiple sizes
Grid uses ceildiv, but loads/stores assume perfect multiples of block_M/N/K. Either add guards in load/store loops or assert divisibility.
Minimal safety precondition (quick fix):
block_K = chunk
@@
- A_shape = (M, K)
+ # Precondition for this example kernel (simplifies control flow)
+ assert M % block_M == 0 and N % block_N == 0 and K % block_K == 0, \
+ "Example kernel currently requires M,N,K to be multiples of block sizes"
+ A_shape = (M, K)Recommended follow-up (optional): add bounds checks in the A/B loads and C stores so arbitrary M,N,K work. I can draft that if you want.
Also applies to: 108-159
I need to examine the actual code to verify the concerns raised in the review comment.
🏁 Script executed:
#!/bin/bash
# Retrieve the file and examine the kernel implementation
if [ -f "examples/gemm/example_gemm_intrinsics_dcu.py" ]; then
echo "=== File Content ==="
cat examples/gemm/example_gemm_intrinsics_dcu.py
else
echo "File not found at examples/gemm/example_gemm_intrinsics_dcu.py"
find . -name "example_gemm_intrinsics_dcu.py" -type f 2>/dev/null
fiLength of output: 5708
Add bounds-checking guards or assertion for non-multiple tensor dimensions
The kernel grid uses ceildiv (line 99) but load/store loops assume perfect multiples. With M, N, or K not divisible by block_M, block_N, or block_K, the loops read/write out of bounds. The current test (M=N=K=16384) masks this because all are divisible by their block sizes.
Either add bounds checks in the A/B loads (lines 117–121) and C store (lines 138–145), or add an early assertion:
block_K = chunk
@@
+ assert M % block_M == 0 and N % block_N == 0 and K % block_K == 0, \
+ "Example kernel requires M, N, K to be multiples of block sizes (M:%d, N:%d, K:%d; block_M:%d, block_N:%d, block_K:%d)" % (M, N, K, block_M, block_N, block_K)
A_shape = (M, K)🤖 Prompt for AI Agents
In examples/gemm/example_gemm_intrinsics_dcu.py around lines 65–68, the computed
block sizes (block_M, block_N, block_K) are used later by loads/stores that
assume tensor dimensions are exact multiples; add either runtime assertions
early (assert M % block_M == 0 and N % block_N == 0 and K % block_K == 0) or,
preferably, add bounds-check guards around A/B loads (lines ~117–121) and the C
store (lines ~138–145): before reading A or B elements check the computed global
row/col indices against M/N/K and substitute zero (or a safe value) for
out-of-bounds loads; before writing C check indices and skip stores outside M/N,
ensuring no out-of-bounds memory access.
| C[by * block_M + i, bx * block_N + j] = C_shared[ | ||
| j // micro_size_y, | ||
| i // micro_size_x, | ||
| i % micro_size_x, | ||
| j % micro_size_y, | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C_shared indexing appears transposed vs declared shape
C_shared is declared as (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y), but store uses [j-group, i-group, i-rem, j-rem]. Swap i/j groups to match the shape.
Apply this diff:
- C[by * block_M + i, bx * block_N + j] = C_shared[
- j // micro_size_y,
- i // micro_size_x,
- i % micro_size_x,
- j % micro_size_y,
- ]
+ C[by * block_M + i, bx * block_N + j] = C_shared[
+ i // micro_size_x,
+ j // micro_size_y,
+ i % micro_size_x,
+ j % micro_size_y,
+ ]🤖 Prompt for AI Agents
In examples/gemm/example_gemm_intrinsics_dcu.py around lines 153 to 158, the
indexing used to read from C_shared is transposed relative to its declared shape
(C_shared declared as (block_M // micro_size_x, block_N // micro_size_y,
micro_size_x, micro_size_y)); fix the read to match that shape by swapping the
group indices so the access becomes C_shared[i // micro_size_x, j //
micro_size_y, i % micro_size_x, j % micro_size_y] instead of the current [j //
micro_size_y, i // micro_size_x, i % micro_size_x, j % micro_size_y].
| int seqlen = seqlens[batch_idx]; | ||
| int block_idx_m = group_idx * blockDim.x + threadIdx.x; | ||
| int start_m = block_idx_m * BLOCK_SIZE_M; | ||
| if (start_m >= seqlen) { | ||
| return; | ||
| } | ||
| int end_m = start_m + BLOCK_SIZE_M; | ||
| vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; | ||
| slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; | ||
| int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; | ||
| block_count += row_offset; | ||
| block_offset += row_offset * NNZ_S; | ||
| column_count += row_offset; | ||
| column_index += row_offset * NNZ_V; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against row index overflow.
Threads with block_idx_m ≥ N_ROWS can compute a valid start_m < seqlen (if seqlen > context), causing OOB on row_offset. Add an explicit guard.
Apply:
int seqlen = seqlens[batch_idx];
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
+ if (block_idx_m >= N_ROWS) {
+ return;
+ }
int start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= seqlen) {
return;
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| int seqlen = seqlens[batch_idx]; | |
| int block_idx_m = group_idx * blockDim.x + threadIdx.x; | |
| int start_m = block_idx_m * BLOCK_SIZE_M; | |
| if (start_m >= seqlen) { | |
| return; | |
| } | |
| int end_m = start_m + BLOCK_SIZE_M; | |
| vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; | |
| slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; | |
| int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; | |
| block_count += row_offset; | |
| block_offset += row_offset * NNZ_S; | |
| column_count += row_offset; | |
| column_index += row_offset * NNZ_V; | |
| int seqlen = seqlens[batch_idx]; | |
| int block_idx_m = group_idx * blockDim.x + threadIdx.x; | |
| if (block_idx_m >= N_ROWS) { | |
| return; | |
| } | |
| int start_m = block_idx_m * BLOCK_SIZE_M; | |
| if (start_m >= seqlen) { | |
| return; | |
| } | |
| int end_m = start_m + BLOCK_SIZE_M; | |
| vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; | |
| slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; | |
| int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; | |
| block_count += row_offset; | |
| block_offset += row_offset * NNZ_S; | |
| column_count += row_offset; | |
| column_index += row_offset * NNZ_V; |
🤖 Prompt for AI Agents
In examples/minference/ops/vertical_slash_index.hip around lines 40 to 54,
threads where block_idx_m >= N_ROWS can still have start_m < seqlen and will
compute row_offset and use it causing out-of-bounds accesses; add an explicit
guard right after computing block_idx_m (before computing row_offset and any
row-dependent offsets) that returns when block_idx_m >= N_ROWS so subsequent
uses of row_offset, block_count/offset, column_count/index are safe.
| int tmp_col_cnt = 0, tmp_blk_cnt = 0; | ||
| int s = 0, v = 0; | ||
| int v_idx = vertical_indexes[v++]; | ||
| int s_idx = slash_indexes[s++]; | ||
| while (s_idx >= end_m) { | ||
| s_idx = slash_indexes[s++]; | ||
| } | ||
| s_idx = max(end_m - s_idx, BLOCK_SIZE_M); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix OOB reads when NNZ_S/NNZ_V are zero and bound the pre-scan.
Accessing vertical_indexes[v++] and slash_indexes[s++] without checking NNZ_* risks OOB. The pre-loop while also lacks a bound on s.
Apply:
- int tmp_col_cnt = 0, tmp_blk_cnt = 0;
- int s = 0, v = 0;
- int v_idx = vertical_indexes[v++];
- int s_idx = slash_indexes[s++];
- while (s_idx >= end_m) {
- s_idx = slash_indexes[s++];
- }
- s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
+ int tmp_col_cnt = 0, tmp_blk_cnt = 0;
+ int s = 0, v = 0;
+ // Safe init of v_idx
+ int v_idx = (NNZ_V > 0) ? vertical_indexes[v++] : (end_m + BLOCK_SIZE_M);
+ // Handle NNZ_S == 0 early
+ if (NNZ_S == 0) {
+ block_count[0] = 0;
+ column_count[0] = 0;
+ return;
+ }
+ int s_idx = slash_indexes[s++];
+ while (s < NNZ_S && s_idx >= end_m) {
+ s_idx = slash_indexes[s++];
+ }
+ if (s_idx >= end_m) {
+ // No slash indices relevant for this row
+ block_count[0] = 0;
+ column_count[0] = 0;
+ return;
+ }
+ s_idx = max(end_m - s_idx, BLOCK_SIZE_M);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| int tmp_col_cnt = 0, tmp_blk_cnt = 0; | |
| int s = 0, v = 0; | |
| int v_idx = vertical_indexes[v++]; | |
| int s_idx = slash_indexes[s++]; | |
| while (s_idx >= end_m) { | |
| s_idx = slash_indexes[s++]; | |
| } | |
| s_idx = max(end_m - s_idx, BLOCK_SIZE_M); | |
| int tmp_col_cnt = 0, tmp_blk_cnt = 0; | |
| int s = 0, v = 0; | |
| // Safe init of v_idx | |
| int v_idx = (NNZ_V > 0) ? vertical_indexes[v++] : (end_m + BLOCK_SIZE_M); | |
| // Handle NNZ_S == 0 early | |
| if (NNZ_S == 0) { | |
| block_count[0] = 0; | |
| column_count[0] = 0; | |
| return; | |
| } | |
| int s_idx = slash_indexes[s++]; | |
| while (s < NNZ_S && s_idx >= end_m) { | |
| s_idx = slash_indexes[s++]; | |
| } | |
| if (s_idx >= end_m) { | |
| // No slash indices relevant for this row | |
| block_count[0] = 0; | |
| column_count[0] = 0; | |
| return; | |
| } | |
| s_idx = max(end_m - s_idx, BLOCK_SIZE_M); |
| template <> struct MfmaTraits<int8_t> { | ||
| template <typename AccType> | ||
| static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) { | ||
| int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b)); | ||
| int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a)); | ||
|
|
||
| *c = __builtin_amdgcn_mmac_i32_16x16x32i8(*b_packed, *a_packed, *c); | ||
| } | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove const_cast from const pointers.
Lines 15-16 use const_cast to remove constness from the input pointers before casting to int64_t*. While the MFMA intrinsic likely doesn't modify the data, casting away const violates const-correctness and can lead to undefined behavior if the underlying data is actually const.
The intrinsic should accept const pointers. If it doesn't, consider using reinterpret_cast<const int64_t*> directly without const_cast, or verify the intrinsic signature:
- int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b));
- int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a));
+ const int64_t *b_packed = reinterpret_cast<const int64_t *>(b);
+ const int64_t *a_packed = reinterpret_cast<const int64_t *>(a);Then update line 18 if needed based on the intrinsic's signature.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| template <> struct MfmaTraits<int8_t> { | |
| template <typename AccType> | |
| static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) { | |
| int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b)); | |
| int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a)); | |
| *c = __builtin_amdgcn_mmac_i32_16x16x32i8(*b_packed, *a_packed, *c); | |
| } | |
| }; | |
| template <> struct MfmaTraits<int8_t> { | |
| template <typename AccType> | |
| static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) { | |
| const int64_t *b_packed = reinterpret_cast<const int64_t *>(b); | |
| const int64_t *a_packed = reinterpret_cast<const int64_t *>(a); | |
| *c = __builtin_amdgcn_mmac_i32_16x16x32i8(*b_packed, *a_packed, *c); | |
| } | |
| }; |
🤖 Prompt for AI Agents
In src/tl_templates/dcu_hip/gemm.h around lines 12 to 20, remove the const_cast
usage that strips constness from the input pointers: change the local packed
pointer types to const int64_t* and use reinterpret_cast<const int64_t*>(b) and
reinterpret_cast<const int64_t*>(a) instead of const_cast; then pass *b_packed
and *a_packed to the MFMA intrinsic as before (or adjust the call if the
intrinsic requires non-pointer/value types) so that const-correctness is
preserved.
| __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z, | ||
| fp8_e4_t w) { | ||
| // reinterpret the 4 fp8_e4_t values to signed char value and shift | ||
| signed char x_char = *reinterpret_cast<signed char *>(&x); | ||
| signed char y_char = *reinterpret_cast<signed char *>(&y); | ||
| signed char z_char = *reinterpret_cast<signed char *>(&z); | ||
| signed char w_char = *reinterpret_cast<signed char *>(&w); | ||
| int res = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char; | ||
| return *reinterpret_cast<fp8_e4_4_t *>(&res); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid UB from type-punning via reinterpret_cast; construct via fields
The current helpers reinterpret int buffers to fp8_e4_4_t/fp8_e4_8_t, which risks aliasing/alignment UB. Build via named members instead.
Apply this diff:
-__device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
- fp8_e4_t w) {
- // reinterpret the 4 fp8_e4_t values to signed char value and shift
- signed char x_char = *reinterpret_cast<signed char *>(&x);
- signed char y_char = *reinterpret_cast<signed char *>(&y);
- signed char z_char = *reinterpret_cast<signed char *>(&z);
- signed char w_char = *reinterpret_cast<signed char *>(&w);
- int res = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char;
- return *reinterpret_cast<fp8_e4_4_t *>(&res);
-}
+__device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
+ fp8_e4_t w) {
+ fp8_e4_4_t out{};
+ out.x = x; out.y = y; out.z = z; out.w = w;
+ return out;
+}
@@
-__device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
- fp8_e4_t w, fp8_e4_t v, fp8_e4_t u,
- fp8_e4_t t, fp8_e4_t s) {
- signed char x_char = *reinterpret_cast<signed char *>(&x);
- signed char y_char = *reinterpret_cast<signed char *>(&y);
- signed char z_char = *reinterpret_cast<signed char *>(&z);
- signed char w_char = *reinterpret_cast<signed char *>(&w);
- signed char v_char = *reinterpret_cast<signed char *>(&v);
- signed char u_char = *reinterpret_cast<signed char *>(&u);
- signed char t_char = *reinterpret_cast<signed char *>(&t);
- signed char s_char = *reinterpret_cast<signed char *>(&s);
- int a = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char;
- int b = (s_char << 24) | (t_char << 16) | (u_char << 8) | v_char;
- fp8_e4_8_t res;
- res.x = *reinterpret_cast<fp8_e4_4_t *>(&a);
- res.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
- return res;
-}
+__device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x, fp8_e4_t y, fp8_e4_t z,
+ fp8_e4_t w, fp8_e4_t v, fp8_e4_t u,
+ fp8_e4_t t, fp8_e4_t s) {
+ fp8_e4_8_t out{};
+ out.x = make_fp8_e4_4_t(x, y, z, w);
+ out.y = make_fp8_e4_4_t(v, u, t, s);
+ return out;
+}Also applies to: 57-74
🤖 Prompt for AI Agents
In src/tl_templates/dcu_hip/hip_fp8.h around lines 46-55 (and similarly 57-74),
the function uses reinterpret_cast to type-pun fp8 bytes into signed char and to
reinterpret an int as fp8_e4_4_t, which can cause aliasing/alignment UB; replace
the reinterpret_casts by directly constructing and returning the fp8_e4_4_t
using its named byte/member fields (assign x,y,z,w to the struct members in the
correct order) and similarly for the 8-wide variant, avoiding any pointer casts
or reinterpreting integers as the struct — ensure you preserve byte ordering and
use explicit member assignment to build the result.
| static_assert(threads == 1024 or threads == 512 or threads == 256 or | ||
| threads == 128 or threads == 64 or threads == 32); | ||
| template <typename T, int SEG = 32> | ||
| static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H, | ||
| int W) { | ||
|
|
||
| constexpr int TILE_H = threads / SEG; | ||
| constexpr uint64_t MASK = 0xffffffffffffffffULL; | ||
| const int num_blocks = (H + TILE_H - 1) / TILE_H; | ||
| const int tid = threadIdx.x; | ||
| const int lane = tid % 64; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix compile/runtime issues in CumSum2D
- Return type should be void (it writes into dst).
- Don’t
return;inside the row-bounds check; usecontinue;. - Use 32-bit mask type expected by HIP intrinsics.
- Pass integer lane indices to __shfl_sync, not T-typed zeros.
Apply this diff:
- template <typename T, int SEG = 32>
- static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H,
- int W) {
+ template <typename T, int SEG = 32>
+ static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst, int H,
+ int W) {
@@
- constexpr uint64_t MASK = 0xffffffffffffffffULL;
+ const unsigned MASK = 0xFFFFFFFFu;
@@
- if (gRow >= H)
- return;
+ if (gRow >= H)
+ continue;
@@
- T segSum = (T)__shfl_sync(MASK, val, (T)0);
+ T segSum = (T)__shfl_sync(MASK, val, 0);
@@
- carry = (T)__shfl_sync(MASK, carry, (T)0);
+ carry = (T)__shfl_sync(MASK, carry, 0);Also applies to: 102-107, 131-135, 157-161
🤖 Prompt for AI Agents
In src/tl_templates/dcu_hip/reduce.h around lines 89-99 (and similarly adjust
102-107, 131-135, 157-161), fix the CumSum2D routine to: change the function
signature return type from T to void since it writes into dst; replace any early
"return;" inside the row-bounds check with "continue;" to let threads proceed to
next iteration; use a 32-bit mask type (uint32_t) for HIP shuffle/warp
intrinsics instead of uint64_t; and when calling __shfl_sync pass integer lane
indices (e.g., int lane or int offset) rather than T-typed zeros so the shuffle
arguments are integer types. Ensure these changes are applied consistently in
the other referenced line ranges.
| B.T.to(torch.float32)).to(getattr(torch, out_dtype)) | ||
| elif a_transposed and not b_transposed: | ||
| # Get Reference Result | ||
| ref_c = torch.matmul(A.Tto(torch.float32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix syntax error in transpose operation.
Line 224 contains a typo: A.Tto(torch.float32) should be A.T.to(torch.float32). This will cause a runtime AttributeError.
- ref_c = torch.matmul(A.Tto(torch.float32),
+ ref_c = torch.matmul(A.T.to(torch.float32),📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ref_c = torch.matmul(A.Tto(torch.float32), | |
| ref_c = torch.matmul(A.T.to(torch.float32), |
🤖 Prompt for AI Agents
In testing/python/dcu/test_tilelang_gemm_mmac_intrinsic.py around line 224,
there's a typo in the transpose/convert chain: replace the incorrect attribute
call `A.Tto(torch.float32)` with `A.T.to(torch.float32)` so the tensor is first
transposed (A.T) and then converted to float32 (A.T.to(...)); update that call
accordingly to avoid the AttributeError.
|
|
||
| @tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True) | ||
| def get_rocm_arch(rocm_path="/opt/rocm"): | ||
| def get_rocm_arch(rocm_path="/opt/dtk"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Breaking change: default path breaks standard ROCm installations.
Changing the default from /opt/rocm to /opt/dtk will cause existing ROCm users to fall back to the default gfx900 architecture instead of detecting their actual GPU. This could lead to suboptimal performance or compatibility issues.
Consider trying multiple paths in order to support both standard ROCm and Hygon DCU installations:
-def get_rocm_arch(rocm_path="/opt/dtk"):
+def get_rocm_arch(rocm_path=None):
"""Utility function to get the AMD GPU architecture
Parameters
----------
rocm_path : str
The path to rocm installation directory
Returns
-------
gpu_arch : str
The AMD GPU architecture
"""
gpu_arch = "gfx900"
+ # Try multiple common installation paths if not specified
+ if rocm_path is None:
+ for candidate in ["/opt/dtk", "/opt/rocm"]:
+ if os.path.exists(candidate):
+ rocm_path = candidate
+ break
+ else:
+ print("ROCm/DTK not detected, using default gfx900")
+ return gpu_arch
+
# check if rocm is installed
if not os.path.exists(rocm_path):
- print("ROCm not detected, using default gfx900")
+ print(f"ROCm/DTK not found at {rocm_path}, using default gfx900")
return gpu_archCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In tilelang/contrib/rocm.py around line 230, the function get_rocm_arch
currently defaults to "/opt/dtk", which breaks standard ROCm installs; change
the implementation to try multiple install roots in order (e.g. ['/opt/rocm',
'/opt/dtk']) rather than a single hardcoded default: accept an optional path or
list, iterate over the candidate roots, attempt detection against each root
until one succeeds, fall back to existing gfx900 behavior only if none match,
and update the function signature/docstring to reflect that it checks multiple
standard locations.
This PR introduces initial support for running tilelang on the Hygon DCU backend (tested on the BW200 platform). The key changes include:
Summary by CodeRabbit
New Features
Chores