@@ -14,6 +14,7 @@ const map_ptx_to_jl_array = Dict(
14
14
" s8" => Int8,
15
15
" s32" => Int32,
16
16
" f16" => Float16,
17
+ " tf32" => Float32,
17
18
" f32" => Float32
18
19
)
19
20
@@ -23,6 +24,7 @@ const map_ptx_to_jl_frag = Dict(
23
24
" s8" => UInt32,
24
25
" s32" => Int32,
25
26
" f16" => NTuple{2 , VecElement{Float16}},
27
+ " tf32" => Float32,
26
28
" f32" => Float32
27
29
)
28
30
@@ -40,6 +42,8 @@ const map_frag_sizes = Dict(
40
42
" a.f16.m16n16k16" => 8 ,
41
43
" a.f16.m8n32k16" => 8 ,
42
44
" a.f16.m32n8k16" => 8 ,
45
+
46
+ " a.tf32.m16n16k8" => 8 ,
43
47
# B
44
48
" b.u8.m16n16k16" => 2 ,
45
49
" b.u8.m8n32k16" => 4 ,
@@ -52,6 +56,8 @@ const map_frag_sizes = Dict(
52
56
" b.f16.m16n16k16" => 8 ,
53
57
" b.f16.m8n32k16" => 8 ,
54
58
" b.f16.m32n8k16" => 8 ,
59
+
60
+ " b.tf32.m16n16k8" => 8 ,
55
61
# C
56
62
" c.s32.m16n16k16" => 8 ,
57
63
" c.s32.m8n32k16" => 8 ,
@@ -64,6 +70,8 @@ const map_frag_sizes = Dict(
64
70
" c.f32.m16n16k16" => 8 ,
65
71
" c.f32.m8n32k16" => 8 ,
66
72
" c.f32.m32n8k16" => 8 ,
73
+
74
+ " c.f32.m16n16k8" => 8 ,
67
75
# D
68
76
" d.s32.m16n16k16" => 8 ,
69
77
" d.s32.m8n32k16" => 8 ,
@@ -76,6 +84,8 @@ const map_frag_sizes = Dict(
76
84
" d.f32.m16n16k16" => 8 ,
77
85
" d.f32.m8n32k16" => 8 ,
78
86
" d.f32.m32n8k16" => 8 ,
87
+
88
+ " d.f32.m16n16k8" => 8 ,
79
89
)
80
90
81
91
# Maps PTX AS to CUDA.AS
@@ -87,6 +97,10 @@ const map_ptx_as_to_as_ty = Dict(
87
97
88
98
# Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type
89
99
100
+ # TF32-Precision Floating Point
101
+ const ldst_tf32_ab_ops = [(16 ,16 ,8 )], [" a" , " b" ], [" tf32" ]
102
+ const ldst_tf32_cd_ops = [(16 ,16 ,8 )], [" c" , " d" ], [" f32" ]
103
+ const wmma_tf32_ops = [(16 ,16 ,8 )], [" tf32" ], [" f32" ], [" f32" ]
90
104
# Half-Precision Floating Point
91
105
const ldst_half_ab_ops = [(16 ,16 ,16 ), (32 ,8 ,16 ), (8 ,32 ,16 )], [" a" , " b" ], [" f16" ]
92
106
const ldst_half_cd_ops = [(16 ,16 ,16 ), (32 ,8 ,16 ), (8 ,32 ,16 )], [" c" , " d" ], [" f16" , " f32" ]
@@ -97,11 +111,12 @@ const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
97
111
const wmma_int_ops = [(16 ,16 ,16 ), (32 ,8 ,16 ), (8 ,32 ,16 )], [" s8" , " u8" ], [" s32" ], [" s32" ]
98
112
99
113
const all_ldst_ops = vcat (ldst_half_ab_ops, ldst_half_cd_ops,
100
- ldst_int_ab_ops, ldst_int_cd_ops)
101
- const all_wmma_ops = vcat (wmma_half_ops, wmma_int_ops)
114
+ ldst_int_ab_ops, ldst_int_cd_ops,
115
+ ldst_tf32_ab_ops, ldst_tf32_cd_ops)
116
+ const all_wmma_ops = vcat (wmma_half_ops, wmma_int_ops, wmma_tf32_ops)
102
117
103
118
# Valid WMMA operation shapes
104
- const valid_shapes = [(16 , 16 , 16 ), (32 , 8 , 16 ), (8 , 32 , 16 )]
119
+ const valid_shapes = [(16 , 16 , 16 ), (32 , 8 , 16 ), (8 , 32 , 16 ), ( 16 , 16 , 8 ) ]
105
120
106
121
# ###############################################################################
107
122
# HELPER FUNCTIONS
0 commit comments