Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit fa59edb

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Model state should support PyTorch API
Summary: Classy Models should work like regular PyTorch models. The `{get, set}_classy_state` functions for state are the only blockers which this diff fixes by moving over to `state_dict` and `load_state_dict` Differential Revision: D25213283 fbshipit-source-id: 037572266d13ffde9a3a0c2c87aa9e76c5faeea1
1 parent c4d9725 commit fa59edb

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

classy_vision/models/classy_model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def from_checkpoint(cls, checkpoint):
201201
model.set_classy_state(checkpoint["classy_state_dict"]["base_model"])
202202
return model
203203

204-
def get_classy_state(self, deep_copy=False):
204+
def state_dict(self, deep_copy=False):
205205
"""Get the state of the ClassyModel.
206206
207207
The returned state is used for checkpointing.
@@ -222,7 +222,7 @@ def get_classy_state(self, deep_copy=False):
222222
# as the trunk state. If the model doesn't have heads attached, all of the
223223
# model's state lives in the trunk.
224224
self.clear_heads()
225-
trunk_state_dict = self.state_dict()
225+
trunk_state_dict = super().state_dict()
226226
self.set_heads(attached_heads)
227227

228228
head_state_dict = {}
@@ -252,7 +252,7 @@ def load_head_states(self, state, strict=True):
252252
for head_name, head_state in head_states.items():
253253
self._heads[block_name][head_name].load_state_dict(head_state, strict)
254254

255-
def set_classy_state(self, state, strict=True):
255+
def load_state_dict(self, state, strict=True):
256256
"""Set the state of the ClassyModel.
257257
258258
Args:
@@ -275,6 +275,12 @@ def set_classy_state(self, state, strict=True):
275275
# set the heads back again
276276
self.set_heads(attached_heads)
277277

278+
def get_classy_state(self, deep_copy=False):
279+
return self.state_dict(deep_copy=deep_copy)
280+
281+
def set_classy_state(self, state, strict=True):
282+
self.load_state_dict(state, strict=strict)
283+
278284
def forward(self, x):
279285
"""
280286
Perform computation of blocks in the order define in get_blocks.

0 commit comments

Comments
 (0)