diff --git a/onnx2pytorch/convert/attribute.py b/onnx2pytorch/convert/attribute.py index 4cdcd61..d31adbb 100644 --- a/onnx2pytorch/convert/attribute.py +++ b/onnx2pytorch/convert/attribute.py @@ -168,7 +168,10 @@ def extract_attributes(node): elif attr.name == "transA": kwargs["transpose_activation"] = bool(extract_attr_values(attr)) elif attr.name == "value": - kwargs["constant"] = extract_attr_values(attr) + if node.op_type == "Pad": + kwargs["value"] = extract_attr_values(attr) + else: + kwargs["constant"] = extract_attr_values(attr) elif attr.name == "value_float": kwargs["constant"] = extract_attr_values(attr) elif attr.name == "value_floats": diff --git a/onnx2pytorch/convert/model.py b/onnx2pytorch/convert/model.py index b3b79eb..a7bc9b8 100644 --- a/onnx2pytorch/convert/model.py +++ b/onnx2pytorch/convert/model.py @@ -10,6 +10,7 @@ from torch import nn from torch.jit import TracerWarning from torch.nn.modules.linear import Identity +from torch.nn import Dropout from onnx2pytorch.constants import ( COMPOSITE_LAYERS, @@ -221,6 +222,9 @@ def forward(self, *input_list, **input_dict): for out_op_id, output in zip(node.output, op(*in_activations)): activations[out_op_id] = output else: + if 'dropout' in out_op_id: + op = Dropout() + activations[out_op_id] = op(*in_activations) # Remove activations that are no longer needed diff --git a/onnx2pytorch/operations/pad.py b/onnx2pytorch/operations/pad.py index e0c7fd8..92fd51e 100644 --- a/onnx2pytorch/operations/pad.py +++ b/onnx2pytorch/operations/pad.py @@ -4,12 +4,15 @@ class Pad(Operator): - def __init__(self, mode="constant", padding=None): + def __init__(self, mode="constant", padding=None, value=0): self.mode = mode self.padding = padding + self.value = value super().__init__() - def forward(self, input, pads=None, value=0): + def forward(self, input, pads=None, value=None): + if value is None: + value = self.value if self.padding is not None: pads = self.padding elif pads is None: diff --git a/requirements.txt b/requirements.txt index cabe9b7..dcb0208 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ torchvision>=0.9.0 onnx>=1.6.0 onnxruntime>=1.5.0 numpy>=1.18.1 +protobuf==3.20.0