Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions testing/python/language/test_tilelang_language_frontend_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,5 +273,35 @@ def foo() -> T.Tensor((128,), T.float32):
assert isinstance(foo, T.PrimFunc)


def test_swap_logic():

@tilelang.jit
@T.prim_func
def swap_var(A: T.Tensor[(2,), T.float32]):
with T.Kernel(1, threads=1) as _:
a = T.alloc_var(T.float32, A[0])
b = T.alloc_var(T.float32, A[1])
a, b = b, a
A[0], A[1] = a, b

@tilelang.jit
@T.prim_func
def swap_idx(A: T.Tensor[(2,), T.float32]):
with T.Kernel(1, threads=1) as _:
A[0], A[1] = A[1], A[0]

k_swap_var = swap_var()
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
k_swap_var(data)
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
torch.testing.assert_close(data, ref)

k_swap_idx = swap_idx()
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
k_swap_idx(data)
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
torch.testing.assert_close(data, ref)
Comment on lines +294 to +303
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Skip when CUDA is unavailable

Line 294 calls .cuda() without checking torch.cuda.is_available(). On CPU-only machines (typical CI runners), this raises AssertionError: Torch not compiled with CUDA, causing the suite to fail even though the swap logic itself is valid. Please guard the test so it skips when CUDA isn’t present.

@@
 def test_swap_logic():
+    if not torch.cuda.is_available():
+        import pytest
+        pytest.skip("CUDA is required for test_swap_logic")
@@
     k_swap_var = swap_var()
     data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_frontend_v2.py around lines
294 to 303, the test unconditionally calls .cuda() which fails on machines
without CUDA; guard the CUDA-dependent assertions by checking
torch.cuda.is_available() and skip the CUDA-specific section when it's False
(e.g., call pytest.skip("CUDA unavailable") or add a pytest.mark.skipif
decorator), or alternatively run the same tensor ops on CPU when CUDA is not
present so the swap logic is still validated without requiring GPU.



if __name__ == '__main__':
tilelang.testing.main()
25 changes: 25 additions & 0 deletions tilelang/language/v2/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ def _emit_assign_target(self,
span=target,
)
else:

# flatten nested tuple into a list of (tmp_name, target)
unpacked = []

def _visit_target(target: ast.expr) -> str:
Expand All @@ -360,6 +362,9 @@ def _visit_target(target: ast.expr) -> str:
res = ast.Tuple(elts=elts, ctx=target.ctx)
ast_set_span(res, ast_get_span(target))
return res
else:
s = ast.unparse(target)
raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`')

unpack_stmt = ast.Assign(
targets=[_visit_target(target)],
Expand All @@ -376,6 +381,26 @@ def flush_binds():
bind_lvals.clear()
bind_rvals.clear()

# the following code generate two phase binding to support swap like semantics
# for example:
# a, b = b, a
# 1 phase:
# _tmp_0, _tmp_1 = b, a
# => _tmp_0: T.int32 = b
# => _tmp_1: T.int32 = a
# 2 phase:
# a, b = _tmp_0, _tmp_1
# => a = _tmp_0 => a[0] = _tmp_0
# => b = _tmp_1 => b[0] = _tmp_1

# 1 phase: _tmp_0, _tmp_1 = __tb.bind('_', a), __tb.bind('_', b)
for tmp, _target in unpacked:
bind_lvals.append(tmp)
bind_rvals.append(f'__tb.bind("_", {tmp})')

flush_binds()

# 2 phase: a, b = __tb.bind('a', _tmp_0), __tb.bind('b', _tmp_1)
for tmp, target in unpacked:
if isinstance(target, ast.Name):
bind_lvals.append(target.id)
Expand Down
3 changes: 3 additions & 0 deletions tilelang/language/v2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,9 @@ def unwrap_value(self, value):
return value

def bind_immutable(self, name, value):
if name == '_':
# use _tmp to make the generated tir more readable
name = "_tmp"
if isinstance(value, tir.meta_var):
return value.value
elif isinstance(value, tir.frame.IRBuilderFrame):
Expand Down
10 changes: 5 additions & 5 deletions tilelang/language/v2/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from tilelang import tvm
from tvm import ir
import tvm_ffi
import torch
import ctypes
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -100,16 +99,17 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var
return call(expr, is_size_var)


__orig_dtype_new = dtype.__new__


def __dtype_new__(cls, value: AnyDType) -> dtype:
if isinstance(value, str):
val = str.__new__(cls, value)
return __orig_dtype_new(cls, value)
elif value in _dtype_py2tvmstr:
val = str.__new__(cls, _dtype_py2tvmstr[value])
return __orig_dtype_new(cls, _dtype_py2tvmstr[value])
else:
expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values()))
raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}")
val.__tvm_ffi_dtype__ = tvm_ffi.core.DataType(val)
return val


dtype.__eq__ = __dtype_eq__
Expand Down
Loading