Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit bed41ad

Browse files
Add onnx support (#292)
* onnx support first cut * incorporated review comments * add backport, move skl2onnx to test dependency, store onnx model tensors as external data * fix setup.py tests * fix setup.py tests * update dependencies * update dependencies * fix flaky testcase * Update tests/contrib/test_pandas.py Co-authored-by: Alexander Guschin <[email protected]>
1 parent 9518156 commit bed41ad

File tree

8 files changed

+407
-3
lines changed

8 files changed

+407
-3
lines changed

mlem/contrib/onnx.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from typing import Any, ClassVar, List, Optional, Union
2+
3+
import numpy as np
4+
import onnx
5+
import onnxruntime as onnxrt
6+
import pandas as pd
7+
from numpy.typing import DTypeLike
8+
from onnx import ModelProto, ValueInfoProto
9+
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
10+
11+
from mlem.core.artifacts import Artifacts, Storage
12+
from mlem.core.hooks import IsInstanceHookMixin
13+
from mlem.core.model import ModelHook, ModelIO, ModelType, Signature
14+
from mlem.core.requirements import InstallableRequirement, Requirements
15+
from mlem.utils.backport import cached_property
16+
from mlem.utils.module import get_object_requirements
17+
18+
19+
def convert_to_numpy(
20+
data: Union[np.ndarray, pd.DataFrame], dtype: DTypeLike
21+
) -> np.ndarray:
22+
"""Converts input data to numpy"""
23+
if isinstance(data, np.ndarray):
24+
pass
25+
elif isinstance(data, pd.DataFrame):
26+
data = data.to_numpy()
27+
else:
28+
raise TypeError(f"input data type: {type(data)} is not supported")
29+
return data.astype(dtype=dtype)
30+
31+
32+
def get_onnx_to_numpy_type(value_info: ValueInfoProto) -> DTypeLike:
33+
"""Returns numpy equivalent type of onnx value info"""
34+
onnx_type = value_info.type.tensor_type.elem_type
35+
return TENSOR_TYPE_TO_NP_TYPE[onnx_type]
36+
37+
38+
class ModelProtoIO(ModelIO):
39+
"""IO for ONNX model object"""
40+
41+
type: ClassVar[str] = "model_proto"
42+
43+
def dump(self, storage: Storage, path: str, model) -> Artifacts:
44+
path = f"{path}/model.onnx"
45+
with storage.open(path) as (f, art):
46+
onnx.save_model(
47+
model,
48+
f,
49+
save_as_external_data=True,
50+
location="tensors",
51+
size_threshold=0,
52+
all_tensors_to_one_file=True,
53+
)
54+
return {self.art_name: art}
55+
56+
def load(self, artifacts: Artifacts):
57+
if len(artifacts) != 1:
58+
raise ValueError("Invalid artifacts: should be one .onnx file")
59+
with artifacts[self.art_name].open() as f:
60+
return onnx.load_model(f)
61+
62+
63+
class ONNXModel(ModelType, ModelHook, IsInstanceHookMixin):
64+
"""
65+
:class:`mlem.core.model.ModelType` implementation for `onnx` models
66+
"""
67+
68+
type: ClassVar[str] = "onnx"
69+
io: ModelIO = ModelProtoIO()
70+
valid_types: ClassVar = (ModelProto,)
71+
72+
class Config:
73+
keep_untouched = (cached_property,)
74+
75+
@classmethod
76+
def process(
77+
cls, obj: Any, sample_data: Optional[Any] = None, **kwargs
78+
) -> ModelType:
79+
80+
model = ONNXModel(io=ModelProtoIO(), methods={}).bind(obj)
81+
# TODO - use ONNX infer shapes.
82+
onnxrt_predict = Signature.from_method(
83+
model.predict, auto_infer=sample_data is not None, data=sample_data
84+
)
85+
model.methods = {
86+
"predict": onnxrt_predict,
87+
}
88+
89+
return model
90+
91+
@cached_property
92+
def runtime_session(self) -> onnxrt.InferenceSession:
93+
"""Provides onnx runtime inference session"""
94+
# TODO - add support for runtime providers, options. add support for GPU devices.
95+
return onnxrt.InferenceSession(self.model.SerializeToString())
96+
97+
def predict(self, data: Union[List, np.ndarray, pd.DataFrame]) -> Any:
98+
"""Returns inference output for given input data"""
99+
model_inputs = self.runtime_session.get_inputs()
100+
101+
if not isinstance(data, list):
102+
data = [data]
103+
104+
if len(model_inputs) != len(data):
105+
raise ValueError(
106+
f"no of inputs provided: {len(data)}, "
107+
f"expected: {len(model_inputs)}"
108+
)
109+
110+
input_dict = {}
111+
for model_input, input_data in zip(self.model.graph.input, data):
112+
input_dict[model_input.name] = convert_to_numpy(
113+
input_data, get_onnx_to_numpy_type(model_input)
114+
)
115+
116+
label_names = [out.name for out in self.runtime_session.get_outputs()]
117+
pred_onnx = self.runtime_session.run(label_names, input_dict)
118+
119+
output = []
120+
for output_data in pred_onnx:
121+
if isinstance(
122+
output_data, list
123+
): # TODO - temporary workaround to fix fastapi model issues
124+
output.append(pd.DataFrame(output_data).to_numpy())
125+
else:
126+
output.append(output_data)
127+
128+
return output
129+
130+
def get_requirements(self) -> Requirements:
131+
return (
132+
super().get_requirements()
133+
+ InstallableRequirement.from_module(onnx)
134+
+ get_object_requirements(self.predict)
135+
+ Requirements.new(
136+
InstallableRequirement(module="protobuf", version="3.20.1")
137+
)
138+
)
139+
# https://github.com/protocolbuffers/protobuf/issues/10051

mlem/core/objects.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,9 @@ def from_obj(
609609
params: Dict[str, str] = None,
610610
) -> "MlemModel":
611611
mt = ModelAnalyzer.analyze(model, sample_data=sample_data)
612-
mt.model = model
612+
if mt.model is None:
613+
mt = mt.bind(model)
614+
613615
return MlemModel(
614616
model_type=mt,
615617
requirements=mt.get_requirements().expanded,

mlem/ext.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class ExtensionLoader:
9292
Extension("mlem.contrib.numpy", ["numpy"], False),
9393
Extension("mlem.contrib.pandas", ["pandas"], False),
9494
Extension("mlem.contrib.sklearn", ["sklearn"], False),
95+
Extension("mlem.contrib.onnx", ["onnx"], False),
9596
Extension("mlem.contrib.tensorflow", ["tensorflow"], False),
9697
Extension("mlem.contrib.torch", ["torch"], False),
9798
Extension("mlem.contrib.catboost", ["catboost"], False),

mlem/utils/backport.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import functools
2+
import sys
3+
4+
if sys.version_info >= (3, 8):
5+
cached_property = functools.cached_property
6+
else:
7+
# Code copied from Python 3.8 https://github.com/python/cpython/blob/3.8/Lib/functools.py
8+
# cached_property is not available in Python versions < 3.8.
9+
from _thread import RLock
10+
11+
_NOT_FOUND = object()
12+
13+
class cached_property:
14+
def __init__(self, func):
15+
self.func = func
16+
self.attrname = None
17+
self.__doc__ = func.__doc__
18+
self.lock = RLock()
19+
20+
def __set_name__(self, owner, name):
21+
if self.attrname is None:
22+
self.attrname = name
23+
elif name != self.attrname:
24+
raise TypeError(
25+
"Cannot assign the same cached_property to two different names "
26+
f"({self.attrname!r} and {name!r})."
27+
)
28+
29+
def __get__(self, instance, owner=None):
30+
if instance is None:
31+
return self
32+
if self.attrname is None:
33+
raise TypeError(
34+
"Cannot use cached_property instance without calling __set_name__ on it."
35+
)
36+
try:
37+
cache = instance.__dict__
38+
except AttributeError: # not all objects have __dict__ (e.g. class defines slots)
39+
msg = (
40+
f"No '__dict__' attribute on {type(instance).__name__!r} "
41+
f"instance to cache {self.attrname!r} property."
42+
)
43+
raise TypeError(msg) from None
44+
val = cache.get(self.attrname, _NOT_FOUND)
45+
if val is _NOT_FOUND:
46+
with self.lock:
47+
# check if another thread filled cache while we awaited lock
48+
val = cache.get(self.attrname, _NOT_FOUND)
49+
if val is _NOT_FOUND:
50+
val = self.func(instance)
51+
try:
52+
cache[self.attrname] = val
53+
except TypeError:
54+
msg = (
55+
f"The '__dict__' attribute on {type(instance).__name__!r} instance "
56+
f"does not support item assignment for caching {self.attrname!r} property."
57+
)
58+
raise TypeError(msg) from None
59+
return val

mlem/utils/module.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,9 @@ def lstrip_lines(lines: Union[str, List[str]], check=True) -> str:
343343

344344

345345
_SKIP_CLOSURE_OBJECTS: Dict[str, Dict[str, Set[str]]] = {
346-
"globals": {"re": {"_cache"}},
346+
# In onnx, "protobuf" module is imported using "google.protobuf" namespace which results in identifying "google"
347+
# as possible installable requirement which is incorrect. TODO - see if this can be handled in more correct way
348+
"globals": {"re": {"_cache"}, "onnx": {"google"}},
347349
"nonlocals": {},
348350
}
349351

setup.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"xlrd",
4747
"tables",
4848
"pyarrow",
49+
"skl2onnx",
4950
]
5051

5152
extras = {
@@ -54,6 +55,11 @@
5455
"pandas": ["pandas"],
5556
"numpy": ["numpy"],
5657
"sklearn": ["scikit-learn"],
58+
"onnx": ["onnx"],
59+
"onnxruntime": [
60+
"protobuf==3.20.0",
61+
"onnxruntime",
62+
], # TODO - see if it can be merged with onnx
5763
"catboost": ["catboost"],
5864
"xgboost": ["xgboost"],
5965
"lightgbm": ["lightgbm"],
@@ -151,13 +157,15 @@
151157
"model_io.lightgbm_io = mlem.contrib.lightgbm:LightGBMModelIO",
152158
"model_io.pickle = mlem.contrib.callable:PickleModelIO",
153159
"model_io.xgboost_io = mlem.contrib.xgboost:XGBoostModelIO",
160+
"model_io.model_proto = mlem.contrib.onnx:ModelProtoIO",
154161
"model_io.torch_io = mlem.contrib.torch:TorchModelIO",
155162
"model_io.tf_keras = mlem.contrib.tensorflow:TFKerasModelIO",
156163
"model_type.callable = mlem.contrib.callable:CallableModelType",
157164
"model_type.catboost = mlem.contrib.catboost:CatBoostModel",
158165
"model_type.lightgbm = mlem.contrib.lightgbm:LightGBMModel",
159166
"model_type.sklearn = mlem.contrib.sklearn:SklearnModel",
160167
"model_type.sklearn_pipeline = mlem.contrib.sklearn:SklearnPipelineType",
168+
"model_type.onnx = mlem.contrib.onnx:ONNXModel",
161169
"model_type.xgboost = mlem.contrib.xgboost:XGBoostModel",
162170
"model_type.torch = mlem.contrib.torch:TorchModel",
163171
"model_type.tf_keras = mlem.contrib.tensorflow:TFKerasModel",

0 commit comments

Comments
 (0)