3
3
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
4
# SPDX-License-Identifier: BSD-3-Clause
5
5
#
6
- # ----------------------------------------------------------------------------
6
+ # -----------------------------------------------------------------------------
7
+
7
8
import os
8
9
from typing import Optional , Tuple
9
10
import numpy as np
10
11
from onnx import ModelProto , external_data_helper , numpy_helper
11
12
from concurrent .futures import ThreadPoolExecutor
12
13
14
+
13
15
class OnnxTransform :
14
16
"""
15
- OnnxTransform is the base class for graph modifications on exported onnx .
17
+ OnnxTransform is the base class for graph modifications on exported ONNX .
16
18
"""
17
19
18
20
def __init__ (self ):
@@ -22,20 +24,22 @@ def __init__(self):
22
24
def apply (cls , model : ModelProto , ** kwargs ) -> Tuple [ModelProto , bool ]:
23
25
"""
24
26
Override this class to apply a transformation.
27
+
25
28
:param model: The model's ONNX graph to transform
26
29
:param kwargs: Parameters needed for specific transforms. All transforms should take **kwargs to ignore unneeded kwargs.
27
-
28
30
:returns: ONNX graph after applying the transform
29
31
:returns: Boolean indicating whether transform was applied
30
32
"""
31
33
raise NotImplementedError ("Use subclasses for ONNX transform" )
32
-
34
+
35
+
33
36
class FP16ClipTransform (OnnxTransform ):
34
37
@classmethod
35
38
def apply (cls , model : ModelProto , * , onnx_base_dir : Optional [str ] = None , ** kwargs ) -> Tuple [ModelProto , bool ]:
36
39
finfo = np .finfo (np .float16 )
37
40
fp16_max = finfo .max
38
41
fp16_min = finfo .min
42
+
39
43
def clip_tensor (tensor ):
40
44
nptensor = numpy_helper .to_array (tensor , onnx_base_dir )
41
45
if nptensor .dtype == np .float32 and (np .any (nptensor > fp16_max ) or np .any (nptensor < fp16_min )):
@@ -51,6 +55,7 @@ def clip_tensor(tensor):
51
55
transformed = any (results )
52
56
return model , transformed
53
57
58
+
54
59
class SplitTensorsTransform (OnnxTransform ):
55
60
@classmethod
56
61
def apply (
@@ -63,14 +68,13 @@ def apply(
63
68
size_threshold : int = 1024 ,
64
69
** kwargs ,
65
70
) -> Tuple [ModelProto , bool ]:
66
-
67
71
external_data_helper .load_external_data_for_model (model , onnx_base_dir )
68
72
tensors = external_data_helper ._get_all_tensors (model )
69
73
file_assignments = []
70
74
file_num = 0
71
75
current_file_size = 0
72
- transformed = False
73
-
76
+ transformed = False
77
+
74
78
for tensor in tensors :
75
79
if tensor .HasField ("raw_data" ) and (tsize := len (tensor .raw_data )) > size_threshold :
76
80
transformed = True
@@ -85,5 +89,5 @@ def process_tensor(args):
85
89
external_data_helper .set_external_data (tensor , file_name )
86
90
87
91
with ThreadPoolExecutor (max_workers = os .cpu_count () * 4 ) as executor :
88
- executor .map (process_tensor , file_assignments )
92
+ list ( executor .map (process_tensor , file_assignments ) )
89
93
return model , transformed
0 commit comments