Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion onmt/decoders/layer_stack_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ def __init__(self, embeddings, decoders):
self._active: List[str] = []

@classmethod
def from_opt(cls, opt, embeddings, task_queue_manager):
def from_opt(cls, opt, embeddings, task_queue_manager, is_on_top=False):
"""Alternate constructor for use during training."""
decoders = nn.ModuleList()
for layer_stack_index, n_layers in enumerate(opt.dec_layers):
is_on_top = layer_stack_index == len(opt.dec_layers) - 1
stacks = nn.ModuleDict()
for module_id in task_queue_manager.get_decoders(layer_stack_index):
if module_id in stacks:
Expand All @@ -46,6 +47,10 @@ def from_opt(cls, opt, embeddings, task_queue_manager):
opt.alignment_layer,
alignment_heads=opt.alignment_heads,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
),
)
decoders.append(stacks)
return cls(embeddings, decoders)
Expand All @@ -56,6 +61,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack):
decoders = nn.ModuleList()
for layer_stack_index, n_layers in enumerate(model_opt.dec_layers):
stacks = nn.ModuleDict()
is_on_top = layer_stack_index == len(model_opt.dec_layers) - 1
module_opts = opt_stack['decoder'][layer_stack_index]
module_id = module_opts['id']
stacks[module_id] = AdaptedTransformerDecoder(
Expand All @@ -78,6 +84,10 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack):
model_opt.alignment_layer,
alignment_heads=model_opt.alignment_heads,
pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(model_opt.dec_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
),
)
decoders.append(stacks)
return cls(embeddings, decoders)
Expand Down Expand Up @@ -200,3 +210,46 @@ def activate(self, metadata: DatasetMetadata):
for layer_stack_index, adapter_group, sub_id in metadata.decoder_adapter_ids:
module_id = metadata.decoder_id[layer_stack_index]
self.activate_adapter(module_id, adapter_group, sub_id)

def make_shallow(self, module_keys: List[str], model_opt):
"""Utility function for HF port.
Simplifies the structure of the layerstack for easier statedict mapping
"""

assert len(module_keys) == self.n_layer_stacks, \
"Need all module keys for a given task to make the encoder shallow"

shallow_decoder = AdaptedTransformerDecoder(
0,
model_opt.dec_rnn_size,
model_opt.heads,
model_opt.transformer_ff,
model_opt.copy_attn,
model_opt.self_attn_type,
model_opt.dropout[0] if type(model_opt.dropout) is list else model_opt.dropout,
(
model_opt.attention_dropout[0]
if type(model_opt.attention_dropout) is list
else model_opt.attention_dropout
),
None, # embeddings,
model_opt.max_relative_positions,
model_opt.aan_useffn,
model_opt.full_context_alignment,
model_opt.alignment_layer,
alignment_heads=model_opt.alignment_heads,
pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn,
layer_norm_module=nn.LayerNorm(model_opt.dec_rnn_size, eps=1e-6),
)
shallow_decoder.layer_norm = self.decoders[-1][module_keys[-1]].layer_norm

for idx, key in enumerate(module_keys):
stack = self.decoders[idx][key].transformer
shallow_decoder.transformer.extend(stack)

wrapped = nn.ModuleList([
nn.ModuleDict({
'shallow': shallow_decoder
})
])
return self.__class__(self.embeddings, wrapped)
13 changes: 9 additions & 4 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _forward(


class TransformerDecoderBase(DecoderBase):
def __init__(self, d_model, copy_attn, embeddings, alignment_layer):
def __init__(self, d_model, copy_attn, embeddings, alignment_layer, layer_norm_module):
super(TransformerDecoderBase, self).__init__()

self.embeddings = embeddings
Expand All @@ -278,12 +278,12 @@ def __init__(self, d_model, copy_attn, embeddings, alignment_layer):
# attention. But it was never actually used -- the "copy" attention
# just reuses the context attention.
self._copy = copy_attn
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.layer_norm = layer_norm_module

self.alignment_layer = alignment_layer

@classmethod
def from_opt(cls, opt, embeddings):
def from_opt(cls, opt, embeddings, is_on_top=False):
"""Alternate constructor."""
return cls(
opt.dec_layers,
Expand All @@ -301,6 +301,10 @@ def from_opt(cls, opt, embeddings):
opt.alignment_layer,
alignment_heads=opt.alignment_heads,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opt.dec_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
),
)

