1919from caffe2 .python .onnx import backend as caffe2_backend
2020from caffe2 .python .predictor import predictor_exporter
2121from fairseq import tasks , utils
22+ from fairseq .iterative_refinement_generator import DecoderOut
2223from fairseq .models import ARCH_MODEL_REGISTRY
23- from fairseq .models .model_utils import script_skip_tensor , script_skip_tensor_list
24+ from fairseq .models .model_utils import script_skip_tensor
25+ from fairseq .models .transformer import EncoderOut
2426from pytorch_translate .beam_decode import BeamDecode
2527from pytorch_translate .data import dictionary
2628from pytorch_translate .research .knowledge_distillation import (
@@ -2194,42 +2196,43 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None):
21942196
21952197 # initialize buffers (very model specific, with length prediction or not)
21962198 prev_decoder_out = model .initialize_output_tokens (encoder_out , src_tokens )
2197- prev_output_tokens = prev_decoder_out [ 0 ] .clone ()
2199+ prev_output_tokens = prev_decoder_out . output_tokens .clone ()
21982200
21992201 finalized_tokens_list = [torch .tensor (0 ) for _ in range (bsz )]
22002202 finalized_scores_list = [torch .tensor (0 ) for _ in range (bsz )]
22012203 finalized_attns_list = [torch .tensor (0 ) for _ in range (bsz )]
22022204 finalized_alignments_list = [torch .tensor (0 ) for _ in range (bsz )]
2203- prev_decoder_out [ 4 ] = self .max_iter + 1
2205+ prev_decoder_out . _replace ( max_step = self .max_iter + 1 )
22042206
22052207 for step in range (self .max_iter + 1 ):
2206- prev_decoder_out [ 3 ] = step
2208+ prev_decoder_out . _replace ( step = step )
22072209 decoder_out = model .forward_decoder (
22082210 prev_decoder_out ,
22092211 encoder_out ,
22102212 eos_penalty = self .eos_penalty ,
22112213 max_ratio = self .max_ratio if step == 0 else None ,
22122214 decoding_format = self .decoding_format ,
22132215 )
2214-
22152216 terminated , output_tokens , output_scores , output_attn = is_a_loop (
22162217 self .pad ,
22172218 prev_output_tokens ,
2218- decoder_out [0 ],
2219- decoder_out [1 ],
2220- decoder_out [2 ],
2219+ decoder_out .output_tokens ,
2220+ decoder_out .output_scores ,
2221+ decoder_out .attn ,
2222+ )
2223+ decoder_out ._replace (
2224+ output_tokens = output_tokens ,
2225+ output_scores = output_scores ,
2226+ attn = output_attn ,
22212227 )
2222- decoder_out [0 ] = output_tokens
2223- decoder_out [1 ] = output_scores
2224- decoder_out [2 ] = output_attn
22252228
22262229 terminated = last_step (step , self .max_iter , terminated )
22272230 # collect finalized sentences
22282231 finalized_idxs = sent_idxs [terminated ]
2229- finalized_tokens = decoder_out [ 0 ] [terminated ]
2230- finalized_scores = decoder_out [ 1 ] [terminated ]
2232+ finalized_tokens = decoder_out . output_tokens [terminated ]
2233+ finalized_scores = decoder_out . output_scores [terminated ]
22312234 finalized_attn = (
2232- None if decoder_out [ 2 ] is None else decoder_out [ 2 ] [terminated ]
2235+ None if decoder_out . attn is None else decoder_out . attn [terminated ]
22332236 )
22342237 finalized_tokens_list = finalize_hypos_loop_tokens (
22352238 finalized_tokens_list ,
@@ -2256,17 +2259,30 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None):
22562259 )
22572260
22582261 # for next step
2259- prev_decoder_out = [
2260- script_skip_tensor (decoder_out [0 ], ~ terminated ),
2261- script_skip_tensor (decoder_out [1 ], ~ terminated ),
2262- decoder_out [2 ],
2263- decoder_out [3 ],
2264- decoder_out [4 ],
2265- ]
2266- encoder_out = script_skip_tensor_list (encoder_out , ~ terminated )
2262+ prev_decoder_out = DecoderOut (
2263+ output_tokens = script_skip_tensor (
2264+ decoder_out .output_tokens , ~ terminated
2265+ ),
2266+ output_scores = script_skip_tensor (
2267+ decoder_out .output_scores , ~ terminated
2268+ ),
2269+ attn = decoder_out .attn ,
2270+ step = decoder_out .step ,
2271+ max_step = decoder_out .max_step ,
2272+ history = None ,
2273+ )
2274+
2275+ encoder_out = EncoderOut (
2276+ encoder_out = script_skip_tensor (encoder_out .encoder_out , ~ terminated ),
2277+ encoder_padding_mask = None ,
2278+ encoder_embedding = script_skip_tensor (
2279+ encoder_out .encoder_embedding , ~ terminated
2280+ ),
2281+ encoder_states = None ,
2282+ )
22672283 sent_idxs = script_skip_tensor (sent_idxs , ~ terminated )
22682284
2269- prev_output_tokens = prev_decoder_out [ 0 ] .clone ()
2285+ prev_output_tokens = prev_decoder_out . output_tokens .clone ()
22702286
22712287 return (
22722288 finalized_tokens_list ,
0 commit comments