Skip to content

WMMA TensorFloat32 (TF32) #1419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions src/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const map_ptx_to_jl_array = Dict(
"s8" => Int8,
"s32" => Int32,
"f16" => Float16,
"tf32" => Float32,
"f32" => Float32
)

Expand All @@ -23,6 +24,7 @@ const map_ptx_to_jl_frag = Dict(
"s8" => UInt32,
"s32" => Int32,
"f16" => NTuple{2, VecElement{Float16}},
"tf32" => Float32,
"f32" => Float32
)

Expand All @@ -36,10 +38,12 @@ const map_frag_sizes = Dict(
"a.s8.m16n16k16" => 2,
"a.s8.m8n32k16" => 1,
"a.s8.m32n8k16" => 4,

"a.f16.m16n16k16" => 8,
"a.f16.m8n32k16" => 8,
"a.f16.m32n8k16" => 8,

"a.tf32.m16n16k8" => 4,
# B
"b.u8.m16n16k16" => 2,
"b.u8.m8n32k16" => 4,
Expand All @@ -52,7 +56,9 @@ const map_frag_sizes = Dict(
"b.f16.m16n16k16" => 8,
"b.f16.m8n32k16" => 8,
"b.f16.m32n8k16" => 8,
# C

"b.tf32.m16n16k8" => 4,
# C
"c.s32.m16n16k16" => 8,
"c.s32.m8n32k16" => 8,
"c.s32.m32n8k16" => 8,
Expand All @@ -64,6 +70,7 @@ const map_frag_sizes = Dict(
"c.f32.m16n16k16" => 8,
"c.f32.m8n32k16" => 8,
"c.f32.m32n8k16" => 8,
"c.f32.m16n16k8" => 8,
# D
"d.s32.m16n16k16" => 8,
"d.s32.m8n32k16" => 8,
Expand All @@ -76,6 +83,7 @@ const map_frag_sizes = Dict(
"d.f32.m16n16k16" => 8,
"d.f32.m8n32k16" => 8,
"d.f32.m32n8k16" => 8,
"d.f32.m16n16k8" => 8,
)

# Maps PTX AS to CUDA.AS
Expand All @@ -87,6 +95,10 @@ const map_ptx_as_to_as_ty = Dict(

# Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type

# TF32-Precision Floating Point
const ldst_tf32_ab_ops = [(16, 16, 8)], ["a", "b"], ["tf32"]
const ldst_tf32_cd_ops = [(16, 16, 8)], ["c", "d"], ["f32"]
const wmma_tf32_ops = [(16, 16, 8)], ["tf32"], ["f32"], ["f32"]
# Half-Precision Floating Point
const ldst_half_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["f16"]
const ldst_half_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["f16", "f32"]
Expand All @@ -97,11 +109,12 @@ const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"]

const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
ldst_int_ab_ops, ldst_int_cd_ops)
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops)
ldst_int_ab_ops, ldst_int_cd_ops,
ldst_tf32_ab_ops, ldst_tf32_cd_ops)
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_tf32_ops)

# Valid WMMA operation shapes
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (16, 16, 8)]

################################################################################
# HELPER FUNCTIONS
Expand Down Expand Up @@ -160,10 +173,10 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.load.{matrix}.sync.{layout}.{
# Placeholders
- `{matrix}`: The matrix to load. Can be `a`, `b` or `c`.
- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively.
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`.
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k8`, `m16n16k16`, `m32n8k16`, and `m8n32k16`.
- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`.
- `{elem_type}`: The type of each element in the matrix. For `a` and `b` matrices, valid values are `u8` (byte unsigned integer),
`s8` (byte signed integer), and `f16` (half precision floating point). For `c` and `d` matrices, valid values are
`s8` (byte signed integer), `f16` (half precision floating point), and `tf32` (TensorFloat-32). For `c` and `d` matrices, valid values are
`s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).
"""
llvm_wmma_load() = error("Cannot call llvm_wmma_load without values for placeholders!")
Expand Down Expand Up @@ -217,10 +230,10 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.store.d.sync.{layout}.{shape}

# Placeholders
- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively.
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`.
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k8`, `m16n16k16`, `m32n8k16`, and `m8n32k16`.
- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`.
- `{elem_type}`: The type of each element in the matrix. For `a` and `b` matrices, valid values are `u8` (byte unsigned integer),
`s8` (byte signed integer), and `f16` (half precision floating point). For `c` and `d` matrices, valid values are
`s8` (byte signed integer), `f16` (half precision floating point), and `tf32` (TensorFloat-32). For `c` and `d` matrices, valid values are
`s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).
"""
llvm_wmma_store() = error("Cannot call llvm_wmma_store without values for placeholders!")
Expand Down Expand Up @@ -283,8 +296,8 @@ For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma
# Placeholders
- `{a_layout}`: The storage layout for matrix ``A``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation.
- `{b_layout}`: The storage layout for matrix ``B``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation.
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`.
- `{a_elem_type}`: The type of each element in the ``A`` matrix. Valid values are `u8` (byte unsigned integer), `s8` (byte signed integer), and `f16` (half precision floating point).
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k8`, `m16n16k16`, `m32n8k16`, and `m8n32k16`.
- `{a_elem_type}`: The type of each element in the ``A`` matrix. Valid values are `u8` (byte unsigned integer), `s8` (byte signed integer), `f16` (half precision floating point), and `tf32` (TensorFloat-32).
- `{d_elem_type}`: The type of each element in the resultant ``D`` matrix. Valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).
- `{c_elem_type}`: The type of each element in the ``C`` matrix. Valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).