def init_state(self, src, memory_bank, enc_hidden):
Expand Down Expand Up @@ -391,8 +395,9 @@ def __init__(
alignment_layer,
alignment_heads,
pos_ffn_activation_fn=ActivationFunction.relu,
layer_norm_module=None,
):
super(TransformerDecoder, self).__init__(d_model, copy_attn, embeddings, alignment_layer)
super(TransformerDecoder, self).__init__(d_model, copy_attn, embeddings, alignment_layer, layer_norm_module)

self.transformer_layers = nn.ModuleList(
[
Expand Down
47 changes: 47 additions & 0 deletions onmt/encoders/layer_stack_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def from_opt(cls, opt, embeddings, task_queue_manager):
encoders = nn.ModuleList()
for layer_stack_index, n_layers in enumerate(opt.enc_layers):
stacks = nn.ModuleDict()
is_on_top = layer_stack_index == len(opt.enc_layers) - 1
for module_id in task_queue_manager.get_encoders(layer_stack_index):
if module_id in stacks:
# several tasks using the same layer stack
Expand All @@ -40,6 +41,10 @@ def from_opt(cls, opt, embeddings, task_queue_manager):
None, # embeddings,
opt.max_relative_positions,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
)
)
encoders.append(stacks)
return cls(embeddings, encoders)
Expand All @@ -52,6 +57,7 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack):
stacks = nn.ModuleDict()
module_opts = opt_stack['encoder'][layer_stack_index]
module_id = module_opts['id']
is_on_top = layer_stack_index == len(model_opt.enc_layers) - 1
stacks[module_id] = AdaptedTransformerEncoder(
n_layers,
model_opt.enc_rnn_size,
Expand All @@ -66,6 +72,10 @@ def from_trans_opt(cls, model_opt, embeddings, opt_stack):
None, # embeddings,
model_opt.max_relative_positions,
pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(model_opt.enc_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
)
)
encoders.append(stacks)
return cls(embeddings, encoders)
Expand Down Expand Up @@ -163,3 +173,40 @@ def activate(self, metadata: DatasetMetadata):
for layer_stack_index, adapter_group, sub_id in metadata.encoder_adapter_ids:
module_id = metadata.encoder_id[layer_stack_index]
self.activate_adapter(module_id, adapter_group, sub_id)

def make_shallow(self, module_keys: List[str], model_opt):
"""Utility function for HF port.
Simplifies the structure of the layerstack for easier statedict mapping
"""

assert len(module_keys) == self.n_layer_stacks, \
"Need all module keys for a given task to make the encoder shallow"

shallow_encoder = AdaptedTransformerEncoder(
0,
model_opt.enc_rnn_size,
model_opt.heads,
model_opt.transformer_ff,
model_opt.dropout[0] if type(model_opt.dropout) is list else model_opt.dropout,
(
model_opt.attention_dropout[0]
if type(model_opt.attention_dropout) is list
else model_opt.attention_dropout
),
None, # embeddings,
model_opt.max_relative_positions,
pos_ffn_activation_fn=model_opt.pos_ffn_activation_fn,
layer_norm_module=nn.LayerNorm(model_opt.enc_rnn_size, eps=1e-6)
)
shallow_encoder.layer_norm = self.encoders[-1][module_keys[-1]].layer_norm

for idx, key in enumerate(module_keys):
stack = self.encoders[idx][key].transformer
shallow_encoder.transformer.extend(stack)

wrapped = nn.ModuleList([
nn.ModuleDict({
'shallow': shallow_encoder
})
])
return self.__class__(self.embeddings, wrapped)
9 changes: 7 additions & 2 deletions onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
embeddings,
max_relative_positions,
pos_ffn_activation_fn=ActivationFunction.relu,
layer_norm_module=None,
):
super(TransformerEncoder, self).__init__()

Expand All @@ -129,10 +130,10 @@ def __init__(
for i in range(num_layers)
]
)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.layer_norm = layer_norm_module

@classmethod
def from_opt(cls, opt, embeddings):
def from_opt(cls, opt, embeddings, is_on_top=False):
"""Alternate constructor."""
return cls(
opt.enc_layers,
Expand All @@ -144,6 +145,10 @@ def from_opt(cls, opt, embeddings):
embeddings,
opt.max_relative_positions,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opt.enc_rnn_size, eps=1e-6) if is_on_top
else nn.Identity()
)
)

def forward(self, src, lengths=None, skip_embedding=False, mask=None):
Expand Down
Empty file added tools/huggingface/__init__.py
Empty file.
Loading