33import logging
44import os
55import tempfile
6+ from collections import OrderedDict
67
78import numpy as np
89import onnx
@@ -155,8 +156,8 @@ def forward(self, src_tokens, src_lengths):
155156 encoder_outputs = encoder_out [0 ]
156157 outputs .append (encoder_outputs )
157158 output_names .append (f"encoder_output_{ i } " )
158- if hasattr (model .decoder , "_init_prev_states " ):
159- states .extend (model .decoder ._init_prev_states (encoder_out ))
159+ if hasattr (model .decoder , "get_init_prev_states " ):
160+ states .extend (model .decoder .get_init_prev_states (encoder_out ))
160161
161162 # underlying assumption is each model has same vocab_reduction_module
162163 vocab_reduction_module = self .models [0 ].decoder .vocab_reduction_module
@@ -272,9 +273,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
272273
273274 next_state_input = len (self .models )
274275
275- # size of "batch" dimension of input as tensor
276- batch_size = torch .onnx .operators .shape_as_tensor (input_tokens )[0 ]
277-
278276 # underlying assumption is each model has same vocab_reduction_module
279277 vocab_reduction_module = self .models [0 ].decoder .vocab_reduction_module
280278 if vocab_reduction_module is not None :
@@ -285,20 +283,6 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
285283
286284 for i , model in enumerate (self .models ):
287285 encoder_output = inputs [i ]
288- prev_hiddens = []
289- prev_cells = []
290-
291- for _ in range (len (model .decoder .layers )):
292- prev_hiddens .append (inputs [next_state_input ])
293- prev_cells .append (inputs [next_state_input + 1 ])
294- next_state_input += 2
295-
296- # ensure previous attention context has batch dimension
297- input_feed_shape = torch .cat ((batch_size .view (1 ), torch .LongTensor ([- 1 ])))
298- prev_input_feed = torch .onnx .operators .reshape_from_tensor_shape (
299- inputs [next_state_input ], input_feed_shape
300- )
301- next_state_input += 1
302286
303287 # no batching, we only care about care about "max" length
304288 src_length_int = int (encoder_output .size ()[0 ])
@@ -310,8 +294,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
310294
311295 encoder_out = (
312296 encoder_output ,
313- prev_hiddens ,
314- prev_cells ,
297+ None ,
298+ None ,
315299 src_length ,
316300 src_tokens ,
317301 src_embeddings ,
@@ -321,16 +305,12 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
321305 model .decoder ._is_incremental_eval = True
322306 model .eval ()
323307
324- # placeholder
325- incremental_state = {}
326-
327- # cache previous state inputs
328- utils .set_incremental_state (
329- model .decoder ,
330- incremental_state ,
331- "cached_state" ,
332- (prev_hiddens , prev_cells , prev_input_feed ),
333- )
308+ # pass state inputs via incremental_state
309+ num_states = model .decoder .get_num_states ()
310+ prev_states = inputs [next_state_input : next_state_input + num_states ]
311+ next_state_input += num_states
312+ incremental_state = OrderedDict ()
313+ model .decoder .populate_incremental_state (incremental_state , prev_states )
334314
335315 decoder_output = model .decoder (
336316 input_tokens ,
@@ -345,13 +325,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
345325 log_probs_per_model .append (log_probs )
346326 attn_weights_per_model .append (attn_scores )
347327
348- (next_hiddens , next_cells , next_input_feed ) = utils .get_incremental_state (
349- model .decoder , incremental_state , "cached_state"
350- )
351-
352- for h , c in zip (next_hiddens , next_cells ):
353- state_outputs .extend ([h , c ])
354- state_outputs .append (next_input_feed )
328+ next_states = model .decoder .serialize_incremental_state (incremental_state )
329+ state_outputs .extend (next_states )
355330
356331 average_log_probs = torch .mean (
357332 torch .cat (log_probs_per_model , dim = 1 ), dim = 1 , keepdim = True
@@ -735,15 +710,6 @@ def forward(self, input_token, target_token, timestep, *inputs):
735710
736711 for i , model in enumerate (self .models ):
737712 encoder_output = inputs [i ]
738- prev_hiddens = []
739- prev_cells = []
740-
741- for _ in range (len (model .decoder .layers )):
742- prev_hiddens .append (inputs [next_state_input ])
743- prev_cells .append (inputs [next_state_input + 1 ])
744- next_state_input += 2
745- prev_input_feed = inputs [next_state_input ].view (1 , - 1 )
746- next_state_input += 1
747713
748714 # no batching, we only care about care about "max" length
749715 src_length_int = int (encoder_output .size ()[0 ])
@@ -755,8 +721,8 @@ def forward(self, input_token, target_token, timestep, *inputs):
755721
756722 encoder_out = (
757723 encoder_output ,
758- prev_hiddens ,
759- prev_cells ,
724+ None ,
725+ None ,
760726 src_length ,
761727 src_tokens ,
762728 src_embeddings ,
@@ -766,16 +732,12 @@ def forward(self, input_token, target_token, timestep, *inputs):
766732 model .decoder ._is_incremental_eval = True
767733 model .eval ()
768734
769- # placeholder
770- incremental_state = {}
771-
772- # cache previous state inputs
773- utils .set_incremental_state (
774- model .decoder ,
775- incremental_state ,
776- "cached_state" ,
777- (prev_hiddens , prev_cells , prev_input_feed ),
778- )
735+ # pass state inputs via incremental_state
736+ num_states = model .decoder .get_num_states ()
737+ prev_states = inputs [next_state_input : next_state_input + num_states ]
738+ next_state_input += num_states
739+ incremental_state = OrderedDict ()
740+ model .decoder .populate_incremental_state (incremental_state , prev_states )
779741
780742 decoder_output = model .decoder (
781743 input_token .view (1 , 1 ),
@@ -789,13 +751,8 @@ def forward(self, input_token, target_token, timestep, *inputs):
789751
790752 log_probs_per_model .append (log_probs )
791753
792- (next_hiddens , next_cells , next_input_feed ) = utils .get_incremental_state (
793- model .decoder , incremental_state , "cached_state"
794- )
795-
796- for h , c in zip (next_hiddens , next_cells ):
797- state_outputs .extend ([h , c ])
798- state_outputs .append (next_input_feed )
754+ next_states = model .decoder .serialize_incremental_state (incremental_state )
755+ state_outputs .extend (next_states )
799756
800757 average_log_probs = torch .mean (
801758 torch .cat (log_probs_per_model , dim = 0 ), dim = 0 , keepdim = True
@@ -1020,8 +977,8 @@ def forward(self, src_tokens, src_lengths, char_inds, word_lengths):
1020977 outputs .append (encoder_outputs )
1021978 output_names .append (f"encoder_output_{ i } " )
1022979
1023- if hasattr (model .decoder , "_init_prev_states " ):
1024- states .extend (model .decoder ._init_prev_states (encoder_out ))
980+ if hasattr (model .decoder , "get_init_prev_states " ):
981+ states .extend (model .decoder .get_init_prev_states (encoder_out ))
1025982
1026983 # underlying assumption is each model has same vocab_reduction_module
1027984 vocab_reduction_module = self .models [0 ].decoder .vocab_reduction_module
0 commit comments