Skip to content

Commit 49c5f5f

Browse files
committed
defer if yield results type setting (ie use the set_type api)
1 parent 0bfe5c8 commit 49c5f5f

File tree

17 files changed

+323
-221
lines changed

17 files changed

+323
-221
lines changed

mlir_utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from ._configuration.configuration import alias_upstream_bindings
21
import atexit
32

3+
from ._configuration.configuration import alias_upstream_bindings
4+
45
if alias_upstream_bindings():
56
from mlir import ir
67

mlir_utils/_configuration/configuration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from base64 import urlsafe_b64encode
66
from importlib.metadata import distribution, packages_distributions
77
from importlib.resources import files
8-
from importlib.resources.readers import MultiplexedPath
98
from pathlib import Path
109

10+
from importlib.resources.readers import MultiplexedPath
11+
1112
from .module_alias_map import get_meta_path_insertion_index, AliasedModuleFinder
1213

1314
__MLIR_PYTHON_PACKAGE_PREFIX__ = "__MLIR_PYTHON_PACKAGE_PREFIX__"

mlir_utils/ast/canonicalize.py

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,43 @@
11
import ast
2-
import logging
3-
2+
import enum
43
import inspect
4+
import logging
55
import types
66
from abc import ABC
7+
from opcode import opmap
78
from types import CodeType
8-
from typing import Type, Optional
99

1010
import libcst as cst
1111
from bytecode import ConcreteBytecode
1212
from libcst.matchers import MatcherDecoratableTransformer
13-
from libcst.metadata import ParentNodeProvider
1413

1514
from mlir_utils.ast.util import get_module_cst, copy_func
1615

1716
logger = logging.getLogger(__name__)
1817

1918

20-
class FuncIdentTypeTable(cst.CSTVisitor):
21-
METADATA_DEPENDENCIES = (ParentNodeProvider,)
22-
23-
def __init__(self, f):
24-
super().__init__()
25-
self.ident_type: dict[str, Type] = {}
26-
module_cst = get_module_cst(f)
27-
wrapper = cst.MetadataWrapper(module_cst)
28-
wrapper.visit(self)
29-
30-
def visit_Annotation(self, node: cst.Annotation) -> Optional[bool]:
31-
parent = self.get_metadata(ParentNodeProvider, node)
32-
if isinstance(node.annotation, (cst.Tuple, cst.List)):
33-
self.ident_type[parent.target.value] = [
34-
e.value.value for e in node.annotation.elements
35-
]
36-
else:
37-
self.ident_type[parent.target.value] = [node.annotation.value]
38-
39-
def __getitem__(self, ident):
40-
return self.ident_type[ident]
41-
42-
4319
class Transformer(MatcherDecoratableTransformer):
44-
def __init__(self, context, func_sym_table: FuncIdentTypeTable):
20+
def __init__(self, context):
4521
super().__init__()
4622
self.context = context
47-
self.func_sym_table = func_sym_table
4823

4924

5025
class StrictTransformer(Transformer):
5126
def visit_FunctionDef(self, node: cst.FunctionDef):
5227
return False
5328

5429

55-
def transform_cst(f, transformers: list[type(StrictTransformer)] = None):
30+
def transform_cst(
31+
f, transformers: list[type(Transformer) | type(StrictTransformer)] = None
32+
):
5633
if transformers is None:
5734
return f
5835

5936
module_cst = get_module_cst(f)
60-
func_sym_table = FuncIdentTypeTable(f)
6137
context = types.SimpleNamespace()
6238
for transformer in transformers:
6339
func_node = module_cst.body[0]
64-
replace = transformer(context, func_sym_table)
40+
replace = transformer(context)
6541
new_func = func_node._visit_and_replace_children(replace)
6642
module_cst = module_cst.deep_replace(func_node, new_func)
6743

@@ -79,6 +55,23 @@ def transform_cst(f, transformers: list[type(StrictTransformer)] = None):
7955
return copy_func(f, new_f_code_o)
8056

8157

58+
# this is like this because i couldn't figure out how to subclass
59+
# Enum and simultaneously pass in opmap
60+
OpCode = enum.Enum("OpCode", opmap)
61+
62+
63+
def to_int(self: OpCode):
64+
return self.value
65+
66+
67+
def to_str(self: OpCode):
68+
return self.name
69+
70+
71+
setattr(OpCode, "__int__", to_int)
72+
setattr(OpCode, "__str__", to_str)
73+
74+
8275
class BytecodePatcher(ABC):
8376
def __init__(self, context=None):
8477
self.context = context
@@ -109,10 +102,10 @@ def bytecode_patchers(self) -> list[BytecodePatcher]:
109102
pass
110103

111104

112-
def canonicalize(*, with_: Canonicalizer):
105+
def canonicalize(*, using: Canonicalizer):
113106
def wrapper(f):
114-
f = transform_cst(f, with_.cst_transformers)
115-
f = patch_bytecode(f, with_.bytecode_patchers)
107+
f = transform_cst(f, using.cst_transformers)
108+
f = patch_bytecode(f, using.bytecode_patchers)
116109
return f
117110

118111
return wrapper

mlir_utils/dialects/ext/arith.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from copy import deepcopy
2-
from functools import lru_cache, partialmethod, cached_property
2+
from functools import partialmethod, cached_property
33
from typing import Union, Optional
44

55
import numpy as np
@@ -143,7 +143,7 @@ def __call__(cls, *args, **kwargs):
143143
# which by default (through the Python buffer protocol) does not copy;
144144
# see mlir/lib/Bindings/Python/IRAttributes.cpp#L556
145145
arg_copy = deepcopy(arg)
146-
val = constant(arg, dtype).result
146+
return constant(arg_copy, dtype)
147147
else:
148148
raise NotImplementedError(f"{cls.__name__} doesn't support wrapping {arg}.")
149149

@@ -155,7 +155,7 @@ def __call__(cls, *args, **kwargs):
155155
# the Python object protocol; first an object is new'ed and then
156156
# it is init'ed. Note we pass arg_copy here in case a subclass wants to
157157
# inspect the literal.
158-
cls.__init__(cls_obj, val, arg_copy)
158+
cls.__init__(cls_obj, val)
159159
return cls_obj
160160

161161

@@ -276,19 +276,12 @@ class ArithValue(Value, metaclass=ArithValueMeta):
276276
Value.__init__
277277
"""
278278

279-
def __init__(
280-
self,
281-
val,
282-
arg: Optional[Union[int, float, bool, np.ndarray]] = None,
283-
):
284-
self.__arg = arg
279+
def __init__(self, val):
285280
super().__init__(val)
286281

287-
# @lru_cache(maxsize=1)
288282
def __str__(self):
289283
return f"{self.__class__.__name__}({self.get_name()}, {self.type})"
290284

291-
# @lru_cache(maxsize=1)
292285
def __repr__(self):
293286
return str(self)
294287

mlir_utils/dialects/ext/func.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
from functools import wraps, partial
32

43
from mlir.dialects.func import FuncOp, ReturnOp, CallOp
54
from mlir.ir import (

0 commit comments

Comments
 (0)