Skip to content

Commit c24234d

Browse files
committed
first naive attempt
1 parent ab81b56 commit c24234d

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

src/device/intrinsics/wmma.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ const map_ptx_to_jl_array = Dict(
1414
"s8" => Int8,
1515
"s32" => Int32,
1616
"f16" => Float16,
17+
"tf32" => Float32,
1718
"f32" => Float32
1819
)
1920

@@ -23,6 +24,7 @@ const map_ptx_to_jl_frag = Dict(
2324
"s8" => UInt32,
2425
"s32" => Int32,
2526
"f16" => NTuple{2, VecElement{Float16}},
27+
"tf32" => Float32,
2628
"f32" => Float32
2729
)
2830

@@ -40,6 +42,8 @@ const map_frag_sizes = Dict(
4042
"a.f16.m16n16k16" => 8,
4143
"a.f16.m8n32k16" => 8,
4244
"a.f16.m32n8k16" => 8,
45+
46+
"a.tf32.m16n16k8" => 8,
4347
# B
4448
"b.u8.m16n16k16" => 2,
4549
"b.u8.m8n32k16" => 4,
@@ -52,6 +56,8 @@ const map_frag_sizes = Dict(
5256
"b.f16.m16n16k16" => 8,
5357
"b.f16.m8n32k16" => 8,
5458
"b.f16.m32n8k16" => 8,
59+
60+
"b.tf32.m16n16k8" => 8,
5561
# C
5662
"c.s32.m16n16k16" => 8,
5763
"c.s32.m8n32k16" => 8,
@@ -64,6 +70,8 @@ const map_frag_sizes = Dict(
6470
"c.f32.m16n16k16" => 8,
6571
"c.f32.m8n32k16" => 8,
6672
"c.f32.m32n8k16" => 8,
73+
74+
"c.f32.m16n16k8" => 8,
6775
# D
6876
"d.s32.m16n16k16" => 8,
6977
"d.s32.m8n32k16" => 8,
@@ -76,6 +84,8 @@ const map_frag_sizes = Dict(
7684
"d.f32.m16n16k16" => 8,
7785
"d.f32.m8n32k16" => 8,
7886
"d.f32.m32n8k16" => 8,
87+
88+
"d.f32.m16n16k8" => 8,
7989
)
8090

8191
# Maps PTX AS to CUDA.AS
@@ -87,6 +97,10 @@ const map_ptx_as_to_as_ty = Dict(
8797

8898
# Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type
8999

100+
# TF32-Precision Floating Point
101+
const ldst_tf32_ab_ops = [(16,16,8)], ["a", "b"], ["tf32"]
102+
const ldst_tf32_cd_ops = [(16,16,8)], ["c", "d"], ["f32"]
103+
const wmma_tf32_ops = [(16,16,8)], ["tf32"], ["f32"], ["f32"]
90104
# Half-Precision Floating Point
91105
const ldst_half_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["f16"]
92106
const ldst_half_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["f16", "f32"]
@@ -97,11 +111,12 @@ const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
97111
const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"]
98112

99113
const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
100-
ldst_int_ab_ops, ldst_int_cd_ops)
101-
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops)
114+
ldst_int_ab_ops, ldst_int_cd_ops,
115+
ldst_tf32_ab_ops, ldst_tf32_cd_ops)
116+
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_tf32_ops)
102117

103118
# Valid WMMA operation shapes
104-
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]
119+
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16), (16,16,8)]
105120

106121
################################################################################
107122
# HELPER FUNCTIONS

0 commit comments

Comments
 (0)