diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index b4ca94232..da6e8e4b6 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -273,5 +273,35 @@ def foo() -> T.Tensor((128,), T.float32): assert isinstance(foo, T.PrimFunc) +def test_swap_logic(): + + @tilelang.jit + @T.prim_func + def swap_var(A: T.Tensor[(2,), T.float32]): + with T.Kernel(1, threads=1) as _: + a = T.alloc_var(T.float32, A[0]) + b = T.alloc_var(T.float32, A[1]) + a, b = b, a + A[0], A[1] = a, b + + @tilelang.jit + @T.prim_func + def swap_idx(A: T.Tensor[(2,), T.float32]): + with T.Kernel(1, threads=1) as _: + A[0], A[1] = A[1], A[0] + + k_swap_var = swap_var() + data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda() + k_swap_var(data) + ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda() + torch.testing.assert_close(data, ref) + + k_swap_idx = swap_idx() + data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda() + k_swap_idx(data) + ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda() + torch.testing.assert_close(data, ref) + + if __name__ == '__main__': tilelang.testing.main() diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py index 34e74d64b..0e778fbc0 100644 --- a/tilelang/language/v2/ast.py +++ b/tilelang/language/v2/ast.py @@ -346,6 +346,8 @@ def _emit_assign_target(self, span=target, ) else: + + # flatten nested tuple into a list of (tmp_name, target) unpacked = [] def _visit_target(target: ast.expr) -> str: @@ -360,6 +362,9 @@ def _visit_target(target: ast.expr) -> str: res = ast.Tuple(elts=elts, ctx=target.ctx) ast_set_span(res, ast_get_span(target)) return res + else: + s = ast.unparse(target) + raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`') unpack_stmt = ast.Assign( targets=[_visit_target(target)], @@ -376,6 +381,26 @@ def flush_binds(): bind_lvals.clear() bind_rvals.clear() + # the following code generate two phase binding to support swap like semantics + # for example: + # a, b = b, a + # 1 phase: + # _tmp_0, _tmp_1 = b, a + # => _tmp_0: T.int32 = b + # => _tmp_1: T.int32 = a + # 2 phase: + # a, b = _tmp_0, _tmp_1 + # => a = _tmp_0 => a[0] = _tmp_0 + # => b = _tmp_1 => b[0] = _tmp_1 + + # 1 phase: _tmp_0, _tmp_1 = __tb.bind('_', a), __tb.bind('_', b) + for tmp, _target in unpacked: + bind_lvals.append(tmp) + bind_rvals.append(f'__tb.bind("_", {tmp})') + + flush_binds() + + # 2 phase: a, b = __tb.bind('a', _tmp_0), __tb.bind('b', _tmp_1) for tmp, target in unpacked: if isinstance(target, ast.Name): bind_lvals.append(target.id) diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 3bae9ecd1..4ea91f40f 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -313,6 +313,9 @@ def unwrap_value(self, value): return value def bind_immutable(self, name, value): + if name == '_': + # use _tmp to make the generated tir more readable + name = "_tmp" if isinstance(value, tir.meta_var): return value.value elif isinstance(value, tir.frame.IRBuilderFrame): diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index def59845b..39ea90f81 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -1,6 +1,5 @@ from tilelang import tvm from tvm import ir -import tvm_ffi import torch import ctypes from typing import TYPE_CHECKING @@ -100,16 +99,17 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var return call(expr, is_size_var) +__orig_dtype_new = dtype.__new__ + + def __dtype_new__(cls, value: AnyDType) -> dtype: if isinstance(value, str): - val = str.__new__(cls, value) + return __orig_dtype_new(cls, value) elif value in _dtype_py2tvmstr: - val = str.__new__(cls, _dtype_py2tvmstr[value]) + return __orig_dtype_new(cls, _dtype_py2tvmstr[value]) else: expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values())) raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}") - val.__tvm_ffi_dtype__ = tvm_ffi.core.DataType(val) - return val dtype.__eq__ = __dtype_eq__