diff --git a/classy_vision/models/classy_model.py b/classy_vision/models/classy_model.py index 0f91c6516..11f3c440a 100644 --- a/classy_vision/models/classy_model.py +++ b/classy_vision/models/classy_model.py @@ -59,7 +59,7 @@ def __call__(self, *args, **kwargs): return ret_val -class ClassyModelWrapper: +class ClassyModelWrapper(torch.nn.Module): """Base ClassyModel wrapper class. This class acts as a thin pass through wrapper which lets users modify the behavior @@ -68,9 +68,8 @@ class ClassyModelWrapper: accessed by the `classy_model` attribute. """ - # TODO: Make this torchscriptable by inheriting from nn.Module / ClassyModel - def __init__(self, classy_model): + super().__init__() self.classy_model = classy_model def __getattr__(self, name):