Skip to content

Commit b498685

Browse files
Reformated onnx_transforms.py for ruff check
Signed-off-by: abhishek-singh591 <[email protected]>
1 parent 406155b commit b498685

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

QEfficient/base/onnx_transforms.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
6-
# ----------------------------------------------------------------------------
6+
# -----------------------------------------------------------------------------
7+
78
import os
89
from typing import Optional, Tuple
910
import numpy as np
1011
from onnx import ModelProto, external_data_helper, numpy_helper
1112
from concurrent.futures import ThreadPoolExecutor
1213

14+
1315
class OnnxTransform:
1416
"""
15-
OnnxTransform is the base class for graph modifications on exported onnx.
17+
OnnxTransform is the base class for graph modifications on exported ONNX.
1618
"""
1719

1820
def __init__(self):
@@ -22,20 +24,22 @@ def __init__(self):
2224
def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:
2325
"""
2426
Override this class to apply a transformation.
27+
2528
:param model: The model's ONNX graph to transform
2629
:param kwargs: Parameters needed for specific transforms. All transforms should take **kwargs to ignore unneeded kwargs.
27-
2830
:returns: ONNX graph after applying the transform
2931
:returns: Boolean indicating whether transform was applied
3032
"""
3133
raise NotImplementedError("Use subclasses for ONNX transform")
32-
34+
35+
3336
class FP16ClipTransform(OnnxTransform):
3437
@classmethod
3538
def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]:
3639
finfo = np.finfo(np.float16)
3740
fp16_max = finfo.max
3841
fp16_min = finfo.min
42+
3943
def clip_tensor(tensor):
4044
nptensor = numpy_helper.to_array(tensor, onnx_base_dir)
4145
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)):
@@ -51,6 +55,7 @@ def clip_tensor(tensor):
5155
transformed = any(results)
5256
return model, transformed
5357

58+
5459
class SplitTensorsTransform(OnnxTransform):
5560
@classmethod
5661
def apply(
@@ -63,14 +68,13 @@ def apply(
6368
size_threshold: int = 1024,
6469
**kwargs,
6570
) -> Tuple[ModelProto, bool]:
66-
6771
external_data_helper.load_external_data_for_model(model, onnx_base_dir)
6872
tensors = external_data_helper._get_all_tensors(model)
6973
file_assignments = []
7074
file_num = 0
7175
current_file_size = 0
72-
transformed = False
73-
76+
transformed = False
77+
7478
for tensor in tensors:
7579
if tensor.HasField("raw_data") and (tsize := len(tensor.raw_data)) > size_threshold:
7680
transformed = True
@@ -85,5 +89,5 @@ def process_tensor(args):
8589
external_data_helper.set_external_data(tensor, file_name)
8690

8791
with ThreadPoolExecutor(max_workers=os.cpu_count() * 4) as executor:
88-
executor.map(process_tensor, file_assignments)
92+
list(executor.map(process_tensor, file_assignments))
8993
return model, transformed

0 commit comments

Comments
 (0)