-
Notifications
You must be signed in to change notification settings - Fork 321
[Feat] Add support for T.serial with step and negative step
#1188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThis PR consolidates loop constructors into a new Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant tilelang_language_serial as serial()
participant v2_ctx_for as ctx_for
participant tir_serial as tir.serial / tir.frame
User->>tilelang_language_serial: call serial(start, stop, step, annotations)
activate tilelang_language_serial
alt step is 1 or None
tilelang_language_serial->>tir_serial: return standard tir.serial frame
tir_serial-->>tilelang_language_serial: ForFrame
else step != 1
tilelang_language_serial->>v2_ctx_for: SerialForWithStep(start,stop,step,annotations)
activate v2_ctx_for
v2_ctx_for->>v2_ctx_for: compute real_stop = ceildiv(|stop-start|, |step|)
v2_ctx_for->>tir_serial: tir.serial(real_stop, annotations)
tir_serial-->>v2_ctx_for: serial_frame
v2_ctx_for-->>tilelang_language_serial: transformed loop yield (start + v*step)
deactivate v2_ctx_for
end
tilelang_language_serial-->>User: loop frame (ForFrame or transformed serial)
deactivate tilelang_language_serial
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
testing/python/language/test_tilelang_language_frontend_v2.py (1)
284-284: Use integer literals for int32 tensor assignments.Lines 284 and 287 assign float literals (
1.0,2.0) to anint32tensor. While TVM likely handles the implicit conversion, using integer literals (1,2) improves clarity and avoids potential type confusion.Apply this diff:
for i in range(0, 10, 2): T.device_assert(0 <= i < 10 and i % 2 == 0, "i out of range") - A[i] = 1.0 + A[i] = 1 for i in range(1, 10, 2): T.device_assert(1 <= i < 10 and i % 2 == 1, "i out of range") - A[i] = 2.0 + A[i] = 2Also applies to: 287-287
tilelang/language/__init__.py (1)
26-26: Remove unusednoqadirective.The
noqa: F401directive is unnecessary as the F401 warning is not being raised for this import.Apply this diff:
-from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401 +from .loop import serial, Parallel, Persistent, PipelinedBased on static analysis.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
testing/python/language/test_tilelang_language_frontend_v2.py(1 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/loop.py(1 hunks)tilelang/language/parallel.py(0 hunks)tilelang/language/persistent.py(0 hunks)tilelang/language/pipeline.py(0 hunks)tilelang/language/v2/builder.py(3 hunks)
💤 Files with no reviewable changes (3)
- tilelang/language/parallel.py
- tilelang/language/pipeline.py
- tilelang/language/persistent.py
🧰 Additional context used
🧬 Code graph analysis (4)
testing/python/language/test_tilelang_language_frontend_v2.py (3)
tilelang/jit/__init__.py (3)
jit(275-276)jit(280-291)jit(294-361)tilelang/language/v2/builder.py (2)
prim_func(144-148)prim_func(594-687)tilelang/language/print.py (1)
device_assert(144-155)
tilelang/language/__init__.py (3)
tilelang/language/loop.py (1)
serial(97-108)tilelang/language/tir/ir.py (1)
serial(10-32)tilelang/language/ast/ir.py (1)
serial(672-700)
tilelang/language/v2/builder.py (1)
tilelang/language/loop.py (1)
serial(97-108)
tilelang/language/loop.py (1)
tilelang/language/v2/builder.py (1)
SerialForWithStep(99-103)
🪛 Ruff (0.14.3)
tilelang/language/__init__.py
26-26: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/language/v2/builder.py
263-265: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (6)
tilelang/language/v2/builder.py (3)
98-103: LGTM!The
SerialForWithStepdataclass is well-defined with appropriate types for representing a serial loop with a custom step.
247-267: LGTM!The
ctx_forlogic correctly handlesSerialForWithStepfor both positive and negative stepping:
- For positive steps:
real_stop = ceildiv(stop - start, step)correctly computes iteration count- For negative constant steps: adjusts
real_stop = ceildiv(start - stop, -step)- Yields the transformed value
start + loop_var * stepto produce the expected iteration sequenceThe warning for non-constant steps is also appropriate.
475-478: LGTM!The updated
overridemethod correctly routes Python'srangetoT.serial, which now supports stepped iteration. The local import avoids potential circular dependency issues.tilelang/language/loop.py (3)
12-32: LGTM!The
Parallelfunction is well-implemented with clear documentation and appropriate handling of the optionalcoalesced_widthannotation.
35-54: LGTM!The
Persistentfunction is straightforward and well-documented.
57-94: LGTM!The
Pipelinedfunction correctly normalizesstart/stopand handles optional list parameters with appropriate defaults.
|
@codex review |
|
To use Codex here, create a Codex account and connect to github. |
|
@codex review |
|
To use Codex here, create a Codex account and connect to github. |
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
… step values and prevent undefined behavior. Added error handling for zero step values and improved logging for non-constant steps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tilelang/language/v2/builder.py (1)
254-267: [Duplicate] Zero step validation still missing.This issue was flagged in a previous review. The code computes
tir.ceildiv(it.stop - it.start, it.step)at line 255 without verifying thatit.stepis non-zero. Whenstep=0, this causes division by zero before any loop is created. Python'srange(0, 10, 0)raisesValueError: range() arg 3 must not be zero, so the DSL should validate and reject zero step similarly.Apply this diff to add zero-step validation:
def ctx_for(self, it): self.check_continue_break() it = unwrap_expr(it) if isinstance(it, SerialForWithStep): + # Validate non-zero step + if isinstance(it.step, (int, IntImm)): + value = it.step if isinstance(it.step, int) else it.step.value + if value == 0: + raise ValueError("Serial loop step must not be zero") real_stop = tir.ceildiv(it.stop - it.start, it.step) if isinstance(it.step, (int, IntImm)): value = it.step if isinstance(it.step, int) else it.step.valueAdditionally, verify the ceil division logic for edge cases:
#!/bin/bash # Check for any existing zero-step validation in the codebase rg -nP "step.*==\s*0|step.*zero|ValueError.*step" --type=py # Search for range() calls that might be affected rg -nP "range\s*\([^)]*,\s*[^)]*,\s*[^)]*\)" --type=py -A2 -B2
🧹 Nitpick comments (2)
tilelang/language/v2/builder.py (1)
269-272: Consider extracting the error message to a module-level constant.The error message is quite long and static analysis (Ruff TRY003) suggests extracting it to improve maintainability.
+_INVALID_FOR_LOOP_ERROR = ( + "Invalid for loop, got {it}({type}), expect one of the following: " + "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding" +) + def ctx_for(self, it): ... else: if not isinstance(it, tir.frame.ForFrame): - raise TypeError( - f"Invalid for loop, got {it}({type(it)}), expect one of the following: " - "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding") + raise TypeError(_INVALID_FOR_LOOP_ERROR.format(it=it, type=type(it)))testing/python/language/test_tilelang_language_frontend_v2.py (1)
278-313: Good test coverage, but consider adding zero-step validation test.The tests effectively cover positive steps, negative steps, and frame type assertions. The use of
device_assertwithin loops validates that iteration variables stay in expected ranges.Consider adding a test case for
step=0, which should raise aValueErrorsimilar to Python'srange(0, 10, 0):def test_serial_zero_step_error(): """Verify that zero step is rejected.""" import pytest @T.prim_func def invalid_zero_step(A: T.Tensor((10,), T.int32)): with T.Kernel(1) as _: for i in range(0, 10, 0): # Should raise ValueError A[i] = i with pytest.raises(ValueError, match="step must not be zero"): invalid_zero_step
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
testing/python/language/test_tilelang_language_frontend_v2.py(2 hunks)tilelang/language/loop.py(1 hunks)tilelang/language/v2/builder.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
testing/python/language/test_tilelang_language_frontend_v2.py (2)
tilelang/language/print.py (1)
device_assert(144-155)tilelang/language/loop.py (1)
serial(97-111)
tilelang/language/loop.py (1)
tilelang/language/v2/builder.py (1)
SerialForWithStep(104-108)
tilelang/language/v2/builder.py (1)
tilelang/language/loop.py (1)
serial(97-111)
🪛 Ruff (0.14.3)
tilelang/language/v2/builder.py
270-272: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (3)
tilelang/language/v2/builder.py (2)
103-108: LGTM! Clean dataclass definition.The
SerialForWithStepdataclass is well-structured and provides the necessary fields to represent a serial loop with a custom step parameter.
485-487: LGTM! Correct routing to the new serial API.The override correctly routes
range()toT.serial, which now supports the step parameter. The import is appropriately scoped within the method.testing/python/language/test_tilelang_language_frontend_v2.py (1)
6-7: LGTM! Necessary imports for serial step testing.The imports support runtime assertions about frame types and symbolic variables in the new test.
| def serial(start: tir.PrimExpr, | ||
| stop: tir.PrimExpr | None = None, | ||
| step: tir.PrimExpr | None = None, | ||
| *, | ||
| annotations: dict[str, Any] | None = None): | ||
| step_is_one = False | ||
| step_is_one |= isinstance(step, int) and step == 1 | ||
| step_is_one |= isinstance(step, IntImm) and step.value == 1 | ||
| if step is None or step_is_one: | ||
| return tb_tir.serial(start, stop, annotations=annotations) | ||
| else: | ||
| if stop is None: | ||
| stop = start | ||
| start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0 | ||
| return SerialForWithStep(start, stop, step, annotations=annotations) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing zero-step validation.
The function correctly routes step values to either tb_tir.serial (for step=None or step=1) or SerialForWithStep, and properly normalizes start/stop. However, zero step is not validated before creating SerialForWithStep, which will cause division by zero in builder.py::ctx_for line 255.
Apply this diff to validate step:
def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
+ # Validate non-zero step for constant values
+ if isinstance(step, int) and step == 0:
+ raise ValueError("Serial loop step must not be zero")
+ if isinstance(step, IntImm) and step.value == 0:
+ raise ValueError("Serial loop step must not be zero")
+
step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1Minor: Consider using logical or instead of bitwise |=.
While lines 103-104 work correctly, the bitwise OR pattern is unconventional for boolean accumulation. Consider:
- step_is_one = False
- step_is_one |= isinstance(step, int) and step == 1
- step_is_one |= isinstance(step, IntImm) and step.value == 1
+ step_is_one = (isinstance(step, int) and step == 1) or \
+ (isinstance(step, IntImm) and step.value == 1)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def serial(start: tir.PrimExpr, | |
| stop: tir.PrimExpr | None = None, | |
| step: tir.PrimExpr | None = None, | |
| *, | |
| annotations: dict[str, Any] | None = None): | |
| step_is_one = False | |
| step_is_one |= isinstance(step, int) and step == 1 | |
| step_is_one |= isinstance(step, IntImm) and step.value == 1 | |
| if step is None or step_is_one: | |
| return tb_tir.serial(start, stop, annotations=annotations) | |
| else: | |
| if stop is None: | |
| stop = start | |
| start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0 | |
| return SerialForWithStep(start, stop, step, annotations=annotations) | |
| def serial(start: tir.PrimExpr, | |
| stop: tir.PrimExpr | None = None, | |
| step: tir.PrimExpr | None = None, | |
| *, | |
| annotations: dict[str, Any] | None = None): | |
| # Validate non-zero step for constant values | |
| if isinstance(step, int) and step == 0: | |
| raise ValueError("Serial loop step must not be zero") | |
| if isinstance(step, IntImm) and step.value == 0: | |
| raise ValueError("Serial loop step must not be zero") | |
| step_is_one = (isinstance(step, int) and step == 1) or \ | |
| (isinstance(step, IntImm) and step.value == 1) | |
| if step is None or step_is_one: | |
| return tb_tir.serial(start, stop, annotations=annotations) | |
| else: | |
| if stop is None: | |
| stop = start | |
| start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0 | |
| return SerialForWithStep(start, stop, step, annotations=annotations) |
🤖 Prompt for AI Agents
In tilelang/language/loop.py around lines 97 to 111, the function lacks
validation for a zero step which will cause a division-by-zero downstream; add a
guard before creating SerialForWithStep that checks if step represents zero
(e.g., isinstance(step, int) and step == 0 or isinstance(step, IntImm) and
step.value == 0) and raise a ValueError with a clear message if so; also replace
the bitwise accumulation using '|=' on step_is_one with conventional boolean
logic using 'or' (e.g., step_is_one = (isinstance(step, int) and step == 1) or
(isinstance(step, IntImm) and step.value == 1)) so the intent is clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tilelang/language/v2/builder.py (1)
265-268: Optional: Clarify why non-constant steps cause undefined behavior.The warning correctly alerts users but could be more explicit about the root cause.
Consider enhancing the warning message:
logger.warning( - f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang' + f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang. ' + 'T.ceildiv requires non-negative operands; negative runtime step values will produce incorrect iteration counts.' )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/language/v2/builder.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/v2/builder.py (1)
tilelang/language/loop.py (1)
serial(97-111)
🪛 Ruff (0.14.3)
tilelang/language/v2/builder.py
259-259: Avoid specifying long messages outside the exception class
(TRY003)
275-277: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (4)
tilelang/language/v2/builder.py (4)
103-109: LGTM: Clean dataclass definition for stepped serial loops.The
SerialForWithStepdataclass appropriately encapsulates the loop parameters needed for serial iteration with custom step values.
258-259: LGTM: Zero-step validation correctly prevents division by zero.The check properly rejects zero steps for constant values, addressing the concern from previous reviews.
274-279: LGTM: Existing for-loop handling preserved.The else branch correctly maintains backward compatibility for non-stepped serial loops and other loop constructs.
490-493: Import path verified and correct.The
from tilelang.language import serialat line 490 correctly imports the serial function fromtilelang/language/loop.py, which is properly exported throughtilelang/language/__init__.py. The change integrates the stepped serial loop support as intended.
| if isinstance(it.step, (int, IntImm)): | ||
| step_value = it.step if isinstance(it.step, int) else it.step.value | ||
| if step_value == 0: | ||
| raise ValueError('Invalid stepped serial: step must be non-zero') | ||
| if step_value > 0: | ||
| real_stop = tir.ceildiv(it.stop - it.start, step_value) | ||
| else: | ||
| real_stop = tir.ceildiv(it.start - it.stop, -step_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Handle empty iteration range for negative steps.
When step < 0 and start <= stop (e.g., T.serial(0, 10, -3)), line 263 computes ceildiv(0 - 10, 3) = -3, producing a negative real_stop. This causes invalid IR generation when passed to tir.serial at line 269.
Python's range(0, 10, -3) produces an empty sequence; the DSL should yield 0 iterations in this case.
Apply this diff to clamp real_stop to non-negative values:
if step_value > 0:
- real_stop = tir.ceildiv(it.stop - it.start, step_value)
+ real_stop = tir.max(tir.ceildiv(it.stop - it.start, step_value), 0)
else:
- real_stop = tir.ceildiv(it.start - it.stop, -step_value)
+ real_stop = tir.max(tir.ceildiv(it.start - it.stop, -step_value), 0)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if isinstance(it.step, (int, IntImm)): | |
| step_value = it.step if isinstance(it.step, int) else it.step.value | |
| if step_value == 0: | |
| raise ValueError('Invalid stepped serial: step must be non-zero') | |
| if step_value > 0: | |
| real_stop = tir.ceildiv(it.stop - it.start, step_value) | |
| else: | |
| real_stop = tir.ceildiv(it.start - it.stop, -step_value) | |
| if isinstance(it.step, (int, IntImm)): | |
| step_value = it.step if isinstance(it.step, int) else it.step.value | |
| if step_value == 0: | |
| raise ValueError('Invalid stepped serial: step must be non-zero') | |
| if step_value > 0: | |
| real_stop = tir.max(tir.ceildiv(it.stop - it.start, step_value), 0) | |
| else: | |
| real_stop = tir.max(tir.ceildiv(it.start - it.stop, -step_value), 0) |
🧰 Tools
🪛 Ruff (0.14.3)
259-259: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In tilelang/language/v2/builder.py around lines 256 to 263, the computation of
real_stop for negative step values can produce a negative count (e.g., start <=
stop with step < 0), causing invalid IR; after computing real_stop for the
negative-step branch, clamp it to a non-negative integer (e.g., real_stop =
max(0, real_stop)) so empty iteration ranges yield zero iterations before it is
passed to tir.serial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tilelang/language/v2/builder.py (1)
256-263: Empty iteration range still produces negative real_stop.When
step < 0andstart <= stop(e.g.,T.serial(0, 10, -3)), line 263 computesceildiv(0 - 10, 3) = -3, producing a negativereal_stopthat causes invalid IR.This issue was already flagged in previous reviews and remains unresolved. The suggested fix is to clamp
real_stopto non-negative values after the computation.
🧹 Nitpick comments (1)
tilelang/language/v2/builder.py (1)
264-268: Clarify the undefined behavior for non-constant steps.The warning mentions "undefined behavior" but doesn't explain that line 268 assumes a positive step. If the step expression evaluates to a negative value at runtime,
ceildiv(it.stop - it.start, it.step)will produce incorrect iteration counts.Consider clarifying the warning message:
else: logger.warning( - f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang' + f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang. ' + f'The implementation assumes a positive step; negative runtime values will produce incorrect iteration counts.' )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/language/v2/builder.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/v2/builder.py (2)
tilelang/carver/roller/hint.py (1)
step(240-243)tilelang/language/loop.py (1)
serial(97-111)
🪛 Ruff (0.14.3)
tilelang/language/v2/builder.py
259-259: Avoid specifying long messages outside the exception class
(TRY003)
275-277: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (3)
tilelang/language/v2/builder.py (3)
103-108: LGTM!The dataclass structure is clean and appropriate for representing stepped serial loops.
269-272: LGTM!The frame creation and value transformation (
it.start + v * it.step) correctly implement the stepped loop logic.
490-493: LGTM!The override correctly returns
tilelang.language.serialforrange, enabling stepped serial loops in the DSL.
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if isinstance(it.step, (int, IntImm)): | ||
| step_value = it.step if isinstance(it.step, int) else it.step.value | ||
| if step_value == 0: | ||
| raise ValueError('Invalid stepped serial: step must be non-zero') | ||
| if step_value > 0: | ||
| real_stop = tir.ceildiv(it.stop - it.start, step_value) | ||
| else: | ||
| real_stop = tir.ceildiv(it.start - it.stop, -step_value) | ||
| else: | ||
| logger.warning( | ||
| f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang' | ||
| ) | ||
| real_stop = tir.ceildiv(it.stop - it.start, it.step) | ||
| real_frame = tir.serial(real_stop, annotations=it.annotations) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard trip count when step cannot reach stop
The new stepped loop path computes real_stop with tir.ceildiv and passes it directly to tir.serial. When the step sign is incompatible with the start/stop ordering (e.g. range(10, 0, 2) or range(0, 10, -2)), (stop - start) and step have opposite signs, so real_stop becomes negative and we attempt to build a serial loop with a negative extent. TVM’s serial builder rejects negative extents, raising an error for loops that should simply execute zero iterations per Python semantics. Consider clamping the trip count to zero when no iterations are expected before creating the tir.serial frame.
Useful? React with 👍 / 👎.
This pr add support step and negative step in
T.serial. But due toT.ceildivallows non-negative variables only, negative step may cause undefined behavior.I check whether it is a constant, and show warning when it may cause undefined behavior:
Summary by CodeRabbit
New Features
Refactor
Tests