Expand All @@ -309,7 +322,7 @@ for ops in all_wmma_ops,

# Name of the LLVM intrinsic
# If integer/sub-byte/bit A/B types, name is determined by A/B types
if d_elem_type == "s32"
if d_elem_type == "s32" || a_elem_type == "tf32"
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type"
# Name of the Julia wrapper function
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_"))
Expand Down Expand Up @@ -647,7 +660,7 @@ mma
_, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape)
d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape)



# Name of the Julia wrapper
wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str]), "_"))
Expand Down
36 changes: 26 additions & 10 deletions test/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ map_ptx_to_jl_frag = Dict(
"u32" => UInt32(42),
"s32" => Int32(42),
"f16" => ntuple(i -> VecElement{Float16}(42), 2),
"f32" => Float32(42)
)
"f32" => Float32(42),
"tf32" => Float32(42)
)
# Return specific matrix shape given operation configuration
function get_array_shape(mat, mnk, layout)
if !(mat in ["a","b","c","d"])
Expand Down Expand Up @@ -41,18 +42,23 @@ end
continue
end

if mnk == (16,16,8) && VERSION <= v"1.7"
# TensorFlow32 tests require at least Julia 1.8
continue
end

shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3])

# Type-dependent variables
array_ty = CUDA.WMMA.map_ptx_to_jl_array[elem_type]
expected = map_ptx_to_jl_frag[elem_type]

# Address-space dependent variables
do_shared_test = (addr_space == "_shared")

# Get the function name
func = Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)")

input_shape = get_array_shape(mat, mnk, layout)
input = array_ty(42) * ones(array_ty, input_shape)
input_dev = CuArray(input)
Expand Down Expand Up @@ -96,12 +102,17 @@ end
elem_type in ops[3],
addr_space in ["", "_global", "_shared"],
stride in ["stride"]

# Skip all but d matrices
if mat != "d"
continue
end

if mnk == (16,16,8) && VERSION <= v"1.7"
# TensorFlow32 tests require at least Julia 1.8
continue
end

shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3])

# Type-dependent variables
Expand Down Expand Up @@ -156,6 +167,11 @@ end
d_elem_type in ops[4],
c_elem_type in ops[3]

if mnk == (16,16,8) && VERSION <= v"1.7"
# TensorFlow32 tests require at least Julia 1.8
continue
end

# Type-dependent variables
d_ty = CUDA.WMMA.map_ptx_to_jl_array[d_elem_type]
c_ty = CUDA.WMMA.map_ptx_to_jl_array[c_elem_type]
Expand All @@ -169,9 +185,9 @@ end
ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_$(shape)_global_stride_$(c_elem_type)"))
# Account for half and int/subint mma different naming conventions
# Int/subint mma functions are distinguished by the a/b element type
mma_sym = d_ty == Int32 ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
mma_sym = (d_ty == Int32 || ab_elem_type == "tf32") ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)")
mma_func = getfield(Main, mma_sym)
mma_func = getfield(Main, mma_sym)
std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_$(shape)_global_stride_$(d_elem_type)"))

a_shape = get_array_shape("a", mnk, a_layout)
Expand Down Expand Up @@ -205,9 +221,9 @@ end
new_a = (a_layout == "col" ? a : transpose(a))
new_b = (b_layout == "col" ? b : transpose(b))
# Alter test depending on a/b element Type
if ab_ty == Float16
if ab_ty == Float16 || ab_elem_type == "tf32"
@test new_a * new_b + c ≈ Array(d_dev) rtol=Base.rtoldefault(Float16)
else # Cast a and b to prevent UInt8 rollover of resultant data
else # Cast a and b to prevent UInt8 rollover of resultant data
@test Int32.(new_a) * Int32.(new_b) + c == Array(d_dev)
end
end
Expand Down Expand Up @@ -344,4 +360,4 @@ end
@test !occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.f32", ptx)
@test occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.shared.f32", ptx)
end
end
end