|
18 | 18 | ) |
19 | 19 |
|
20 | 20 |
|
21 | | -def func_base( |
22 | | - FuncOp, |
23 | | - ReturnOp, |
24 | | - CallOp, |
25 | | - sym_visibility=None, |
26 | | - arg_attrs=None, |
27 | | - res_attrs=None, |
28 | | - loc=None, |
29 | | - ip=None, |
30 | | -): |
31 | | - ip = ip or InsertionPoint.current |
32 | | - |
33 | | - # if this is set to true then wrapper below won't emit a call op |
34 | | - # it is set below by a def emit fn that is attached to the body_builder |
35 | | - # wrapper; thus you can call wrapped_fn.emit() (i.e., without an operands) |
36 | | - # and the func will be emitted. |
37 | | - _emit = False |
38 | | - |
39 | | - def builder_wrapper(body_builder): |
40 | | - @wraps(body_builder) |
41 | | - def wrapper(*call_args): |
42 | | - # TODO(max): implement constexpr ie enable passing constants that skip being |
43 | | - # part of the signature |
44 | | - sig = inspect.signature(body_builder) |
45 | | - implicit_return = sig.return_annotation is inspect._empty |
46 | | - input_types = [p.annotation for p in sig.parameters.values()] |
47 | | - if not ( |
48 | | - len(input_types) == len(sig.parameters) |
49 | | - and all(isinstance(t, Type) for t in input_types) |
50 | | - ): |
51 | | - input_types = [a.type for a in call_args] |
52 | | - function_type = TypeAttr.get( |
53 | | - FunctionType.get( |
54 | | - inputs=input_types, |
55 | | - results=[] if implicit_return else sig.return_annotation, |
56 | | - ) |
| 21 | +class FuncOpMeta(type): |
| 22 | + def __call__(cls, *args, **kwargs): |
| 23 | + cls_obj = cls.__new__(cls) |
| 24 | + if len(args) == 1 and len(kwargs) == 0 and inspect.isfunction(args[0]): |
| 25 | + return cls.__init__(cls_obj, args[0]) |
| 26 | + else: |
| 27 | + |
| 28 | + def init_wrapper(f): |
| 29 | + cls.__init__(cls_obj, f, *args, **kwargs) |
| 30 | + return cls_obj |
| 31 | + |
| 32 | + return lambda f: init_wrapper(f) |
| 33 | + |
| 34 | + |
| 35 | +class FuncBase(metaclass=FuncOpMeta): |
| 36 | + def __init__( |
| 37 | + self, |
| 38 | + body_builder, |
| 39 | + func_op_ctor, |
| 40 | + return_op_ctor, |
| 41 | + call_op_ctor, |
| 42 | + sym_visibility=None, |
| 43 | + arg_attrs=None, |
| 44 | + res_attrs=None, |
| 45 | + loc=None, |
| 46 | + ip=None, |
| 47 | + ): |
| 48 | + assert inspect.isfunction(body_builder), body_builder |
| 49 | + assert inspect.isclass(func_op_ctor), func_op_ctor |
| 50 | + assert inspect.isclass(return_op_ctor), return_op_ctor |
| 51 | + assert inspect.isclass(call_op_ctor), call_op_ctor |
| 52 | + |
| 53 | + self.body_builder = body_builder |
| 54 | + self.func_name = self.body_builder.__name__ |
| 55 | + |
| 56 | + self.func_op_ctor = func_op_ctor |
| 57 | + self.return_op_ctor = return_op_ctor |
| 58 | + self.call_op_ctor = call_op_ctor |
| 59 | + self.sym_visibility = ( |
| 60 | + StringAttr.get(str(sym_visibility)) if sym_visibility is not None else None |
| 61 | + ) |
| 62 | + self.arg_attrs = arg_attrs |
| 63 | + self.res_attrs = res_attrs |
| 64 | + self.loc = loc |
| 65 | + self.ip = ip or InsertionPoint.current |
| 66 | + self.emitted = False |
| 67 | + |
| 68 | + def __str__(self): |
| 69 | + return str(f"{self.__class__} {self.__dict__}") |
| 70 | + |
| 71 | + def body_builder_wrapper(self, *call_args): |
| 72 | + sig = inspect.signature(self.body_builder) |
| 73 | + implicit_return = sig.return_annotation is inspect._empty |
| 74 | + input_types = [p.annotation for p in sig.parameters.values()] |
| 75 | + if not ( |
| 76 | + len(input_types) == len(sig.parameters) |
| 77 | + and all(isinstance(t, Type) for t in input_types) |
| 78 | + ): |
| 79 | + input_types = [a.type for a in call_args] |
| 80 | + function_type = TypeAttr.get( |
| 81 | + FunctionType.get( |
| 82 | + inputs=input_types, |
| 83 | + results=[] if implicit_return else sig.return_annotation, |
57 | 84 | ) |
58 | | - # FuncOp is extended but we do really want the base |
59 | | - func_name = body_builder.__name__ |
60 | | - func_op = FuncOp( |
61 | | - func_name, |
62 | | - function_type, |
63 | | - sym_visibility=StringAttr.get(str(sym_visibility)) |
64 | | - if sym_visibility is not None |
65 | | - else None, |
66 | | - arg_attrs=arg_attrs, |
67 | | - res_attrs=res_attrs, |
68 | | - loc=loc, |
69 | | - ip=ip, |
| 85 | + ) |
| 86 | + func_op = self.func_op_ctor( |
| 87 | + self.func_name, |
| 88 | + function_type, |
| 89 | + sym_visibility=self.sym_visibility, |
| 90 | + arg_attrs=self.arg_attrs, |
| 91 | + res_attrs=self.res_attrs, |
| 92 | + loc=self.loc, |
| 93 | + ip=self.ip, |
| 94 | + ) |
| 95 | + func_op.regions[0].blocks.append(*input_types) |
| 96 | + with InsertionPoint(func_op.regions[0].blocks[0]): |
| 97 | + results = get_result_or_results( |
| 98 | + self.body_builder(*func_op.regions[0].blocks[0].arguments) |
70 | 99 | ) |
71 | | - func_op.regions[0].blocks.append(*input_types) |
72 | | - with InsertionPoint(func_op.regions[0].blocks[0]): |
73 | | - results = get_result_or_results( |
74 | | - body_builder(*func_op.regions[0].blocks[0].arguments) |
75 | | - ) |
76 | | - if results is not None: |
77 | | - if isinstance(results, (tuple, list)): |
78 | | - results = list(results) |
79 | | - else: |
80 | | - results = [results] |
| 100 | + if results is not None: |
| 101 | + if isinstance(results, (tuple, list)): |
| 102 | + results = list(results) |
81 | 103 | else: |
82 | | - results = [] |
83 | | - ReturnOp(results) |
84 | | - # Recompute the function type. |
85 | | - return_types = [v.type for v in results] |
86 | | - function_type = FunctionType.get(inputs=input_types, results=return_types) |
87 | | - func_op.attributes["function_type"] = TypeAttr.get(function_type) |
88 | | - |
89 | | - if _emit: |
90 | | - return maybe_cast(get_result_or_results(func_op)) |
| 104 | + results = [results] |
91 | 105 | else: |
92 | | - call_op = CallOp( |
93 | | - [r.type for r in results], |
94 | | - FlatSymbolRefAttr.get(func_name), |
95 | | - call_args, |
96 | | - ) |
97 | | - return maybe_cast(get_result_or_results(call_op)) |
| 106 | + results = [] |
| 107 | + self.return_op_ctor(results) |
98 | 108 |
|
99 | | - def emit(): |
100 | | - nonlocal _emit |
101 | | - _emit = True |
102 | | - wrapper() |
| 109 | + return results, input_types, func_op |
103 | 110 |
|
104 | | - wrapper.emit = emit |
105 | | - return wrapper |
| 111 | + def emit(self): |
| 112 | + self.results, input_types, func_op = self.body_builder_wrapper() |
| 113 | + return_types = [v.type for v in self.results] |
| 114 | + function_type = FunctionType.get(inputs=input_types, results=return_types) |
| 115 | + func_op.attributes["function_type"] = TypeAttr.get(function_type) |
| 116 | + self.emitted = True |
| 117 | + # this is the func op itself (funcs never have a resulting ssa value) |
| 118 | + return maybe_cast(get_result_or_results(func_op)) |
106 | 119 |
|
107 | | - return builder_wrapper |
| 120 | + def __call__(self, *call_args): |
| 121 | + if not self.emitted: |
| 122 | + self.emit() |
| 123 | + call_op = CallOp( |
| 124 | + [r.type for r in self.results], |
| 125 | + FlatSymbolRefAttr.get(self.func_name), |
| 126 | + call_args, |
| 127 | + ) |
| 128 | + return maybe_cast(get_result_or_results(call_op)) |
108 | 129 |
|
109 | 130 |
|
110 | | -func = make_maybe_no_args_decorator( |
111 | | - partial(func_base, FuncOp=FuncOp.__base__, ReturnOp=ReturnOp, CallOp=CallOp) |
112 | | -) |
| 131 | +func = FuncBase(FuncOp.__base__, ReturnOp, CallOp.__base__) |
0 commit comments