33from typing import Any , List , Optional , Tuple , Union
44
55
6- from .arith import constant
6+ from . import arith as arith_ext
77from .func import FuncBase
88from ... import types as T
99from ...meta import (
1616 make_maybe_no_args_decorator ,
1717 find_ops ,
1818)
19- from ....dialects ._gpu_ops_gen import _Dialect
19+ from ....dialects import _gpu_ops_gen
20+ from ....dialects ._gpu_ops_gen import (
21+ _Dialect ,
22+ GPUFuncOp as _GPUFuncOp ,
23+ LaunchOp as _LaunchOp ,
24+ LaunchFuncOp as _LaunchFuncOp ,
25+ )
2026from ....dialects ._ods_common import (
2127 _cext ,
2228 get_default_loc_context ,
@@ -223,7 +229,8 @@ def __prepare__(cls, name, bases, **kwargs):
223229 return {"ip" : ip , "gpu_module_op" : gpu_module_op }
224230
225231
226- class GPUFuncOp (GPUFuncOp ):
232+ # TODO(max): integrate upstream
233+ class GPUFuncOp (_GPUFuncOp ):
227234 def __init__ (
228235 self ,
229236 sym_name ,
@@ -269,7 +276,8 @@ def __init__(
269276 )
270277
271278
272- class LaunchOp (LaunchOp ):
279+ # TODO(max): integrate upstream
280+ class LaunchOp (_LaunchOp ):
273281 def __init__ (
274282 self ,
275283 grid_size : Tuple [Any , Any , Any ],
@@ -321,7 +329,7 @@ def launch_(
321329 for size in [grid_size , block_size ]:
322330 for i , s in enumerate (size ):
323331 if isinstance (s , int ):
324- size [i ] = constant (s , index = True )
332+ size [i ] = arith_ext . constant (s , index = True )
325333 launch_op = LaunchOp (
326334 grid_size ,
327335 block_size ,
@@ -336,7 +344,8 @@ def launch_(
336344launch = region_op (launch_ , terminator = lambda * _args : TerminatorOp ())
337345
338346
339- class LaunchFuncOp (LaunchFuncOp ):
347+ # TODO(max): integrate upstream
348+ class LaunchFuncOp (_LaunchFuncOp ):
340349 def __init__ (
341350 self ,
342351 kernel : List [str ],
@@ -394,7 +403,7 @@ def __call__(
394403 for size in [grid_size , block_size ]:
395404 for i , s in enumerate (size ):
396405 if isinstance (s , int ):
397- size [i ] = constant (s , index = True )
406+ size [i ] = arith_ext . constant (s , index = True )
398407
399408 if loc is None :
400409 loc = get_user_code_loc ()
@@ -596,15 +605,6 @@ def get_compile_object_bytes(compiled_module):
596605 return objects [- 1 ].object
597606
598607
599- _printf = printf
600-
601-
602- def printf (format , * args , loc = None , ip = None ):
603- if loc is None :
604- loc = get_user_code_loc ()
605- return _printf (format = format , args = args , loc = loc , ip = ip )
606-
607-
608608_dynamic_shared_memory = dynamic_shared_memory
609609
610610
@@ -634,7 +634,7 @@ def memset(dst, value, async_dependencies=None, *, loc=None, ip=None):
634634 if len (async_dependencies ):
635635 async_token = gpu_async_token ()
636636 if isinstance (value , (int , float , bool )):
637- value = constant (value , type = dst .type .element_type )
637+ value = arith_ext . constant (value , type = dst .type .element_type )
638638 return _memset (async_token , async_dependencies , dst , value , loc = loc , ip = ip )
639639
640640
0 commit comments