-
Notifications
You must be signed in to change notification settings - Fork 321
[Language] Add type stubs for tir op #1239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughA new Python type stub file is introduced, defining comprehensive type signatures for TVM TIR IR operations. The file includes generic-typed function declarations covering arithmetic, trigonometric, bitwise, memory, tensor synchronization, and barrier operations, enabling static type checking for TIR API usage. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🧹 Nitpick comments (2)
tilelang/language/tir/ir.pyi (2)
4-4: Consider constraining the TypeVar for better type safety.The unbounded
_TTypeVar allows any type to be passed to functions that semantically expect numeric expressions. Consider constraining it toPrimExpror defining multiple TypeVars for different use cases (e.g., one for expression operations, one for generic returns).Example:
-_T = TypeVar('_T') +_T = TypeVar('_T', bound=PrimExpr)Or define multiple TypeVars for different purposes:
-_T = TypeVar('_T') +_ExprT = TypeVar('_ExprT', bound=PrimExpr) +_T = TypeVar('_T')
100-101: Parameter name shadows built-in function.The parameter name
idshadows Python's built-inid()function, which can cause confusion and potential issues.Consider renaming to
profile_idor similar:-def start_profile_intrinsic(id: int) -> PrimExpr: ... -def end_profile_intrinsic(id: int) -> PrimExpr: ... +def start_profile_intrinsic(profile_id: int) -> PrimExpr: ... +def end_profile_intrinsic(profile_id: int) -> PrimExpr: ...
| def fmod(x: _T, y: _T) -> _T: ... | ||
| def hypot(x1: _T, x2: _T) -> _T: ... | ||
| def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ... | ||
| def infinity(dtype: _T, span: Span | None=None) -> _T: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect type annotation for dtype parameter.
The dtype parameter is typed as _T, but based on the parameter name and typical usage, it should be str.
Apply this diff:
-def infinity(dtype: _T, span: Span | None=None) -> _T: ...
+def infinity(dtype: str, span: Span | None=None) -> PrimExpr: ...📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def infinity(dtype: _T, span: Span | None=None) -> _T: ... | |
| def infinity(dtype: str, span: Span | None=None) -> PrimExpr: ... |
🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around line 35, the dtype parameter is
incorrectly annotated as _T; change the dtype parameter's type annotation from
_T to str (i.e., make the signature use dtype: str) while keeping the rest of
the signature intact so callers see the correct parameter type.
| def shift_left(x: _T, y: _T, span=None) -> _T: ... | ||
| def shift_right(x: _T, y: _T, span=None) -> _T: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add type annotation for span parameter.
The span parameter lacks a type annotation, unlike all other functions with the same parameter. This breaks consistency and reduces type safety.
Apply this diff:
-def shift_left(x: _T, y: _T, span=None) -> _T: ...
-def shift_right(x: _T, y: _T, span=None) -> _T: ...
+def shift_left(x: _T, y: _T, span: Span | None=None) -> _T: ...
+def shift_right(x: _T, y: _T, span: Span | None=None) -> _T: ...📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def shift_left(x: _T, y: _T, span=None) -> _T: ... | |
| def shift_right(x: _T, y: _T, span=None) -> _T: ... | |
| def shift_left(x: _T, y: _T, span: Span | None=None) -> _T: ... | |
| def shift_right(x: _T, y: _T, span: Span | None=None) -> _T: ... |
🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around lines 58-59, the span parameter is
missing a type annotation; update both function signatures to add an explicit
optional Span type (e.g. span: Span | None = None or span: Optional[Span] = None
to match project typing style) so they match other functions with span and
restore type consistency.
| def trunc(x: _T, span: Span | None=None) -> _T: ... | ||
| def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ... | ||
| def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ... | ||
| def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Add type annotation for the data parameter.
The data parameter lacks a type annotation, reducing type safety and consistency with other parameters in the file.
Verify the expected type from the runtime implementation to add the correct annotation (likely PrimExpr or Var).
🏁 Script executed:
#!/bin/bash
# Find the runtime implementation to determine the correct type for the data parameter
ast-grep --pattern 'def tvm_access_ptr($$$)'Length of output: 1513
---
🏁 Script executed:
```shell
#!/bin/bash
# Search for calls to tvm_access_ptr to see what types are passed for the data parameter
rg -n 'tvm_access_ptr\s*\(' -A 2 -B 2 --type py
Length of output: 15734
🏁 Script executed:
#!/bin/bash
# Check for type imports and similar buffer/pointer-like types in the file
rg -n 'Buffer|Var|Handle' tilelang/language/tir/ir.pyi | head -20Length of output: 1180
🏁 Script executed:
#!/bin/bash
# Look at other function signatures in the stub file for patterns with similar parameters
rg -n 'def.*data.*:' tilelang/language/tir/ir.pyi -A 1 | head -30Length of output: 395
Add type annotation for the data parameter: use PrimExpr.
The data parameter lacks a type annotation. Based on the docstring in the implementation and usage patterns across test files where buffer .data attributes are passed (e.g., A_shared.data), the type should be PrimExpr, consistent with similar functions in the file like tvm_stack_make_array.
🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around line 69, the tvm_access_ptr declaration
is missing a type annotation for the data parameter; update the function
signature so the data parameter is annotated as PrimExpr (matching usage of
buffer .data and other functions like tvm_stack_make_array) so the stub
accurately reflects expected types.
| def tvm_throw_last_error() -> _T: ... | ||
| def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ... | ||
| def tvm_stack_make_shape(*args) -> _T: ... | ||
| def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Add type annotation for the elem_offset parameter.
The elem_offset parameter lacks a type annotation.
Verify the expected type from the runtime implementation.
🏁 Script executed:
#!/bin/bash
# Find the runtime implementation
ast-grep --pattern 'def tvm_stack_make_array($$$)'Length of output: 1753
---
🏁 Script executed:
```shell
#!/bin/bash
# Check imports and type definitions in ir.pyi
head -100 tilelang/language/tir/ir.pyi | grep -E "^(import|from|class PrimExpr|class Expr)"
Length of output: 100
🏁 Script executed:
#!/bin/bash
# Search for PrimExpr and Expr type definitions/imports
rg -n "(class PrimExpr|class Expr|from.*import.*PrimExpr|from.*import.*Expr)" tilelang/language/tir/ir.pyi -A2 -B2Length of output: 128
🏁 Script executed:
#!/bin/bash
# Check what type annotations are used in similar functions
rg -n "def.*Expr.*Expr.*->" tilelang/language/tir/ir.pyi | head -20Length of output: 1026
Add type annotation PrimExpr for the elem_offset parameter.
The elem_offset parameter lacks a type annotation. Based on the stub file's import scope (only PrimExpr is imported from tvm.tir.expr) and the consistent pattern used across all other parameters in this function and similar functions in the file, the type should be PrimExpr.
def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset: PrimExpr) -> PrimExpr: ...🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around line 73, the function
tvm_stack_make_array is missing a type annotation for the elem_offset parameter;
add the annotation elem_offset: PrimExpr to match the other parameters and the
file's import usage so the signature becomes def tvm_stack_make_array(data:
PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype:
PrimExpr, elem_offset: PrimExpr) -> PrimExpr: ...
| def call_packed_lowered(*args, span=None) -> _T: ... | ||
| def call_cpacked_lowered(*args, span=None) -> _T: ... | ||
| def tvm_tuple(*value) -> _T: ... | ||
| def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Add type annotation for the arr parameter.
The arr parameter lacks a type annotation.
Verify the expected type from the runtime implementation.
🏁 Script executed:
#!/bin/bash
# Find the runtime implementation
ast-grep --pattern 'def tvm_struct_set($$$)'Length of output: 1309
Wait, that won't work without cloning. Let me check properly:
---
🏁 Script executed:
```shell
#!/bin/bash
# Inspect imports and related type definitions in the stub file
head -100 tilelang/language/tir/ir.pyi | cat
Length of output: 5583
🏁 Script executed:
#!/bin/bash
# Inspect the top of the stub file for imports and type definitions
head -100 tilelang/language/tir/ir.pyiLength of output: 5583
🏁 Script executed:
#!/bin/bash
# Search for StructType or similar types in ir.pyi
rg -n "StructType|class.*Type" tilelang/language/tir/ir.pyi | head -20Length of output: 0
🏁 Script executed:
#!/bin/bash
# Check imports at the top of ir.pyi
rg -n "^from|^import" tilelang/language/tir/ir.pyi | head -20Length of output: 104
🏁 Script executed:
#!/bin/bash
# Search for StructType definition or usage in codebase
rg -n "StructType" tilelang/language/tir/ --type py --type pyiLength of output: 27
🏁 Script executed:
#!/bin/bash
# Examine the implementation context of tvm_struct_set in op.py
sed -n '439,465p' tilelang/language/tir/op.pyLength of output: 556
🏁 Script executed:
#!/bin/bash
# Search for struct-related functions in op.py to find usage patterns
rg -n "struct" tilelang/language/tir/op.py -i | head -20Length of output: 1053
🏁 Script executed:
#!/bin/bash
# Look for tvm_struct_get (companion function) to see if it has type info
rg -n "tvm_struct_get" tilelang/language/tir/ -A 10Length of output: 2115
Add type annotation for the arr parameter.
The arr parameter lacks a type annotation. Based on the runtime implementation and function signature, it should be typed as PrimExpr:
def tvm_struct_set(arr: PrimExpr, index: int, field: int, value: PrimExpr) -> PrimExpr: ...🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around line 80, the tvm_struct_set function
signature is missing a type annotation for the arr parameter; update the
declaration to annotate arr as PrimExpr so the signature reads with arr:
PrimExpr, e.g. change the parameter list to include arr: PrimExpr and keep the
rest unchanged.
| def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ... | ||
| def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ... | ||
| def create_barriers(barrier_count: int) -> PrimExpr: ... | ||
| def assume(cond: _T=None) -> _T: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Problematic default value with generic TypeVar.
Using None as a default value with an unbounded _T is problematic because None may not be a valid instance of _T.
Consider one of these fixes:
-def assume(cond: _T=None) -> _T: ...
+def assume(cond: _T | None=None) -> _T: ...Or use Optional:
+from typing import Optional
-def assume(cond: _T=None) -> _T: ...
+def assume(cond: Optional[_T]=None) -> _T: ...📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def assume(cond: _T=None) -> _T: ... | |
| def assume(cond: _T | None=None) -> _T: ... |
🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around line 96, the signature def assume(cond:
_T=None) -> _T: uses None as a default for an unconstrained TypeVar which is
invalid; change the signature so the types match — either remove the default and
keep it typed as def assume(cond: _T) -> _T:, or make the parameter and return
optional by importing typing.Optional and using def assume(cond: Optional[_T] =
None) -> Optional[_T]: depending on whether None is a valid value for callers.
| def anylist_getitem(list_handle, index) -> PrimExpr: ... | ||
| def anylist_resetitem(list_handle, index) -> PrimExpr: ... | ||
| def anylist_setitem_call_packed(list_handle, index, func_name, *args) -> PrimExpr: ... | ||
| def anylist_setitem_call_cpacked(list_handle, index, func_name, *args) -> PrimExpr: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Add type annotations for parameters in anylist functions.
Multiple parameters (list_handle, index, func_name) lack type annotations, reducing type safety.
Verify the expected types from the runtime implementation.
🏁 Script executed:
#!/bin/bash
# Find the runtime implementations
ast-grep --pattern 'def anylist_$_($$$ )'Length of output: 0
🏁 Script executed:
#!/bin/bash
# Search for anylist function definitions in Python files
rg -n 'def anylist_' --type=pyLength of output: 358
🏁 Script executed:
#!/bin/bash
# Search for anylist references to understand usage patterns
rg -n 'anylist_' --type=py -C 2 | head -50Length of output: 2975
🏁 Script executed:
#!/bin/bash
# Check the exact structure of the .pyi file around the functions
fd -e pyi -name 'ir.pyi' -exec cat {} \;Length of output: 230
🏁 Script executed:
#!/bin/bash
# Get full docstrings from op.py implementations
sed -n '3331,3343p' tilelang/language/tir/op.py
sed -n '3345,3357p' tilelang/language/tir/op.py
sed -n '3359,3375p' tilelang/language/tir/op.py
sed -n '3377,3393p' tilelang/language/tir/op.pyLength of output: 1551
🏁 Script executed:
#!/bin/bash
# View the .pyi stub file
cat tilelang/language/tir/ir.pyiLength of output: 5950
Add type annotations for parameters in anylist functions.
The runtime implementations in tilelang/language/tir/op.py document the expected parameter types via docstrings: list_handle: Var, index: int, and func_name: str. These should be reflected in the stub file. The Var type is already imported. Add the following annotations:
def anylist_getitem(list_handle: Var, index: int) -> PrimExpr: ...
def anylist_resetitem(list_handle: Var, index: int) -> PrimExpr: ...
def anylist_setitem_call_packed(list_handle: Var, index: int, func_name: str, *args) -> PrimExpr: ...
def anylist_setitem_call_cpacked(list_handle: Var, index: int, func_name: str, *args) -> PrimExpr: ...🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around lines 102 to 105, the anylist_* function
stubs lack parameter type annotations; update each signature to annotate
list_handle as Var, index as int, and func_name as str where applicable (keep
*args untyped) so they match the runtime docstrings and existing imports; ensure
the return type remains PrimExpr.
This pr add type hints stubs for tir/op.py. It enables better typing to help user write correct code.