diff --git a/pytorch_translate/ensemble_export.py b/pytorch_translate/ensemble_export.py index 5b0de249..55e36afb 100644 --- a/pytorch_translate/ensemble_export.py +++ b/pytorch_translate/ensemble_export.py @@ -2073,7 +2073,7 @@ def finalize_hypos_loop_attns( class IterativeRefinementGenerateAndDecode(torch.jit.ScriptModule): def __init__( - self, checkpoint_files, src_dict_filename, tgt_dict_filename, max_iter=2 + self, checkpoint_files, src_dict_filename, tgt_dict_filename, max_iter=1 ): super().__init__() self.models, _, tgt_dict = load_models_from_checkpoints( @@ -2209,21 +2209,16 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None): decoding_format=self.decoding_format, ) - if self.adaptive: - # terminate if looping. - terminated, output_tokens, output_scores, output_attn = is_a_loop( - self.pad, - prev_output_tokens, - decoder_out[0], - decoder_out[1], - decoder_out[2], - ) - decoder_out[0] = output_tokens - decoder_out[1] = output_scores - decoder_out[2] = output_attn - - else: - terminated = torch.zeros_like(decoder_out[0]).bool() + terminated, output_tokens, output_scores, output_attn = is_a_loop( + self.pad, + prev_output_tokens, + decoder_out[0], + decoder_out[1], + decoder_out[2], + ) + decoder_out[0] = output_tokens + decoder_out[1] = output_scores + decoder_out[2] = output_attn terminated = last_step(step, self.max_iter, terminated) # collect finalized sentences @@ -2257,10 +2252,6 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None): finalized_attn, ) - # check if all terminated - if terminated.sum() == terminated.size(0): - break - # for next step prev_decoder_out = [ script_skip_tensor(decoder_out[0], ~terminated),