Skip to content

Conversation

@kurisu6912
Copy link
Collaborator

@kurisu6912 kurisu6912 commented Nov 4, 2025

This pr introduce swap like grammar:

  @tilelang.jit
  @T.prim_func
  def swap_var(A: T.Tensor[(2,), T.float32]):
      with T.Kernel(1, threads=1) as _:
          a = T.alloc_var(T.float32, A[0])
          b = T.alloc_var(T.float32, A[1])
          a, b = b, a
          A[0], A[1] = a, b

  @tilelang.jit
  @T.prim_func
  def swap_idx(A: T.Tensor[(2,), T.float32]):
      with T.Kernel(1, threads=1) as _:
          A[0], A[1] = A[1], A[0]

When enter

a, b = b, a

tilelang generates:

# step 1: to support frame as rhs value, such as a, b = tl.Kernel(...)
tmp_0, tmp_1 = __tb.unwrap_value((b, a))
# step 2: temporally save the value to a let bind, handle when a, b are var
tmp_0, tmp_1 = __tb.bind('_', tmp_0), __tb.bind('_', tmp_1)
# step 3: assign tmp value to a and b
a, b = __tb.bind('a', a), __tb.bind('b', b)

Each step is required for the correctness of binding, for example, if we ignore step 2, and generate the following code:

a, b = __tb.bind('a', b), __tb.bind('b', a)
# __tb.bind('a', b) assign a[0] <= b[0]
# __tb.bind('b', a) assign b[0] <= a[0]

Is identical to the code below, that's wrong.

a[0] = b[0]
b[0] = a[0]

Summary by CodeRabbit

Release Notes

  • Refactor

    • Enhanced tuple unpacking and assignment operations with improved internal handling for complex swap scenarios.
    • Improved generated code readability with optimized temporary variable naming.
    • Streamlined dtype construction logic for more efficient initialization.
  • Tests

    • Added comprehensive kernel-based swap operation verification tests.

@github-actions
Copy link

github-actions bot commented Nov 4, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 4, 2025

Walkthrough

Adds two-phase binding to support multi-target unpacking in the AST, remaps underscore placeholders in the builder for readability, refactors dtype construction delegation, and adds a test validating tuple-swap kernels on CUDA tensors.

Changes

Cohort / File(s) Summary
Test Coverage
testing/python/language/test_tilelang_language_frontend_v2.py
Adds test_swap_logic that defines two tilelang.jit kernels (swap_var, swap_idx), runs them on a 2-element CUDA tensor, and asserts swapped results.
AST Compiler Updates
tilelang/language/v2/ast.py
Implements two-phase binding for multi-target assignments: flattens tuple targets to temporaries, emits Phase 1 binds to temps and Phase 2 binds targets to temps; adds flush_binds; raises NotImplementedError for Attribute targets.
Builder Readability Improvement
tilelang/language/v2/builder.py
In bind_immutable, remaps name "_" to "_tmp" before binding to improve generated TIR readability.
Dtype Construction Change
tilelang/language/v2/dtypes.py
Removes manual string-based dtype construction path and delegates string/mapping inputs to the original dtype constructor via __orig_dtype_new, removing prior tvm_ffi-based branch.

Sequence Diagram

