11import ast
2- import logging
3-
2+ import enum
43import inspect
4+ import logging
55import types
66from abc import ABC
7+ from opcode import opmap
78from types import CodeType
8- from typing import Type , Optional
99
1010import libcst as cst
1111from bytecode import ConcreteBytecode
1212from libcst .matchers import MatcherDecoratableTransformer
13- from libcst .metadata import ParentNodeProvider
1413
1514from mlir_utils .ast .util import get_module_cst , copy_func
1615
1716logger = logging .getLogger (__name__ )
1817
1918
20- class FuncIdentTypeTable (cst .CSTVisitor ):
21- METADATA_DEPENDENCIES = (ParentNodeProvider ,)
22-
23- def __init__ (self , f ):
24- super ().__init__ ()
25- self .ident_type : dict [str , Type ] = {}
26- module_cst = get_module_cst (f )
27- wrapper = cst .MetadataWrapper (module_cst )
28- wrapper .visit (self )
29-
30- def visit_Annotation (self , node : cst .Annotation ) -> Optional [bool ]:
31- parent = self .get_metadata (ParentNodeProvider , node )
32- if isinstance (node .annotation , (cst .Tuple , cst .List )):
33- self .ident_type [parent .target .value ] = [
34- e .value .value for e in node .annotation .elements
35- ]
36- else :
37- self .ident_type [parent .target .value ] = [node .annotation .value ]
38-
39- def __getitem__ (self , ident ):
40- return self .ident_type [ident ]
41-
42-
4319class Transformer (MatcherDecoratableTransformer ):
44- def __init__ (self , context , func_sym_table : FuncIdentTypeTable ):
20+ def __init__ (self , context ):
4521 super ().__init__ ()
4622 self .context = context
47- self .func_sym_table = func_sym_table
4823
4924
5025class StrictTransformer (Transformer ):
5126 def visit_FunctionDef (self , node : cst .FunctionDef ):
5227 return False
5328
5429
55- def transform_cst (f , transformers : list [type (StrictTransformer )] = None ):
30+ def transform_cst (
31+ f , transformers : list [type (Transformer ) | type (StrictTransformer )] = None
32+ ):
5633 if transformers is None :
5734 return f
5835
5936 module_cst = get_module_cst (f )
60- func_sym_table = FuncIdentTypeTable (f )
6137 context = types .SimpleNamespace ()
6238 for transformer in transformers :
6339 func_node = module_cst .body [0 ]
64- replace = transformer (context , func_sym_table )
40+ replace = transformer (context )
6541 new_func = func_node ._visit_and_replace_children (replace )
6642 module_cst = module_cst .deep_replace (func_node , new_func )
6743
@@ -79,6 +55,23 @@ def transform_cst(f, transformers: list[type(StrictTransformer)] = None):
7955 return copy_func (f , new_f_code_o )
8056
8157
58+ # this is like this because i couldn't figure out how to subclass
59+ # Enum and simultaneously pass in opmap
60+ OpCode = enum .Enum ("OpCode" , opmap )
61+
62+
63+ def to_int (self : OpCode ):
64+ return self .value
65+
66+
67+ def to_str (self : OpCode ):
68+ return self .name
69+
70+
71+ setattr (OpCode , "__int__" , to_int )
72+ setattr (OpCode , "__str__" , to_str )
73+
74+
8275class BytecodePatcher (ABC ):
8376 def __init__ (self , context = None ):
8477 self .context = context
@@ -109,10 +102,10 @@ def bytecode_patchers(self) -> list[BytecodePatcher]:
109102 pass
110103
111104
112- def canonicalize (* , with_ : Canonicalizer ):
105+ def canonicalize (* , using : Canonicalizer ):
113106 def wrapper (f ):
114- f = transform_cst (f , with_ .cst_transformers )
115- f = patch_bytecode (f , with_ .bytecode_patchers )
107+ f = transform_cst (f , using .cst_transformers )
108+ f = patch_bytecode (f , using .bytecode_patchers )
116109 return f
117110
118111 return wrapper
0 commit comments