diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index c12fe526ca..fe3ddbaf76 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -14,6 +14,7 @@ const map_ptx_to_jl_array = Dict( "s8" => Int8, "s32" => Int32, "f16" => Float16, + "tf32" => Float32, "f32" => Float32 ) @@ -23,6 +24,7 @@ const map_ptx_to_jl_frag = Dict( "s8" => UInt32, "s32" => Int32, "f16" => NTuple{2, VecElement{Float16}}, + "tf32" => Float32, "f32" => Float32 ) @@ -36,10 +38,12 @@ const map_frag_sizes = Dict( "a.s8.m16n16k16" => 2, "a.s8.m8n32k16" => 1, "a.s8.m32n8k16" => 4, - + "a.f16.m16n16k16" => 8, "a.f16.m8n32k16" => 8, "a.f16.m32n8k16" => 8, + + "a.tf32.m16n16k8" => 4, # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, @@ -52,7 +56,9 @@ const map_frag_sizes = Dict( "b.f16.m16n16k16" => 8, "b.f16.m8n32k16" => 8, "b.f16.m32n8k16" => 8, - # C + + "b.tf32.m16n16k8" => 4, + # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, "c.s32.m32n8k16" => 8, @@ -64,6 +70,7 @@ const map_frag_sizes = Dict( "c.f32.m16n16k16" => 8, "c.f32.m8n32k16" => 8, "c.f32.m32n8k16" => 8, + "c.f32.m16n16k8" => 8, # D "d.s32.m16n16k16" => 8, "d.s32.m8n32k16" => 8, @@ -76,6 +83,7 @@ const map_frag_sizes = Dict( "d.f32.m16n16k16" => 8, "d.f32.m8n32k16" => 8, "d.f32.m32n8k16" => 8, + "d.f32.m16n16k8" => 8, ) # Maps PTX AS to CUDA.AS @@ -87,6 +95,10 @@ const map_ptx_as_to_as_ty = Dict( # Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type +# TF32-Precision Floating Point +const ldst_tf32_ab_ops = [(16, 16, 8)], ["a", "b"], ["tf32"] +const ldst_tf32_cd_ops = [(16, 16, 8)], ["c", "d"], ["f32"] +const wmma_tf32_ops = [(16, 16, 8)], ["tf32"], ["f32"], ["f32"] # Half-Precision Floating Point const ldst_half_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["f16"] const ldst_half_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["f16", "f32"] @@ -97,11 +109,12 @@ const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, - ldst_int_ab_ops, ldst_int_cd_ops) -const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops) + ldst_int_ab_ops, ldst_int_cd_ops, + ldst_tf32_ab_ops, ldst_tf32_cd_ops) +const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_tf32_ops) # Valid WMMA operation shapes -const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)] +const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (16, 16, 8)] ################################################################################ # HELPER FUNCTIONS @@ -160,10 +173,10 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.load.{matrix}.sync.{layout}.{ # Placeholders - `{matrix}`: The matrix to load. Can be `a`, `b` or `c`. - `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. -- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`. +- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k8`, `m16n16k16`, `m32n8k16`, and `m8n32k16`. - `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. - `{elem_type}`: The type of each element in the matrix. For `a` and `b` matrices, valid values are `u8` (byte unsigned integer), - `s8` (byte signed integer), and `f16` (half precision floating point). For `c` and `d` matrices, valid values are + `s8` (byte signed integer), `f16` (half precision floating point), and `tf32` (TensorFloat-32). For `c` and `d` matrices, valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point). """ llvm_wmma_load() = error("Cannot call llvm_wmma_load without values for placeholders!") @@ -217,10 +230,10 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.store.d.sync.{layout}.{shape} # Placeholders - `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. -- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`. +- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k8`, `m16n16k16`, `m32n8k16`, and `m8n32k16`. - `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`. - `{elem_type}`: The type of each element in the matrix. For `a` and `b` matrices, valid values are `u8` (byte unsigned integer), - `s8` (byte signed integer), and `f16` (half precision floating point). For `c` and `d` matrices, valid values are + `s8` (byte signed integer), `f16` (half precision floating point), and `tf32` (TensorFloat-32). For `c` and `d` matrices, valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point). """ llvm_wmma_store() = error("Cannot call llvm_wmma_store without values for placeholders!") @@ -283,8 +296,8 @@ For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma # Placeholders - `{a_layout}`: The storage layout for matrix ``A``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. - `{b_layout}`: The storage layout for matrix ``B``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation. -- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`. -- `{a_elem_type}`: The type of each element in the ``A`` matrix. Valid values are `u8` (byte unsigned integer), `s8` (byte signed integer), and `f16` (half precision floating point). +- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k8`, `m16n16k16`, `m32n8k16`, and `m8n32k16`. +- `{a_elem_type}`: The type of each element in the ``A`` matrix. Valid values are `u8` (byte unsigned integer), `s8` (byte signed integer), `f16` (half precision floating point), and `tf32` (TensorFloat-32). - `{d_elem_type}`: The type of each element in the resultant ``D`` matrix. Valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point). - `{c_elem_type}`: The type of each element in the ``C`` matrix. Valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point). @@ -309,7 +322,7 @@ for ops in all_wmma_ops, # Name of the LLVM intrinsic # If integer/sub-byte/bit A/B types, name is determined by A/B types - if d_elem_type == "s32" + if d_elem_type == "s32" || a_elem_type == "tf32" llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type" # Name of the Julia wrapper function func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_")) @@ -647,7 +660,7 @@ mma _, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape) d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape) - + # Name of the Julia wrapper wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str]), "_")) diff --git a/test/device/intrinsics/wmma.jl b/test/device/intrinsics/wmma.jl index e0d47e34ca..4dfb149bc1 100644 --- a/test/device/intrinsics/wmma.jl +++ b/test/device/intrinsics/wmma.jl @@ -6,8 +6,9 @@ map_ptx_to_jl_frag = Dict( "u32" => UInt32(42), "s32" => Int32(42), "f16" => ntuple(i -> VecElement{Float16}(42), 2), - "f32" => Float32(42) - ) + "f32" => Float32(42), + "tf32" => Float32(42) + ) # Return specific matrix shape given operation configuration function get_array_shape(mat, mnk, layout) if !(mat in ["a","b","c","d"]) @@ -41,18 +42,23 @@ end continue end + if mnk == (16,16,8) && VERSION <= v"1.7" + # TensorFlow32 tests require at least Julia 1.8 + continue + end + shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3]) # Type-dependent variables array_ty = CUDA.WMMA.map_ptx_to_jl_array[elem_type] expected = map_ptx_to_jl_frag[elem_type] - + # Address-space dependent variables do_shared_test = (addr_space == "_shared") # Get the function name func = Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)") - + input_shape = get_array_shape(mat, mnk, layout) input = array_ty(42) * ones(array_ty, input_shape) input_dev = CuArray(input) @@ -96,12 +102,17 @@ end elem_type in ops[3], addr_space in ["", "_global", "_shared"], stride in ["stride"] - + # Skip all but d matrices if mat != "d" continue end + if mnk == (16,16,8) && VERSION <= v"1.7" + # TensorFlow32 tests require at least Julia 1.8 + continue + end + shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3]) # Type-dependent variables @@ -156,6 +167,11 @@ end d_elem_type in ops[4], c_elem_type in ops[3] + if mnk == (16,16,8) && VERSION <= v"1.7" + # TensorFlow32 tests require at least Julia 1.8 + continue + end + # Type-dependent variables d_ty = CUDA.WMMA.map_ptx_to_jl_array[d_elem_type] c_ty = CUDA.WMMA.map_ptx_to_jl_array[c_elem_type] @@ -169,9 +185,9 @@ end ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_$(shape)_global_stride_$(c_elem_type)")) # Account for half and int/subint mma different naming conventions # Int/subint mma functions are distinguished by the a/b element type - mma_sym = d_ty == Int32 ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") : + mma_sym = (d_ty == Int32 || ab_elem_type == "tf32") ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") : Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)") - mma_func = getfield(Main, mma_sym) + mma_func = getfield(Main, mma_sym) std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_$(shape)_global_stride_$(d_elem_type)")) a_shape = get_array_shape("a", mnk, a_layout) @@ -205,9 +221,9 @@ end new_a = (a_layout == "col" ? a : transpose(a)) new_b = (b_layout == "col" ? b : transpose(b)) # Alter test depending on a/b element Type - if ab_ty == Float16 + if ab_ty == Float16 || ab_elem_type == "tf32" @test new_a * new_b + c ≈ Array(d_dev) rtol=Base.rtoldefault(Float16) - else # Cast a and b to prevent UInt8 rollover of resultant data + else # Cast a and b to prevent UInt8 rollover of resultant data @test Int32.(new_a) * Int32.(new_b) + c == Array(d_dev) end end @@ -344,4 +360,4 @@ end @test !occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.f32", ptx) @test occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.shared.f32", ptx) end -end \ No newline at end of file +end