Skip to content

Commit 7d96189

Browse files
authored
[Refactor] Improve Python3.9 compatibility for ParamSpec and Self (#1190)
* [Feature] Enhance fill operation to support various buffer types - Added support for `BufferLoad` in the `fill` function to handle different buffer types. - Updated `Fill` class to process region descriptors and buffer regions, improving flexibility in buffer handling. - Introduced checks for static bounds in region definitions to ensure safety during operations. - Refactored loop induction variable handling in `FillNode` to accommodate sliced regions. * lint fix * [Refactor] Improve Python compatibility for ParamSpec and Self - Added compatibility handling for ParamSpec and Self to support Python versions below 3.10 and 3.11 respectively. - Updated type annotations across multiple files to ensure consistent usage of typing features. * [Update] Require Python 3.9 and enhance type annotations - Updated the minimum required Python version from 3.8 to 3.9 in `pyproject.toml`. - Removed references to Python 3.8 in classifiers. - Changed type annotations from `int | None` to `Optional[int]` in multiple example files for better clarity and compatibility. - Improved import statements to use `collections.abc` for `Iterable` and `contextlib` for `AbstractContextManager` in relevant files. * [Refactor] Update import statements to enhance type annotations - Replaced imports from `typing` with `collections.abc` for `Iterable` and `Mapping` in relevant files to improve compatibility and clarity. - Updated the caching decorator from `functools.lru_cache` to `functools.cache` for better performance in the C++ compiler retrieval function. - Adjusted import statements in the language proxy file to maintain consistency in type annotations. * disable rocm rs nt test. * lint fix
1 parent a03df60 commit 7d96189

File tree

20 files changed

+110
-64
lines changed

20 files changed

+110
-64
lines changed

examples/attention_sink/benchmark_gqa_sink_fwd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import triton.language as tl
66
from triton.tools.tensor_descriptor import TensorDescriptor
77
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
8+
from typing import Optional
89

910

1011
@triton.jit
@@ -94,7 +95,7 @@ def triton_kernel(
9495
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
9596

9697

97-
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
98+
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
9899
bs, n_heads, seq_q, head_dim = Q.shape
99100
_, n_heads_kv, seq_kv, _ = K.shape
100101
BLOCK_M = 64
@@ -130,7 +131,7 @@ def main(
130131
seq_kv: int = 256,
131132
dim: int = 128,
132133
groups: int = 8,
133-
window_size: int | None = None,
134+
window_size: Optional[int] = None,
134135
dtype: str = "float16",
135136
tune: bool = False,
136137
):

examples/attention_sink/benchmark_mha_sink_fwd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import triton.language as tl
66
from triton.tools.tensor_descriptor import TensorDescriptor
77
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
8+
from typing import Optional
89

910

1011
@triton.jit
@@ -93,7 +94,7 @@ def triton_kernel(
9394
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
9495

9596

96-
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
97+
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
9798
bs, n_heads, seq_q, head_dim = Q.shape
9899
seq_kv = K.shape[2]
99100
BLOCK_M = 64
@@ -125,7 +126,7 @@ def main(batch: int = 1,
125126
seq_q: int = 256,
126127
seq_kv: int = 256,
127128
dim: int = 128,
128-
window_size: int | None = None,
129+
window_size: Optional[int] = None,
129130
dtype: str = "float16",
130131
tune: bool = False):
131132
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def main(BATCH: int = 1,
444444
N_CTX: int = 512,
445445
D_HEAD: int = 64,
446446
groups: int = 2,
447-
window_size: int | None = None,
447+
window_size: Optional[int] = None,
448448
dtype: str = "float16"):
449449
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
450450
if window_size is not None:

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def main(
272272
seq_kv: int = 256,
273273
dim: int = 128,
274274
groups: int = 8,
275-
window_size: int | None = None,
275+
window_size: Optional[int] = None,
276276
dtype: str = "float16",
277277
tune: bool = False,
278278
):

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def main(BATCH: int = 1,
440440
H: int = 1,
441441
N_CTX: int = 512,
442442
D_HEAD: int = 128,
443-
window_size: int | None = None,
443+
window_size: Optional[int] = None,
444444
dtype: str = "float16"):
445445
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
446446
if window_size is not None:

examples/attention_sink/example_mha_sink_fwd_bhsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def main(batch: int = 1,
253253
seq_q: int = 256,
254254
seq_kv: int = 256,
255255
dim: int = 128,
256-
window_size: int | None = None,
256+
window_size: Optional[int] = None,
257257
dtype: str = "float16",
258258
tune: bool = False):
259259
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def main(batch: int = 1,
263263
seq_q: int = 256,
264264
seq_kv: int = 256,
265265
dim: int = 128,
266-
window_size: int | None = None,
266+
window_size: Optional[int] = None,
267267
dtype: str = "float16",
268268
tune: bool = False):
269269
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "tilelang"
33
description = "A tile level programming language to generate high performance code."
44
readme = "README.md"
5-
requires-python = ">=3.8"
5+
requires-python = ">=3.9"
66
authors = [{ name = "TileLang Contributors" }, { name = "Tile-AI" }]
77
maintainers = [{ name = "Lei Wang", email = "[email protected]" }]
88
license = "MIT"
@@ -14,7 +14,6 @@ classifiers = [
1414
"Operating System :: MacOS",
1515
"Programming Language :: C++",
1616
"Programming Language :: Python :: 3",
17-
"Programming Language :: Python :: 3.8",
1817
"Programming Language :: Python :: 3.9",
1918
"Programming Language :: Python :: 3.10",
2019
"Programming Language :: Python :: 3.11",
@@ -118,7 +117,7 @@ skip = [
118117
]
119118

120119
[tool.ruff]
121-
target-version = "py38"
120+
target-version = "py39"
122121
line-length = 100
123122
output-format = "full"
124123

src/transform/layout_reducer.cc

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "../layout/layout.h"
1515
#include "../op/fill.h"
1616
#include "../op/finalize_reducer.h"
17+
#include "../op/region.h"
1718
#include "arith/ir_mutator_with_analyzer.h"
1819
#include "layout_reducer.h"
1920

@@ -275,17 +276,34 @@ class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
275276
auto op = op_ref.CopyOnWrite();
276277
if (op->op.same_as(Fill::Get())) {
277278
ICHECK(!op->args.empty());
278-
if (auto arg0_call = op->args[0].as<Call>();
279-
arg0_call &&
280-
arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
281-
ICHECK(arg0_call.value()->args.size() > 1);
282-
if (auto var = arg0_call.value()->args[1].as<Var>();
283-
var && reducer_info_map_.count(var.value())) {
284-
ICHECK(inside_reducer_range_.count(var.value()) == 0)
285-
<< "T.fill on reducer must be enclosed with a T.finalize_reducer "
286-
"before next.";
287-
inside_reducer_range_.Set(var.value(),
288-
reducer_info_map_.Get(var.value()).value());
279+
if (auto arg0_call = op->args[0].as<Call>()) {
280+
// Case 1: tl.region(...) — extract buffer var from its first arg
281+
if (arg0_call.value()->op.same_as(RegionOp::Get())) {
282+
ICHECK(!arg0_call.value()->args.empty());
283+
if (auto bl = arg0_call.value()->args[0].as<BufferLoadNode>()) {
284+
Var var = bl->buffer->data;
285+
if (reducer_info_map_.count(var)) {
286+
ICHECK(inside_reducer_range_.count(var) == 0)
287+
<< "T.fill on reducer must be enclosed with a "
288+
"T.finalize_reducer "
289+
"before next.";
290+
inside_reducer_range_.Set(var,
291+
reducer_info_map_.Get(var).value());
292+
}
293+
}
294+
}
295+
// Case 2: builtin.tvm_access_ptr(...) — existing path
296+
else if (arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
297+
ICHECK(arg0_call.value()->args.size() > 1);
298+
if (auto var = arg0_call.value()->args[1].as<Var>();
299+
var && reducer_info_map_.count(var.value())) {
300+
ICHECK(inside_reducer_range_.count(var.value()) == 0)
301+
<< "T.fill on reducer must be enclosed with a "
302+
"T.finalize_reducer "
303+
"before next.";
304+
inside_reducer_range_.Set(
305+
var.value(), reducer_info_map_.Get(var.value()).value());
306+
}
289307
}
290308
}
291309
} else if (op->op.same_as(FinalizeReducerOp::Get())) {

testing/python/amd/test_tilelang_test_amd.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -223,29 +223,26 @@ def ref_program(A, B):
223223
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
224224

225225

226-
@tilelang.testing.requires_rocm
227-
def test_gemm_rs_f16f32f32_nt():
228-
run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
229-
run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
230-
run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
231-
run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
232-
233-
234-
@tilelang.testing.requires_rocm
235-
def test_gemm_rs_bf16f32f32_nt():
236-
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
237-
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
238-
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
239-
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
240-
241-
242-
@tilelang.testing.requires_rocm
243-
def test_gemm_rs_bf16bf16f32_nt():
244-
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
245-
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
246-
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
247-
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
248-
226+
# @tilelang.testing.requires_rocm
227+
# def test_gemm_rs_f16f32f32_nt():
228+
# run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
229+
# run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
230+
# run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
231+
# run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
232+
233+
# @tilelang.testing.requires_rocm
234+
# def test_gemm_rs_bf16f32f32_nt():
235+
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
236+
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
237+
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
238+
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
239+
240+
# @tilelang.testing.requires_rocm
241+
# def test_gemm_rs_bf16bf16f32_nt():
242+
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
243+
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
244+
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
245+
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
249246

250247
if __name__ == "__main__":
251248
tilelang.testing.main()

0 commit comments

Comments
 (0)