-
Notifications
You must be signed in to change notification settings - Fork 322
[Refactor] add support for numpy dtype conversion #1255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdds a large TIR type-stub Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant Stubgen as stubgen.py
participant OpModule as tilelang/language/tir/op.py
participant IRStub as tilelang/language/tir/ir.pyi
User->>Stubgen: run
Stubgen->>OpModule: read & parse AST
Stubgen->>Stubgen: extract pseudo-docstring types\nmap to PrimExpr/_T/etc.
Stubgen->>IRStub: generate typed FunctionDef stubs
Stubgen-->>User: write ir.pyi
note right of IRStub `#DDEBF7`: New static type declarations\n(no runtime code)
sequenceDiagram
autonumber
participant Caller
participant dtypes as dtypes.py
participant Mappings as _DTYPE_TO_STR/_STR_TO_TVM_DTYPE_CALL
participant tvm as TVM FFI
Caller->>dtypes: __dtype_new__(value)
dtypes->>Mappings: lookup value -> dtype_str
Mappings-->>dtypes: dtype_str
dtypes-->>Caller: dtype object
Caller->>dtypes: __dtype_call__(dtype_str)
dtypes->>Mappings: lookup dtype_str -> ffi_entry
Mappings->>tvm: call ffi_entry
tvm-->>dtypes: TVM dtype result
dtypes-->>Caller: TVM dtype
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tilelang/language/v2/dtypes.py (2)
77-96: Harden__dtype_call__for missing FFI symbolsThe new fast path via
_STR_TO_TVM_DTYPE_CALLlooks good, but unlike the fallback logic below, it does not guard against the case where the mapped FFI symbol doesn’t exist ontb_ffi. In that scenario you’d get aNoneType is not callableinstead of the clearer TypeError you emit later.You can align the behaviors like this:
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var: - if self in _STR_TO_TVM_DTYPE_CALL: - attr = _STR_TO_TVM_DTYPE_CALL[self] - call = getattr(tb_ffi, attr, None) - return call(expr, is_size_var) + if self in _STR_TO_TVM_DTYPE_CALL: + attr = _STR_TO_TVM_DTYPE_CALL[self] + call = getattr(tb_ffi, attr, None) + if call is None: + raise TypeError( + f"Convert to datatype `{self}` is not supported by tvm\n" + f"calling failed on `tvm.script.ir_builder.tir._ffi_api.{attr}`" + ) + return call(expr, is_size_var)This keeps error reporting consistent regardless of whether you hit the explicit map or the derived-name path.
Consider running the existing TIR builder tests across the TVM versions you support to ensure all
_STR_TO_TVM_DTYPE_CALLentries have corresponding FFI functions.Also applies to: 100-103
128-135: Makedtype.__new__idempotent and friendlier to more inputsThe revamped
__dtype_new__nicely centralizes Python/NumPy/Torch mappings, but:
- Passing an existing
dtypeinstance (or potentially anir.Type) intodtype(...)now falls into the error branch, even though such calls are often expected to be idempotent.- The
expectedset mixes type objects and strings, which is fine, but can be noisy in error messages.You can make this more robust and backwards-friendly with a small tweak:
def __dtype_new__(cls, value: AnyDType) -> dtype: - if isinstance(value, str): - return __orig_dtype_new(cls, value) - elif value in _DTYPE_TO_STR: - return __orig_dtype_new(cls, _DTYPE_TO_STR[value]) - else: - expected = set(list(_DTYPE_TO_STR.keys()) + list(_DTYPE_TO_STR.values())) - raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}") + # Already a dtype: keep it idempotent. + if isinstance(value, dtype): + return value + if isinstance(value, str): + return __orig_dtype_new(cls, value) + if value in _DTYPE_TO_STR: + return __orig_dtype_new(cls, _DTYPE_TO_STR[value]) + expected = set(_DTYPE_TO_STR.keys()) | set(_DTYPE_TO_STR.values()) + raise TypeError( + f"Invalid DataType {value}({type(value)}), expect one of {expected}" + )This should reduce surprising TypeErrors for callers that pass an existing
dtypewhile keeping the new NumPy/Torch conversion semantics.Please scan for call sites doing
dtype(existing_dtype_or_type)to confirm this change matches their expectations and doesn’t mask any intentional validation.tilelang/language/tir/ir.pyi (1)
58-59: Alignspanparameter typing for consistency in the stubMost intrinsics here use
span: Span | None = None, but a few keepspan=Noneuntyped (shift_left,shift_right,call_packed*,call_cpacked*). This is harmless at runtime (it’s a stub), but slightly degrades type-checking quality and consistency.You could align these signatures as follows:
-def shift_left(x: _T, y: _T, span=None) -> _T: ... -def shift_right(x: _T, y: _T, span=None) -> _T: ... +def shift_left(x: _T, y: _T, span: Span | None = None) -> _T: ... +def shift_right(x: _T, y: _T, span: Span | None = None) -> _T: ... @@ -def call_packed(*args, span=None) -> _T: ... -def call_cpacked(*args, span=None) -> _T: ... -def call_packed_lowered(*args, span=None) -> _T: ... -def call_cpacked_lowered(*args, span=None) -> _T: ... +def call_packed(*args, span: Span | None = None) -> _T: ... +def call_cpacked(*args, span: Span | None = None) -> _T: ... +def call_packed_lowered(*args, span: Span | None = None) -> _T: ... +def call_cpacked_lowered(*args, span: Span | None = None) -> _T: ...This keeps the stub self-consistent and makes it easier for type checkers to track span usage.
Please confirm against the actual TVM TIR Python API signatures you’re targeting to ensure the annotated types and parameter names (especially
span) are accurate.Also applies to: 75-78
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/language/tir/ir.pyi(1 hunks)tilelang/language/v2/dtypes.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/tir/ir.pyi (1)
tilelang/language/ast/ir.py (1)
func_name(206-214)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (1)
tilelang/language/v2/dtypes.py (1)
44-65: ****PyTorch exposes
torch.uint16,torch.uint32, andtorch.uint64dtypes, so this module will not raiseAttributeErrorat import time. The code snippet in lines 44–65 is valid as written, and the proposed conditional registration workaround is unnecessary.Likely an incorrect or invalid review comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/language/v2/dtypes.py (1)
135-138: Return type annotation doesn't match implementation.Line 136 returns
valueif it's anir.Type, but the function signature declares-> dtype. This creates a type annotation mismatch.If
ir.Typeis compatible withdtype, update the return annotation:-def get_tvm_dtype(value: AnyDType) -> dtype: +def get_tvm_dtype(value: AnyDType) -> Union[dtype, ir.Type]:Otherwise, convert
ir.Typetodtype:def get_tvm_dtype(value: AnyDType) -> dtype: - if isinstance(value, (dtype, ir.Type)): - return value + if isinstance(value, dtype): + return value + elif isinstance(value, ir.Type): + return dtype(str(value)) # or appropriate conversion return dtype(value)
♻️ Duplicate comments (1)
tilelang/language/v2/dtypes.py (1)
11-11: AnyDType annotation still missing NumPy types despite previous review feedback.The type annotation remains incomplete even though this was flagged in the previous review. Runtime now accepts NumPy dtypes (via
_NUMPY_DTYPE_TO_STRat lines 19-40), but the type hint does not reflect this.Apply this fix as suggested in the previous review:
-AnyDType = Union[ir.Type, str, type, torch.dtype, dtype] +AnyDType = Union[ir.Type, str, type, torch.dtype, dtype, "np.dtype", "np.generic"]
🧹 Nitpick comments (3)
tilelang/language/v2/dtypes.py (3)
92-115: Implicit dtype-to-string comparison may be fragile.Line 93 checks
if self in _STR_TO_TVM_DTYPE_CALL:whereselfis adtypeobject but the dict has string keys. This relies on TVM'sdtype.__eq__implementation to compare with strings, which is an implicit dependency on internal TVM behavior.Make the string conversion explicit for clarity and robustness:
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var: - if self in _STR_TO_TVM_DTYPE_CALL: - attr = _STR_TO_TVM_DTYPE_CALL[self] + dtype_str = str(self) + if dtype_str in _STR_TO_TVM_DTYPE_CALL: + attr = _STR_TO_TVM_DTYPE_CALL[dtype_str] call = getattr(tb_ffi, attr, None) return call(expr, is_size_var) # try to construct the ffi call - if self.startswith('uint'): - val = 'UInt' + self[4:] - elif self.startswith('int'): - val = 'Int' + self[3:] - elif self.startswith('float'): - val = 'Float' + self[5:] - elif self.startswith('bfloat'): - val = 'BFloat' + self[6:] + if dtype_str.startswith('uint'): + val = 'UInt' + dtype_str[4:] + elif dtype_str.startswith('int'): + val = 'Int' + dtype_str[3:] + elif dtype_str.startswith('float'): + val = 'Float' + dtype_str[5:] + elif dtype_str.startswith('bfloat'): + val = 'BFloat' + dtype_str[6:] else: - raise TypeError(f'Invalid type {self}') + raise TypeError(f'Invalid type {dtype_str}') if '_' in val: first, second = val.split('_', maxsplit=1) val = first + second.upper() call = getattr(tb_ffi, val, None) if call is None: - raise TypeError(f"Convert to datatype `{self}` is not supported by tvm\n" + raise TypeError(f"Convert to datatype `{dtype_str}` is not supported by tvm\n" f"calling failed on `tvm.script.ir_builder.tir._ffi_api.{val}`") return call(expr, is_size_var)
121-128: Minor: Simplify error message construction.The logic is correct, but the error message construction can be slightly cleaner.
else: - expected = set(list(_DTYPE_TO_STR.keys()) + list(_DTYPE_TO_STR.values())) + expected = set(_DTYPE_TO_STR.keys()) | set(_DTYPE_TO_STR.values()) raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}")
64-66: Consider removing or documenting commented code.The reverse mappings are commented out. If these are not needed, remove them to keep the codebase clean. If they're planned for future use, add a TODO comment explaining the intent.
-# _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()} - -# _STR_TO_NUMPY_DTYPE = {v: k for k, v in _NUMPY_DTYPE_TO_STR.items()} +# TODO: Add reverse mappings if needed for future dtype conversion utilities +# _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()} +# _STR_TO_NUMPY_DTYPE = {v: k for k, v in _NUMPY_DTYPE_TO_STR.items()}Or simply remove them if not planned.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/language/v2/dtypes.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (3)
tilelang/language/v2/dtypes.py (3)
13-17: LGTM! Python bool support restored.The
boolmapping has been correctly added, addressing the previous review concern about the behavior regression.
42-62: LGTM! Comprehensive Torch dtype mappings.The Torch dtype mappings are thorough and correctly normalize aliases (e.g.,
torch.half→'float16') to canonical names present in_STR_TO_TVM_DTYPE_CALL.
68-68: LGTM! Clean consolidation of dtype mappings.The merge strategy correctly combines all three mapping dictionaries without key collisions, since Python, NumPy, and Torch types are distinct objects.
| _NUMPY_DTYPE_TO_STR = { | ||
| np.bool_: 'bool', | ||
| np.short: 'short', | ||
| np.int_: 'int32', | ||
| np.longlong: 'int64', | ||
| np.half: 'float16', | ||
| np.float_: 'float32', | ||
| np.double: 'float64', | ||
| np.int8: 'int8', | ||
| np.int16: 'int16', | ||
| np.int32: 'int32', | ||
| np.int64: 'int64', | ||
| np.uint8: 'uint8', | ||
| np.uint16: 'uint16', | ||
| np.uint32: 'uint32', | ||
| np.uint64: 'uint64', | ||
| np.float16: 'float16', | ||
| np.float32: 'float32', | ||
| np.float64: 'float64', | ||
| } | ||
|
|
||
| _NUMPY_DTYPE_TO_STR.update({np.dtype(k): v for k, v in _NUMPY_DTYPE_TO_STR.items()}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify 'short' mapping and np.dtype() conversion pattern.
Two concerns:
-
Missing FFI mapping: Line 21 maps
np.shortto'short', but'short'is absent from_STR_TO_TVM_DTYPE_CALL(lines 70-89). Whendtype(np.short)is called, it will hit the fallback path in__dtype_call__attempting to constructtvm.script.ir_builder.tir._ffi_api.Short, which may not exist. -
Unconventional pattern: Line 40's
np.dtype(k)pattern (e.g.,np.dtype(np.int32)) creates dtype objects from type objects. While this should work, it's an unusual approach that warrants testing.
Run this script to verify both concerns:
🏁 Script executed:
#!/bin/bash
# Verify 'short' and other missing type mappings work correctly
cat << 'PYEOF' > /tmp/test_dtype.py
import sys
sys.path.insert(0, '.')
from tilelang.language.v2 import dtypes
import numpy as np
# Test 1: Check if np.short works
try:
result = dtypes.dtype(np.short)
print(f"✓ np.short works: {result}")
except Exception as e:
print(f"✗ np.short failed: {e}")
# Test 2: Check if np.dtype(np.int32) works
try:
result = dtypes.dtype(np.dtype(np.int32))
print(f"✓ np.dtype(np.int32) works: {result}")
except Exception as e:
print(f"✗ np.dtype(np.int32) failed: {e}")
# Test 3: Check various numpy dtypes
test_types = [np.int8, np.int32, np.float32, np.bool_]
for t in test_types:
try:
result = dtypes.dtype(t)
print(f"✓ {t.__name__} -> {result}")
except Exception as e:
print(f"✗ {t.__name__} failed: {e}")
PYEOF
python /tmp/test_dtype.pyLength of output: 666
I'll verify the code structure directly and check NumPy documentation on the np.dtype() pattern.
🏁 Script executed:
#!/bin/bash
# Verify the code structure and mappings
echo "=== Checking line 21 for np.short mapping ==="
sed -n '19,25p' tilelang/language/v2/dtypes.py
echo -e "\n=== Checking _STR_TO_TVM_DTYPE_CALL (lines 70-89) ==="
sed -n '70,89p' tilelang/language/v2/dtypes.py
echo -e "\n=== Checking line 40 and surrounding context ==="
sed -n '38,42p' tilelang/language/v2/dtypes.py
echo -e "\n=== Checking __dtype_call__ implementation ==="
rg -n -A 10 "def __dtype_call__" tilelang/language/v2/dtypes.pyLength of output: 1499
🌐 Web query:
Does numpy.dtype() work with numpy type objects like np.int32?
💡 Result:
Yes — np.dtype accepts NumPy scalar type objects like np.int32. Example: np.dtype(np.int32) returns a dtype('int32'). See the NumPy documentation. [1][2]
Sources
- NumPy — Data type objects (dtype) (reference manual). [1]
- NumPy — Specifying and constructing data types (1.25 docs). [2]
I need to see the complete __dtype_call__ implementation to understand what happens when 'short' doesn't match any pattern.
🏁 Script executed:
#!/bin/bash
# Get full __dtype_call__ implementation
rg -n -A 20 "def __dtype_call__" tilelang/language/v2/dtypes.pyLength of output: 848
Add 'short' to _STR_TO_TVM_DTYPE_CALL mapping.
The code maps np.short to the string 'short' (line 21), but 'short' is missing from _STR_TO_TVM_DTYPE_CALL (lines 70-89). When dtype(np.short) is called, it passes through __dtype_call__ (line 92), which:
- Checks if
'short'exists in the dictionary (line 93) — it doesn't - Falls through to pattern matching (lines 98-105) —
'short'doesn't start with'uint','int','float', or'bfloat' - Raises
TypeError(line 107)
This also affects calls using np.dtype(np.short) due to the .update() on line 40.
Add 'short': 'Short' to _STR_TO_TVM_DTYPE_CALL. The np.dtype() pattern on line 40 is valid NumPy usage and requires no changes.
🤖 Prompt for AI Agents
tilelang/language/v2/dtypes.py lines ~19-40 and mapping usage around lines
70-107: the mapping contains np.short -> 'short' but '_STR_TO_TVM_DTYPE_CALL'
lacks a 'short' entry causing dtype(np.short) or dtype(np.dtype(np.short)) to
fall through and raise; add the missing mapping entry 'short': 'Short' to the
_STR_TO_TVM_DTYPE_CALL dictionary so __dtype_call__ can resolve 'short' directly
(no other changes to the np.dtype() update required).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
♻️ Duplicate comments (1)
tilelang/language/v2/dtypes.py (1)
11-17: AlignAnyDTypetyping with the new NumPy support (optional typing polish)Runtime-wise, the new registries
_NUMPY_DTYPE_TO_STRand_DTYPE_TO_STRcorrectly supportnp.bool_,np.int32,np.dtype("float32"), etc. However, theAnyDTypeannotation hasn’t been updated and still reads:AnyDType = Union[ir.Type, str, type, torch.dtype, dtype]For type checkers, that means NumPy dtypes are seen as “wrong” even though they’re accepted at runtime.
If you want the typing story to match behavior, you could extend it to cover NumPy explicitly, for example:
-from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Union +import numpy as np # already imported above -AnyDType = Union[ir.Type, str, type, torch.dtype, dtype] +AnyDType = Union[ir.Type, str, type, torch.dtype, dtype, np.dtype, np.generic](or use
np.typing.DTypeLikeif that fits your typing conventions).Not mandatory for correctness, but it keeps annotations honest with respect to the new NumPy integration.
What is the recommended static typing alias in NumPy (>=1.26) for “dtype-like” inputs that include `np.dtype`, scalar types (`np.float32`), and strings?Also applies to: 19-39, 41-68, 134-137
🧹 Nitpick comments (4)
a.py (1)
1-6: Treat this as an example/benchmark, not import-time codeThis looks like a quick local dtype-conversion experiment rather than production logic. Having it as a top-level module means it runs on import and adds noise to the repo.
Consider either deleting
a.pyor turning it into a proper example/test (with anif __name__ == "__main__":guard) so it doesn’t affect normal imports.testing/python/language/test_tilelang_language_frontend_v2.py (1)
148-204: Avoid permanently commenting outtest_torch_eq; update or replace it and add NumPy testsTurning
test_torch_eqinto a commented block with “not supported now” silently drops coverage for a critical path (torch↔T dtype equality andT.dtype(torch_dtype)conversion), right when this PR is refactoring dtype handling.Prefer to:
- Either fix the expectations in
test_torch_eqto match the new semantics, or mark it asxfailwith a clear reason; and- Add a dedicated test that exercises the new NumPy dtype conversions (
dtype(np.int32),dtype(np.dtype("float32")), etc.), similar to how this test covered torch dtypes.That keeps the behavior explicit and protects against future regressions in dtype dispatch.
stubgen.py (1)
28-32: Tighten exception handling and clean up unused imports in the stub generatorFor a one-off tool this is not blocking, but a couple of small improvements would make
stubgen.pymore robust:
- Lines 76–79 and 83–86:
except Exceptiononly prints and continues. Narrowing toSyntaxError/ValueError, or at least documenting why a broad catch is needed, will quiet static analysis and make failures more intentional.from logging.config import valid_identandfrom argparse import ArgumentParserappear unused and can be dropped.None of this affects runtime behavior of TileLang, but it simplifies future maintenance of the stub generator.
Also applies to: 76-86
triteo_linear.py (1)
12-33: Document or guard the allowedshiftrange inshift_with_zeros
shift_with_zerosassumesabs(shift) <= x.shape[dim]:zeros_shape[dim] = abs(shift) ... x.narrow(dim, 0, x.shape[dim] - shift)If someone accidentally calls it with
abs(shift) > x.shape[dim],x.shape[dim] - shiftbecomes negative andnarrowwill raise. In this file you only useshiftas ±1, so it’s safe, but it would be good to either:
- Explicitly document that precondition in the docstring, or
- Clamp/early-return zeros when
abs(shift) >= x.shape[dim].That makes the helper safer to reuse elsewhere.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
a.py(1 hunks)stubgen.py(1 hunks)testing/python/language/test_tilelang_language_frontend_v2.py(1 hunks)tilelang/language/v2/dtypes.py(2 hunks)triteo_linear.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
a.py (1)
tilelang/language/v2/dtypes.py (1)
float32(199-199)
triteo_linear.py (3)
tilelang/jit/kernel.py (1)
JITKernel(31-727)tilelang/language/allocate.py (1)
alloc_fragment(59-70)tilelang/language/loop.py (1)
Parallel(12-32)
🪛 Ruff (0.14.4)
stubgen.py
78-78: Do not catch blind exception: Exception
(BLE001)
85-85: Do not catch blind exception: Exception
(BLE001)
triteo_linear.py
14-14: Docstring contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF002)
16-16: Docstring contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF002)
16-16: Docstring contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF002)
16-16: Docstring contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF002)
16-16: Docstring contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF002)
16-16: Docstring contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF002)
38-38: Docstring contains ambiguous : (FULLWIDTH COLON). Did you mean : (COLON)?
(RUF002)
43-43: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
43-43: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
47-47: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
47-47: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
49-49: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
49-49: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
90-90: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
92-92: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
stubgen.py
Outdated
| fdef.body = [ast.parse('...')] | ||
| # funcs.append(fdef) | ||
| funcs[fdef.name] = fdef | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Use a statement node for the stub body instead of an entire Module
fdef.body = [ast.parse('...')] assigns a Module node into a function body, where the AST normally expects stmt nodes (e.g., an Expr).
Safer and more conventional is:
- fdef.body = [ast.parse('...')]
+ fdef.body = [ast.parse('...').body[0]]This keeps the generated stubs structurally valid across CPython versions and tools that walk the AST.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| fdef.body = [ast.parse('...')] | |
| # funcs.append(fdef) | |
| funcs[fdef.name] = fdef | |
| fdef.body = [ast.parse('...').body[0]] | |
| # funcs.append(fdef) | |
| funcs[fdef.name] = fdef |
🤖 Prompt for AI Agents
In stubgen.py around lines 94 to 97, fdef.body currently gets a Module node via
ast.parse('...'), which places a Module into a function body; replace that with
a statement node instead (for example assign fdef.body =
[ast.parse('...').body[0]] or construct an ast.Expr(ast.Constant('...')) with
appropriate lineno/col_offset) so the function body contains a proper stmt node
and the generated stubs remain valid for AST walkers and different CPython
versions.
| _STR_TO_TVM_DTYPE_CALL = { | ||
| 'bool': 'Boolean', | ||
| 'int8': 'Int8', | ||
| 'int32': 'Int32', | ||
| 'int64': 'Int64', | ||
| 'uint8': 'UInt8', | ||
| 'uint16': 'UInt16', | ||
| 'uint32': 'UInt32', | ||
| 'uint64': 'UInt64', | ||
| 'float16': 'Float16', | ||
| 'float32': 'Float32', | ||
| 'float64': 'Float64', | ||
| 'bfloat16': 'BFloat16', | ||
| 'float8_e4m3': 'Float8E4M3', | ||
| 'float8_e4m3fn': 'Float8E4M3FN', | ||
| 'float8_e4m3fnuz': 'Float8E4M3FNUZ', | ||
| 'float8_e5m2': 'Float8E5M2', | ||
| 'float8_e5m2fnuz': 'Float8E5M2FNUZ', | ||
| 'float8_e8m0fnu': 'Float8E8M0FNU' | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Use str(self) for dispatch in __dtype_call__ so the FFI mapping actually applies
Right now __dtype_call__ does:
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var:
if self in _STR_TO_TVM_DTYPE_CALL:
attr = _STR_TO_TVM_DTYPE_CALL[self]
call = getattr(tb_ffi, attr, None)
return call(expr, is_size_var)
# try to construct the ffi call
if self.startswith('uint'):
...Two problems:
_STR_TO_TVM_DTYPE_CALLis keyed by strings like'bool','float32', etc., butselfis atvm.DataTypeinstance. The membership test therefore never succeeds, so special cases such as'bool' -> 'Boolean'are effectively dead code.- The subsequent
self.startswith(...)calls assumetvm.DataTypebehaves like a string. It’s much safer (and clearer) to normalize to a string once and operate on that.
A minimal, robust fix is:
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var:
- if self in _STR_TO_TVM_DTYPE_CALL:
- attr = _STR_TO_TVM_DTYPE_CALL[self]
- call = getattr(tb_ffi, attr, None)
- return call(expr, is_size_var)
- # try to construct the ffi call
- if self.startswith('uint'):
- val = 'UInt' + self[4:]
- elif self.startswith('int'):
- val = 'Int' + self[3:]
- elif self.startswith('float'):
- val = 'Float' + self[5:]
- elif self.startswith('bfloat'):
- val = 'BFloat' + self[6:]
+ key = str(self)
+ if key in _STR_TO_TVM_DTYPE_CALL:
+ attr = _STR_TO_TVM_DTYPE_CALL[key]
+ call = getattr(tb_ffi, attr, None)
+ return call(expr, is_size_var)
+ # try to construct the ffi call
+ if key.startswith('uint'):
+ val = 'UInt' + key[4:]
+ elif key.startswith('int'):
+ val = 'Int' + key[3:]
+ elif key.startswith('float'):
+ val = 'Float' + key[5:]
+ elif key.startswith('bfloat'):
+ val = 'BFloat' + key[6:]
else:
- raise TypeError(f'Invalid type {self}')
+ raise TypeError(f'Invalid type {key}')That makes the explicit mapping actually used (e.g., bool→Boolean) and decouples the logic from the internal representation of tvm.DataType.
Also applies to: 91-115
🏁 Script executed:
#!/bin/bash
# Quick verification: show how dtype objects stringify and confirm mapping keys
python - << 'EOF'
from tilelang import tvm
from tilelang.language.v2 import dtypes
x = tvm.DataType("bool")
print("repr:", repr(x), "str:", str(x))
print("key in _STR_TO_TVM_DTYPE_CALL:", str(x) in dtypes._STR_TO_TVM_DTYPE_CALL)
EOFLength of output: 611
I need to examine the code directly since the sandbox build is incomplete. Let me inspect the actual implementation of __dtype_call__ and the dtype class structure.
🏁 Script executed:
#!/bin/bash
# Locate and show __dtype_call__ implementation
rg -n "__dtype_call__" --type=py tilelang/language/v2/dtypes.py -A 30Length of output: 2088
🏁 Script executed:
#!/bin/bash
# Show dtype class definition
rg -n "class dtype" --type=py tilelang/language/v2/dtypes.py -A 5Length of output: 0
🏁 Script executed:
#!/bin/bash
# Check if dtype inherits from tvm.DataType
rg -n "dtype.*tvm\.DataType" --type=py -i tilelang/language/v2/dtypes.py -B 2 -A 2Length of output: 169
🏁 Script executed:
#!/bin/bash
# Verify _STR_TO_TVM_DTYPE_CALL structure and confirm keys are strings
sed -n '69,88p' tilelang/language/v2/dtypes.pyLength of output: 538
🏁 Script executed:
#!/bin/bash
# Check tvm module or other imports to understand tvm.DataType
rg -n "from.*tvm|import.*tvm" --type=py tilelang/language/v2/dtypes.py -B 1 -A 1 | head -30Length of output: 201
🏁 Script executed:
#!/bin/bash
# Look for any isinstance or type checks involving dtype or DataType
rg -n "isinstance.*dtype|isinstance.*DataType" --type=py tilelang/language/v2/dtypes.py -B 1 -A 1Length of output: 121
🏁 Script executed:
#!/bin/bash
# Search for startswith usage on self or dtype in dtypes.py
rg -n "startswith" --type=py tilelang/language/v2/dtypes.py -B 2 -A 2Length of output: 389
🏁 Script executed:
#!/bin/bash
# Check for __str__ or __repr__ method on dtype class
rg -n "def __str__|def __repr__" --type=py tilelang/language/v2/dtypes.py -A 3Length of output: 0
🏁 Script executed:
#!/bin/bash
# Show lines 91-115 mentioned as "also applies to"
sed -n '91,115p' tilelang/language/v2/dtypes.pyLength of output: 979
🏁 Script executed:
#!/bin/bash
# Check if tvm.DataType is aliased or wrapped somewhere
rg -n "^dtype\s*=|^DataType\s*=" --type=py tilelang/language/v2/dtypes.pyLength of output: 22
🏁 Script executed:
#!/bin/bash
# Look at dtype usage within the file to understand its nature
rg -nP ":\s*dtype" --type=py tilelang/language/v2/dtypes.py | head -20Length of output: 84
🏁 Script executed:
#!/bin/bash
# Check if dict membership checking works on DataType by looking for similar patterns
rg -n " in _STR_TO" --type=py tilelang/language/v2/dtypes.py -B 2 -A 2Length of output: 223
🏁 Script executed:
#!/bin/bash
# Find any definitions or aliases of dtype at module level
rg -n "^dtype" --type=py tilelang/language/v2/dtypes.py -A 1Length of output: 163
🏁 Script executed:
#!/bin/bash
# Examine the entire dtypes.py file to understand structure
wc -l tilelang/language/v2/dtypes.py && head -50 tilelang/language/v2/dtypes.pyLength of output: 1230
🏁 Script executed:
#!/bin/bash
# Check TypeAlias or type annotation for dtype
rg -n "dtype.*=|dtype.*:" --type=py tilelang/language/v2/dtypes.py | grep -E "^[0-9]+:(dtype\s*=|dtype.*:)" | head -10Length of output: 22
🏁 Script executed:
#!/bin/bash
# Look at imports more comprehensively
head -30 tilelang/language/v2/dtypes.pyLength of output: 693
🏁 Script executed:
#!/bin/bash
# Check type annotations and see what dtype refers to
ast-grep --pattern 'def __dtype_call__(self: $_, $_) { $$$ }'Length of output: 0
🏁 Script executed:
#!/bin/bash
# Look for class or TypeVar definitions
grep -n "class\|TypeVar\|dtype\s*=" tilelang/language/v2/dtypes.py | head -30Length of output: 986
🏁 Script executed:
#!/bin/bash
# Simpler approach: show full context around __dtype_call__
sed -n '80,120p' tilelang/language/v2/dtypes.pyLength of output: 1346
🏁 Script executed:
#!/bin/bash
# Show all lines mentioning "__dtype_call__" and surrounding context
grep -n "__dtype_call__" tilelang/language/v2/dtypes.py && sed -n '80,120p' tilelang/language/v2/dtypes.pyLength of output: 1467
🏁 Script executed:
#!/bin/bash
# Get context: what is dtype?
python3 << 'PYEOF'
import ast
with open('tilelang/language/v2/dtypes.py', 'r') as f:
content = f.read()
tree = ast.parse(content)
# Find dtype definition
for node in ast.walk(tree):
if isinstance(node, (ast.ClassDef, ast.TypeAlias)) and 'dtype' in str(getattr(node, 'name', '')):
print(f"Found: {type(node).__name__} - {getattr(node, 'name', 'unknown')}")
if hasattr(node, 'lineno'):
print(f" Line: {node.lineno}")
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == 'dtype':
print(f"Assignment to 'dtype' at line {node.lineno}")
PYEOFLength of output: 130
🏁 Script executed:
#!/bin/bash
# Direct inspection of the file to find dtype and __dtype_call__
cat tilelang/language/v2/dtypes.py | grep -n "dtype\|__dtype_call__" | head -40Length of output: 1547
🏁 Script executed:
#!/bin/bash
# Just show the whole file
cat tilelang/language/v2/dtypes.pyLength of output: 18570
🏁 Script executed:
#!/bin/bash
# Final check: find method decorated or defined that handles dtype calls
rg -n "def.*dtype" --type=py tilelang/language/v2/dtypes.pyLength of output: 187
🏁 Script executed:
#!/bin/bash
# Look for where 'key' or string conversion happens with self
rg -n "str(self)" --type=py tilelang/language/v2/dtypes.pyLength of output: 0
🏁 Script executed:
#!/bin/bash
# Search for int membership checks on dicts
rg -n " in _" --type=py tilelang/language/v2/dtypes.py -B 1 -A 1Length of output: 586
Now I need to verify that this code path is actually reachable and understand how tvm.DataType stringifies. Let me check for actual usage patterns:Excellent. From the web search, I found that tvm.DataType (in tvm-ffi) behaves like a Python string and carries an internal FFI representation. This means tvm.DataType can be used with string operations, but the key insight is still that it's a different type than a plain string.
Now let me verify one more thing: check if there are any callers of dtype_call to understand the code path:Perfect. Now I have all the information I need. The code inspection is conclusive. Let me verify one more critical detail - whether the code path is even reachable:
🏁 Script executed:
#!/bin/bash
# Check if there are any calls to the dtype object's __call__ method
# Look for patterns like dtype(...)(expr)
rg -n "dtype\(" --type=py tilelang/language/v2/ | grep -v "test" | head -30Length of output: 2002
🏁 Script executed:
#!/bin/bash
# Look for places where __dtype_call__ would be invoked via dtype(...)(expr) pattern
rg -n "\)(\s*\(" --type=py tilelang/language/v2/dtypes.py -A 1Length of output: 61
Use str(self) for dispatch in __dtype_call__ so the FFI mapping actually applies
The __dtype_call__ method at lines 91–114 has a critical bug:
- Line 92 checks
if self in _STR_TO_TVM_DTYPE_CALL:, butselfis atvm.DataTypeinstance while the dict keys (lines 69–88) are strings like'bool','int8','int32', etc. This membership test will never succeed, making the explicit mapping (lines 92–95) dead code. - Lines 97–104 then use
self.startswith(...), which assumestvm.DataTypebehaves like a string. This coupling is fragile and error-prone.
The fix is to normalize to a string once at the start and use that for all subsequent operations:
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var:
- if self in _STR_TO_TVM_DTYPE_CALL:
- attr = _STR_TO_TVM_DTYPE_CALL[self]
- call = getattr(tb_ffi, attr, None)
- return call(expr, is_size_var)
- # try to construct the ffi call
- if self.startswith('uint'):
- val = 'UInt' + self[4:]
- elif self.startswith('int'):
- val = 'Int' + self[3:]
- elif self.startswith('float'):
- val = 'Float' + self[5:]
- elif self.startswith('bfloat'):
- val = 'BFloat' + self[6:]
+ key = str(self)
+ if key in _STR_TO_TVM_DTYPE_CALL:
+ attr = _STR_TO_TVM_DTYPE_CALL[key]
+ call = getattr(tb_ffi, attr, None)
+ return call(expr, is_size_var)
+ # try to construct the ffi call
+ if key.startswith('uint'):
+ val = 'UInt' + key[4:]
+ elif key.startswith('int'):
+ val = 'Int' + key[3:]
+ elif key.startswith('float'):
+ val = 'Float' + key[5:]
+ elif key.startswith('bfloat'):
+ val = 'BFloat' + key[6:]
else:
- raise TypeError(f'Invalid type {self}')
+ raise TypeError(f'Invalid type {key}')This ensures the explicit mappings (e.g., bool→Boolean) are actually used and decouples the logic from the internal representation of tvm.DataType.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
tilelang/language/v2/dtypes.py lines 69-114: The explicit FFI mapping keys are
strings but the method checks and string operations use self (a tvm.DataType),
so the mapping branch never triggers and string methods are fragile; fix by
converting self to a single string once (e.g., s = str(self)) at the start of
__dtype_call__ and then use s for the membership test against
_STR_TO_TVM_DTYPE_CALL and for any startswith checks, returning the mapped value
or falling back to the existing logic using that normalized string.
triteo_linear.py
Outdated
| # n = 2 ** 25 | ||
| B = 8 | ||
| t = 2**11 | ||
| D = 128 | ||
| k = torch.randn(B,t,D, dtype=torch.float32, device='cuda') | ||
| s = torch.softmax(torch.randn(B,t,3, dtype=torch.float32, device='cuda'),dim=-1) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Move heavy tensor allocation, kernel runs, and profiling under a __main__ guard
This module currently executes quite a bit of work on import:
- Allocates large CUDA tensors
kands. - Constructs kernels for multiple
blk_tvalues. - Runs a correctness assertion and a profiling loop with
torch.profiler.
That’s appropriate for a standalone benchmark script, but surprising if this file is ever imported as a library module.
Consider wrapping all of this in an entrypoint block, or moving it into an examples/ or benchmarks/ script:
if __name__ == "__main__":
# set up B, t, D, k, s
# run reference + TileLang kernels
# assert correctness and profileThis keeps imports cheap and side-effect free while preserving the benchmark.
Also applies to: 115-128
🏁 Script executed:
#!/bin/bash
# Quick check: confirm that importing triteo_linear today would execute the profiling path
python - << 'EOF'
import inspect
import triteo_linear
print("Top-level names:", [n for n in dir(triteo_linear) if not n.startswith("_")])
print("Has __name__ guard:", "__main__" in inspect.getsource(triteo_linear))
EOFLength of output: 697
I'll inspect the file directly since importing failed due to build dependencies. Let me read the triteo_linear.py file to check for the __main__ guard and identify what code executes at the module level.
🏁 Script executed:
#!/bin/bash
# Read the triteo_linear.py file to check __main__ guard and module-level code
cat triteo_linear.py | head -150Length of output: 4501
🏁 Script executed:
#!/bin/bash
# Check for __main__ guard presence
grep -n "if __name__" triteo_linear.py || echo "No __main__ guard found"Length of output: 23
🏁 Script executed:
#!/bin/bash
# Count total lines to understand file size
wc -l triteo_linear.pyLength of output: 20
🏁 Script executed:
#!/bin/bash
# Show lines 100-150 to see the mentioned section (115-128)
sed -n '100,150p' triteo_linear.pyLength of output: 1137
Wrap heavy tensor allocation, kernel construction, and profiling in a __main__ guard
Verified: This module executes expensive operations at import time:
- Lines 5–11 allocate large CUDA tensors (
k,s) - Line ~85 computes
o_torchviamake_first_recurrent()loop - Lines 115–127 construct kernels for each
blk_tvalue, run a correctness assertion, and profile 10 iterations withtorch.profiler.profile()
No if __name__ guard exists in the file. This means importing triteo_linear immediately triggers all compute, which is inappropriate for a library module.
Move this section into an entrypoint:
if __name__ == "__main__":
# set up B, t, D, k, s
# compute o_torch reference
# run kernels, assert correctness, profile🤖 Prompt for AI Agents
In triteo_linear.py around lines 5-11 (and also the compute at ~line 85 and the
kernel construction/profiling/assertions at ~lines 115-127), heavy CUDA tensor
allocation, kernel construction and profiling run at import time; move all
runtime work into an entrypoint by wrapping the B, t, D, k, s setup, the call
that computes o_torch (make_first_recurrent loop), the per-blk_t kernel
construction, correctness assertion and torch.profiler.profile runs inside an if
__name__ == "__main__": block so importing the module no longer triggers
expensive computation; keep module-level imports and function/class definitions
at top-level, only relocate the runtime/test code into that guard.
triteo_linear.py
Outdated
| for i0_t in T.serial(blk_t): | ||
| t_local = i0_t*blk_t + i0_t | ||
| #先存第一行也就是栈顶,到输出的o里面 | ||
| o[i_b,t_local,i_d] = S_temp[0] | ||
| #再做三对角,实际上也就是相邻行的加权求和 | ||
| down = s[i_b,t_local,0] | ||
| mid = s[i_b,t_local,1] | ||
| up = s[i_b,t_local,2] | ||
| for i0_d in T.Parallel(d-1): | ||
| S_down[i0_d + 1] = S_temp[i0_d] * down | ||
| for i0_d in T.Parallel(d-1): | ||
| S_up[i0_d] = S_temp[i0_d + 1] * up | ||
| for i0_d in T.Parallel(d): | ||
| S_mid[i0_d] = S_temp[i0_d] * mid | ||
| S_down[0] = 0 | ||
| S_up[d-1] = 0 | ||
| for i0_d in T.Parallel(d): | ||
| S_temp[i0_d] += S_mid[i0_d] | ||
| S_temp[i0_d] += S_down[i0_d] | ||
| S_temp[i0_d] += S_up[i0_d] | ||
| #往栈顶写入当前的k | ||
| S_temp[0] += down * k[i_b,t_local,i_d] | ||
| # 存储当前block最终的状态S,留作未来计算 | ||
| for i0_d in T.Parallel(d): | ||
| S[i_b,i_t,i0_d,i_d] = S_temp[i0_d] | ||
| return inner_chunk_recurrent_fwd_init0_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Fix t_local computation and explicitly handle/guard non-multiple-of-blk_t lengths
Inside inner_chunk_recurrent_fwd_init0_:
for i0_t in T.serial(blk_t):
t_local = i0_t*blk_t + i0_t
...
o[i_b, t_local, i_d] = S_temp[0]
...
down = s[i_b, t_local, 0]
...
S_temp[0] += down * k[i_b, t_local, i_d]Issues:
t_localdoesn’t depend oni_t(the block index fromwith T.Kernel(b * d, T.ceildiv(t, blk_t)) as (i_bd, i_t)), so every block iterates over the same subset of time steps. That both duplicates work and leaves part of the[0, t)range untouched.- For general
t, this also assumest % blk_t == 0; otherwiset_localcan exceedt - 1and index out of bounds onk,s, ando. Similarly,Sis shaped(b, t//blk_t, d, d)whilei_tranges up toT.ceildiv(t, blk_t) - 1, soi_tcan overrun the second dimension whentis not an exact multiple.
You probably want something along the lines of:
- with T.Kernel(b * d, T.ceildiv(t, blk_t)) as (i_bd, i_t):
+ with T.Kernel(b * d, T.ceildiv(t, blk_t)) as (i_bd, i_t):
i_b = i_bd // d
i_d = i_bd % d
...
- for i0_t in T.serial(blk_t):
- t_local = i0_t*blk_t + i0_t
+ for i0_t in T.serial(blk_t):
+ t_local = i_t * blk_t + i0_t
+ if t_local >= t:
+ break # or equivalently clamp the loop extent for the last block(Adjust the guard to whatever control flow TileLang’s T.serial supports—e.g., by computing an effective block length instead of using break.)
Additionally, if the kernel is only intended for t % blk_t == 0, it would be good to document that precondition or assert on it at the Python wrapper level.
🏁 Script executed:
#!/bin/bash
# Minimal sanity check: scan for t_local and block usage in this kernel
rg -n "inner_chunk_recurrent_fwd_init0_" -n triteo_linear.py -n -A15 -B5 || cat triteo_linear.pyLength of output: 1691
I can see the critical issues from the partial output. The key observations are:
-
Shape mismatch confirmed: Line 70 shows
S: T.Tensor((b, t//blk_t, d, d))uses integer division, but line 76 shows the kernel grid isT.ceildiv(t, blk_t)(ceiling division). Whent % blk_t ≠ 0, the last block'si_tindex exceeds the valid range. -
Top-level side effects confirmed: Lines 116–127 show that kernel compilation, tensor allocation, and profiling loops execute on module import—this should not happen at import time.
-
t_localformula unconfirmed: The output was truncated (lines 83–107 hidden), so I can't yet see the exactt_localcomputation. Let me retrieve it.
🏁 Script executed:
#!/bin/bash
# Get the full inner_chunk_recurrent_fwd_init0_ function with no truncation
sed -n '69,113p' triteo_linear.pyLength of output: 1890
Fix critical t_local computation and S tensor shape mismatch
The kernel has two critical bugs:
-
Line 89:
t_localformula is wrong. It should bet_local = i_t * blk_t + i0_t, noti0_t*blk_t + i0_t. As currently written, every block processes the same (wrong) time indices, duplicating work and missing parts of the input. -
Line 70 vs. line 76 shape mismatch. S is allocated as
(b, t//blk_t, d, d)using integer division, but the kernel grid dimension isT.ceildiv(t, blk_t). Whent % blk_t ≠ 0, the last block'si_tindex exceeds the valid S dimension, causing out-of-bounds writes at line 112. -
Lines 116–127: Remove top-level kernel execution and profiling. Kernel compilation, tensor allocation, and profiling should not execute on module import; move this code to an example script or guarded test.
Suggested fix:
- Change line 89 to:
t_local = i_t * blk_t + i0_t - Either change S allocation to
(b, T.ceildiv(t, blk_t), d, d)or add a bounds check to skip iterations whent_local >= t. - Move lines 116–127 to a separate example or guard them behind
if __name__ == "__main__":.
🧰 Tools
🪛 Ruff (0.14.4)
90-90: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
92-92: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
🤖 Prompt for AI Agents
In triteo_linear.py around lines 88–113: the local time index calculation uses
the wrong variables so blocks compute incorrect time indices — replace the
current t_local expression so it uses the block index i_t combined with the
inner loop index i0_t (i.e., compute t_local from i_t and i0_t, not from i0_t
twice); fix the S tensor allocation vs. kernel grid mismatch by either
allocating S with ceildiv over t (so its second dim equals the number of blocks)
or add a bounds check that skips iterations when the computed t_local is >= t to
avoid out-of-bounds writes; and remove or guard the top-level kernel
compilation/allocation/profiling code (lines ~116–127) so it only runs under an
example/test entrypoint (for example inside an if __name__ == "__main__" or
moved to a separate script).
This pr introduces conversion from numpy data types into T.dtype
Summary by CodeRabbit
New Features
Refactor
Chores
Tests