sequenceDiagram
    participant Test as Test / User
    participant AST as AST Compiler
    participant Builder as Builder
    participant TIR as Generated TIR

    Test->>AST: Parse assignment `a, b = b, a`
    
    rect rgb(240, 248, 255)
    Note over AST: Phase 1 — bind RHS to temporaries
    AST->>Builder: bind("_", value_b)
    Builder->>Builder: remap "_" → "_tmp_0"
    Builder->>TIR: __tb.bind("_tmp_0", value_b)
    AST->>Builder: bind("_", value_a)
    Builder->>Builder: remap "_" → "_tmp_1"
    Builder->>TIR: __tb.bind("_tmp_1", value_a)
    end

    rect rgb(240, 255, 240)
    Note over AST: Phase 2 — bind targets to temporaries
    AST->>Builder: bind("a", _tmp_0)
    Builder->>TIR: __tb.bind("a", _tmp_0)
    AST->>Builder: bind("b", _tmp_1)
    Builder->>TIR: __tb.bind("b", _tmp_1)
    end

    TIR->>Test: Executed swap kernel (results returned)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Pay special attention to tilelang/language/v2/ast.py for correctness across nested tuple targets, ordering of binds, and interactions with subscripts.
  • Verify bind_immutable underscore remapping in builder.py doesn't clash with existing naming or tests.
  • Validate dtypes.py delegation preserves dtype semantics for string/mapping inputs.
  • Ensure the new test exercises both two-phase and single-bind paths and is robust on CUDA targets.

Poem

🐰
Tiny paws tap on keys at night,
Temps hold secrets out of sight,
Values twirl, then find their place,
Quiet swap, a perfect trace.
Hooray — the kernels dance with grace!

Pre-merge checks and finishing touches

✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title '[Feat] Add swap like grammar in tuple assignment' clearly and directly describes the main feature being added—support for swap-like syntax in tuple assignments, which is reflected in all the code changes across ast.py, builder.py, dtypes.py, and the new test.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ea93045 and 422a123.

📒 Files selected for processing (1)
  • tilelang/language/v2/dtypes.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (2)
tilelang/language/v2/dtypes.py (2)

102-103: LGTM! Good practice to preserve the original constructor.

Saving the original dtype.__new__ method before overriding it enables clean delegation and is a well-established pattern when extending built-in types.


105-112: LGTM! Clean refactor that delegates to the original constructor.

The change simplifies the dtype construction logic by delegating to __orig_dtype_new instead of manually constructing instances. Since dtype is tvm.DataType (line 9), the original TVM constructor should properly handle all necessary initialization, including any internal attributes like __tvm_ffi_dtype__. This results in cleaner, more maintainable code while preserving the same behavior.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@kurisu6912 kurisu6912 linked an issue Nov 4, 2025 that may be closed by this pull request
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 778b97d and ea93045.

📒 Files selected for processing (3)
  • testing/python/language/test_tilelang_language_frontend_v2.py (1 hunks)
  • tilelang/language/v2/ast.py (3 hunks)
  • tilelang/language/v2/builder.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_frontend_v2.py (2)
tilelang/jit/__init__.py (3)
  • jit (275-276)
  • jit (280-291)
  • jit (294-361)
tilelang/language/v2/builder.py (2)
  • prim_func (136-140)
  • prim_func (573-666)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint

Comment on lines +294 to +303
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
k_swap_var(data)
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
torch.testing.assert_close(data, ref)

k_swap_idx = swap_idx()
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
k_swap_idx(data)
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
torch.testing.assert_close(data, ref)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Skip when CUDA is unavailable

Line 294 calls .cuda() without checking torch.cuda.is_available(). On CPU-only machines (typical CI runners), this raises AssertionError: Torch not compiled with CUDA, causing the suite to fail even though the swap logic itself is valid. Please guard the test so it skips when CUDA isn’t present.

@@
 def test_swap_logic():
+    if not torch.cuda.is_available():
+        import pytest
+        pytest.skip("CUDA is required for test_swap_logic")
@@
     k_swap_var = swap_var()
     data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_frontend_v2.py around lines
294 to 303, the test unconditionally calls .cuda() which fails on machines
without CUDA; guard the CUDA-dependent assertions by checking
torch.cuda.is_available() and skip the CUDA-specific section when it's False
(e.g., call pytest.skip("CUDA unavailable") or add a pytest.mark.skipif
decorator), or alternatively run the same tensor ops on CPU when CUDA is not
present so the swap logic is still validated without requiring GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature request] Support T.var assign macro

2 participants