Skip to content

Commit 055f850

Browse files
authored
[Feat] Add swap like grammar in tuple assignment (#1185)
* [Feat] add 2 phase binding to allow swap two var * Minor update tvm dtype constructor * fix lint error
1 parent 7d96189 commit 055f850

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

testing/python/language/test_tilelang_language_frontend_v2.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,5 +273,35 @@ def foo() -> T.Tensor((128,), T.float32):
273273
assert isinstance(foo, T.PrimFunc)
274274

275275

276+
def test_swap_logic():
277+
278+
@tilelang.jit
279+
@T.prim_func
280+
def swap_var(A: T.Tensor[(2,), T.float32]):
281+
with T.Kernel(1, threads=1) as _:
282+
a = T.alloc_var(T.float32, A[0])
283+
b = T.alloc_var(T.float32, A[1])
284+
a, b = b, a
285+
A[0], A[1] = a, b
286+
287+
@tilelang.jit
288+
@T.prim_func
289+
def swap_idx(A: T.Tensor[(2,), T.float32]):
290+
with T.Kernel(1, threads=1) as _:
291+
A[0], A[1] = A[1], A[0]
292+
293+
k_swap_var = swap_var()
294+
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
295+
k_swap_var(data)
296+
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
297+
torch.testing.assert_close(data, ref)
298+
299+
k_swap_idx = swap_idx()
300+
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
301+
k_swap_idx(data)
302+
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
303+
torch.testing.assert_close(data, ref)
304+
305+
276306
if __name__ == '__main__':
277307
tilelang.testing.main()

tilelang/language/v2/ast.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,8 @@ def _emit_assign_target(self,
353353
span=target,
354354
)
355355
else:
356+
357+
# flatten nested tuple into a list of (tmp_name, target)
356358
unpacked = []
357359

358360
def _visit_target(target: ast.expr) -> str:
@@ -367,6 +369,9 @@ def _visit_target(target: ast.expr) -> str:
367369
res = ast.Tuple(elts=elts, ctx=target.ctx)
368370
ast_set_span(res, ast_get_span(target))
369371
return res
372+
else:
373+
s = ast.unparse(target)
374+
raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`')
370375

371376
unpack_stmt = ast.Assign(
372377
targets=[_visit_target(target)],
@@ -383,6 +388,26 @@ def flush_binds():
383388
bind_lvals.clear()
384389
bind_rvals.clear()
385390

391+
# the following code generate two phase binding to support swap like semantics
392+
# for example:
393+
# a, b = b, a
394+
# 1 phase:
395+
# _tmp_0, _tmp_1 = b, a
396+
# => _tmp_0: T.int32 = b
397+
# => _tmp_1: T.int32 = a
398+
# 2 phase:
399+
# a, b = _tmp_0, _tmp_1
400+
# => a = _tmp_0 => a[0] = _tmp_0
401+
# => b = _tmp_1 => b[0] = _tmp_1
402+
403+
# 1 phase: _tmp_0, _tmp_1 = __tb.bind('_', a), __tb.bind('_', b)
404+
for tmp, _target in unpacked:
405+
bind_lvals.append(tmp)
406+
bind_rvals.append(f'__tb.bind("_", {tmp})')
407+
408+
flush_binds()
409+
410+
# 2 phase: a, b = __tb.bind('a', _tmp_0), __tb.bind('b', _tmp_1)
386411
for tmp, target in unpacked:
387412
if isinstance(target, ast.Name):
388413
bind_lvals.append(target.id)

tilelang/language/v2/builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,9 @@ def unwrap_value(self, value):
320320
return value
321321

322322
def bind_immutable(self, name, value):
323+
if name == '_':
324+
# use _tmp to make the generated tir more readable
325+
name = "_tmp"
323326
if isinstance(value, tir.meta_var):
324327
return value.value
325328
elif isinstance(value, tir.frame.IRBuilderFrame):

0 commit comments

Comments
 (0)