Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.

Commit 0f03af5

Browse files
cndnfacebook-github-bot
authored andcommitted
Fix export code after model "hard fork" (#667)
Summary: Pull Request resolved: #667 After internal version is introduced and transformer uses namedtuple instead of dict, need to update the export code accordingly. Also we used torch.empty([0]) to represent initial attention instead of None, to unify the data type in TorchScript. Reviewed By: jhcross Differential Revision: D18648224 fbshipit-source-id: efea4252675eae98de24029f18790a7cda8ecdbd
1 parent 5bc2c48 commit 0f03af5

File tree

1 file changed

+39
-23
lines changed

1 file changed

+39
-23
lines changed

pytorch_translate/ensemble_export.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
from caffe2.python.onnx import backend as caffe2_backend
2020
from caffe2.python.predictor import predictor_exporter
2121
from fairseq import tasks, utils
22+
from fairseq.iterative_refinement_generator import DecoderOut
2223
from fairseq.models import ARCH_MODEL_REGISTRY
23-
from fairseq.models.model_utils import script_skip_tensor, script_skip_tensor_list
24+
from fairseq.models.model_utils import script_skip_tensor
25+
from fairseq.models.transformer import EncoderOut
2426
from pytorch_translate.beam_decode import BeamDecode
2527
from pytorch_translate.data import dictionary
2628
from pytorch_translate.research.knowledge_distillation import (
@@ -2194,42 +2196,43 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None):
21942196

21952197
# initialize buffers (very model specific, with length prediction or not)
21962198
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
2197-
prev_output_tokens = prev_decoder_out[0].clone()
2199+
prev_output_tokens = prev_decoder_out.output_tokens.clone()
21982200

21992201
finalized_tokens_list = [torch.tensor(0) for _ in range(bsz)]
22002202
finalized_scores_list = [torch.tensor(0) for _ in range(bsz)]
22012203
finalized_attns_list = [torch.tensor(0) for _ in range(bsz)]
22022204
finalized_alignments_list = [torch.tensor(0) for _ in range(bsz)]
2203-
prev_decoder_out[4] = self.max_iter + 1
2205+
prev_decoder_out._replace(max_step=self.max_iter + 1)
22042206

22052207
for step in range(self.max_iter + 1):
2206-
prev_decoder_out[3] = step
2208+
prev_decoder_out._replace(step=step)
22072209
decoder_out = model.forward_decoder(
22082210
prev_decoder_out,
22092211
encoder_out,
22102212
eos_penalty=self.eos_penalty,
22112213
max_ratio=self.max_ratio if step == 0 else None,
22122214
decoding_format=self.decoding_format,
22132215
)
2214-
22152216
terminated, output_tokens, output_scores, output_attn = is_a_loop(
22162217
self.pad,
22172218
prev_output_tokens,
2218-
decoder_out[0],
2219-
decoder_out[1],
2220-
decoder_out[2],
2219+
decoder_out.output_tokens,
2220+
decoder_out.output_scores,
2221+
decoder_out.attn,
2222+
)
2223+
decoder_out._replace(
2224+
output_tokens=output_tokens,
2225+
output_scores=output_scores,
2226+
attn=output_attn,
22212227
)
2222-
decoder_out[0] = output_tokens
2223-
decoder_out[1] = output_scores
2224-
decoder_out[2] = output_attn
22252228

22262229
terminated = last_step(step, self.max_iter, terminated)
22272230
# collect finalized sentences
22282231
finalized_idxs = sent_idxs[terminated]
2229-
finalized_tokens = decoder_out[0][terminated]
2230-
finalized_scores = decoder_out[1][terminated]
2232+
finalized_tokens = decoder_out.output_tokens[terminated]
2233+
finalized_scores = decoder_out.output_scores[terminated]
22312234
finalized_attn = (
2232-
None if decoder_out[2] is None else decoder_out[2][terminated]
2235+
None if decoder_out.attn is None else decoder_out.attn[terminated]
22332236
)
22342237
finalized_tokens_list = finalize_hypos_loop_tokens(
22352238
finalized_tokens_list,
@@ -2256,17 +2259,30 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None):
22562259
)
22572260

22582261
# for next step
2259-
prev_decoder_out = [
2260-
script_skip_tensor(decoder_out[0], ~terminated),
2261-
script_skip_tensor(decoder_out[1], ~terminated),
2262-
decoder_out[2],
2263-
decoder_out[3],
2264-
decoder_out[4],
2265-
]
2266-
encoder_out = script_skip_tensor_list(encoder_out, ~terminated)
2262+
prev_decoder_out = DecoderOut(
2263+
output_tokens=script_skip_tensor(
2264+
decoder_out.output_tokens, ~terminated
2265+
),
2266+
output_scores=script_skip_tensor(
2267+
decoder_out.output_scores, ~terminated
2268+
),
2269+
attn=decoder_out.attn,
2270+
step=decoder_out.step,
2271+
max_step=decoder_out.max_step,
2272+
history=None,
2273+
)
2274+
2275+
encoder_out = EncoderOut(
2276+
encoder_out=script_skip_tensor(encoder_out.encoder_out, ~terminated),
2277+
encoder_padding_mask=None,
2278+
encoder_embedding=script_skip_tensor(
2279+
encoder_out.encoder_embedding, ~terminated
2280+
),
2281+
encoder_states=None,
2282+
)
22672283
sent_idxs = script_skip_tensor(sent_idxs, ~terminated)
22682284

2269-
prev_output_tokens = prev_decoder_out[0].clone()
2285+
prev_output_tokens = prev_decoder_out.output_tokens.clone()
22702286

22712287
return (
22722288
finalized_tokens_list,

0 commit comments

Comments
 (0)