Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit da43c27

Browse files
William FedusMesh TensorFlow Team
authored andcommitted
Option to use mtf.Print to log which tokens are sent to which experts when run on CPU.
PiperOrigin-RevId: 368137313
1 parent 57ed401 commit da43c27

File tree

3 files changed

+85
-11
lines changed

3 files changed

+85
-11
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def __init__(self,
6565
word_embed_mode=None,
6666
use_second_place_expert_prob=None,
6767
use_second_place_expert_prob_temp=None,
68-
top_n_num_experts_per_token=3):
68+
top_n_num_experts_per_token=3,
69+
token_logging=False):
6970
self._hparams = HParams(
7071
moe_gating=moe_gating,
7172
moe_num_experts=num_experts,
@@ -97,6 +98,7 @@ def __init__(self,
9798
use_second_place_expert_prob_temp),
9899
moe_top_n_num_experts_per_token=top_n_num_experts_per_token)
99100
self._activation = activation
101+
self.token_logging = token_logging
100102

101103
def call(self, context, x, losses=None):
102104
"""Call the layer."""
@@ -116,7 +118,13 @@ def call(self, context, x, losses=None):
116118
output_dim = self._hparams.moe_output_dim
117119
else:
118120
output_dim = context.model.model_dim
119-
y, loss = transformer_moe_layer_v1(
121+
if self.token_logging:
122+
tokens = _detokenize(context.inputs, context.model.vocabulary)
123+
x = mtf.Print(x, [tokens], "tokens:", summarize=1000)
124+
extras = _windows(context.inputs, context.length_dim)
125+
else:
126+
extras = None
127+
y, loss, extras = transformer_moe_layer_v1(
120128
x,
121129
output_dim,
122130
self._hparams,
@@ -127,7 +135,16 @@ def call(self, context, x, losses=None):
127135
nonpadding=context.nonpadding,
128136
activation=self._activation,
129137
num_microbatches=context.num_microbatches,
130-
token_embeddings=context.input_embeddings)
138+
token_embeddings=context.input_embeddings,
139+
extras=extras)
140+
141+
if extras:
142+
extras = _detokenize(extras, context.model.vocabulary)
143+
experts_dim = mtf.Dimension("experts", self._hparams.moe_num_experts)
144+
extras = mtf.unstack(extras, experts_dim)
145+
for i, t in enumerate(extras):
146+
y = mtf.Print(y, [t], "EXPERT %s:" % i, summarize=1000)
147+
131148
if context.losses is not None:
132149
context.losses.append(loss)
133150
if not has_length_dim:
@@ -139,6 +156,23 @@ def call(self, context, x, losses=None):
139156
return y
140157

141158

159+
@gin.configurable
160+
def _windows(ids, length_dim, window_start=0, window_end=0):
161+
to_stack = []
162+
for offset in range(window_start, window_end + 1):
163+
to_stack.append(mtf.shift(ids, -offset, length_dim, wrap=False))
164+
return mtf.stack(to_stack, "window", axis=ids.shape.ndims)
165+
166+
167+
def _detokenize(ids, vocabulary):
168+
return mtf.slicewise(
169+
vocabulary.decode_tf,
170+
[ids],
171+
output_shape=mtf.Shape(ids.shape.dims[:-1]),
172+
output_dtype=tf.string,
173+
splittable_dims=ids.shape.dims[:-1])
174+
175+
142176
class MoE2D(transformer.TransformerLayer):
143177
"""Mixture of Experts Layer."""
144178

@@ -202,7 +236,7 @@ def call(self, context, x, losses=None):
202236
def transformer_moe_layer_v1(
203237
inputs, output_dim, hparams, train, variable_dtype,
204238
layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
205-
num_microbatches=None, token_embeddings=None):
239+
num_microbatches=None, token_embeddings=None, extras=None):
206240
"""Local mixture of experts that works well on TPU.
207241
208242
Adapted from the paper https://arxiv.org/abs/1701.06538
@@ -281,6 +315,7 @@ def transformer_moe_layer_v1(
281315
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
282316
that correspond to the inputs. These can optionally be used to make
283317
routing decisions.
318+
extras: a tensor to dispatch (for debugging purposes)
284319
285320
Returns:
286321
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
@@ -344,6 +379,10 @@ def transformer_moe_layer_v1(
344379
# over which those groups are split.
345380
batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
346381
orig_inputs.shape.dims[-1])
382+
383+
if extras:
384+
extras_dims = extras.shape.dims[len(batch_and_length_dims):]
385+
347386
# Hack: we assume that
348387
# "outer_batch" == replication of experts
349388
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
@@ -381,6 +420,11 @@ def transformer_moe_layer_v1(
381420
token_embeddings = mtf.cast(
382421
mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)
383422

423+
if extras:
424+
extras = mtf.reshape(
425+
extras,
426+
[outer_batch_dim, num_groups_dim, group_size_dim] + extras_dims)
427+
384428
# Each sequence sends expert_capacity positions to each expert.
385429
if train:
386430
capacity_factor = hparams.moe_capacity_factor_train
@@ -503,6 +547,17 @@ def transformer_moe_layer_v1(
503547
input_dim
504548
]))
505549

550+
if extras:
551+
extras = mtf.einsum([extras, mtf.cast(dispatch_tensor, extras.dtype)],
552+
mtf.Shape([
553+
outer_batch_dim, experts_dim_unsplit,
554+
num_groups_dim, expert_capacity_dim] + extras_dims))
555+
extras = mtf.reshape(
556+
extras,
557+
mtf.Shape([
558+
outer_batch_dim, experts_dim, batch_dim_unsplit,
559+
expert_capacity_dim] + extras_dims))
560+
506561
# Now feed the expert inputs through the experts.
507562
h = mtf.layers.dense_product(
508563
expert_inputs,
@@ -559,10 +614,15 @@ def _compute_output(hidden, layer_name):
559614
k = _compute_output(k_h, layer_name="k_wo")
560615
outputs.append(q)
561616
outputs.append(k)
562-
return outputs, loss * hparams.moe_loss_coef
617+
return outputs, loss * hparams.moe_loss_coef, None
563618
else:
564619
output = _compute_output(h, layer_name="wo")
565-
return output, loss * hparams.moe_loss_coef
620+
loss *= hparams.moe_loss_coef
621+
622+
if extras:
623+
return output, loss, extras
624+
else:
625+
return output, loss, None
566626

567627

568628
def transformer_moe_layer_v2(

mesh_tensorflow/transformer/transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,8 @@ def __init__(self,
722722
input_full_attention=False,
723723
loss_on_targets_only=False,
724724
loss_denominator=None,
725-
token_dropout_rate=0.0):
725+
token_dropout_rate=0.0,
726+
vocabulary=None):
726727
"""Create a Unitransformer.
727728
728729
Args:
@@ -767,6 +768,7 @@ def __init__(self,
767768
same denominator as was used for the pretraining. This complication
768769
might be avoided by always using loss_denominator = 1.0.
769770
token_dropout_rate: an optional floating point value
771+
vocabulary: an optional vocabularies.Vocabulary
770772
"""
771773
self.layer_stack = layer_stack
772774
self.model_dim = mtf.Dimension("d_model", d_model)
@@ -807,6 +809,7 @@ def __init__(self,
807809
raise ValueError(
808810
"input_full_attention only makes sense with autoregressive")
809811
self.token_dropout_rate = token_dropout_rate
812+
self.vocabulary = vocabulary
810813

811814
@property
812815
def fully_autoregressive(self):

mesh_tensorflow/transformer/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ def build_model(model_type="bitransformer",
172172
input_vocab_size=gin.REQUIRED,
173173
output_vocab_size=gin.REQUIRED,
174174
layout_rules=None,
175-
mesh_shape=None):
175+
mesh_shape=None,
176+
input_vocabulary=None,
177+
target_vocabulary=None):
176178
"""Build a transformer model.
177179
178180
Currently, four types of models are supported:
@@ -214,15 +216,21 @@ def build_model(model_type="bitransformer",
214216
output_vocab_size: an integer
215217
layout_rules: optional, input to mtf.convert_to_layout_rules
216218
mesh_shape: optional, an input to mtf.convert_to_shape()
219+
input_vocabulary: optional, a vocubalaries.Vocabulary
220+
target_vocabulary: optional, a vocubalaries.Vocabulary
221+
217222
Returns:
218223
a Unitransformer or Bitransformer
219224
"""
220225
if model_type == "bitransformer":
221-
return transformer.make_bitransformer(
226+
ret = transformer.make_bitransformer(
222227
input_vocab_size=input_vocab_size,
223228
output_vocab_size=output_vocab_size,
224229
mesh_shape=mesh_shape,
225230
layout=layout_rules)
231+
ret.encoder.vocabulary = input_vocabulary
232+
ret.decoder.vocabulary = target_vocabulary
233+
return ret
226234
elif model_type == "bi_student_teacher":
227235
return transformer.make_bi_student_teacher(
228236
input_vocab_size=input_vocab_size,
@@ -236,7 +244,8 @@ def build_model(model_type="bitransformer",
236244
input_vocab_size=input_vocab_size,
237245
output_vocab_size=output_vocab_size,
238246
mesh_shape=mesh_shape,
239-
layout=layout_rules)
247+
layout=layout_rules,
248+
vocabulary=input_vocabulary)
240249
else:
241250
raise ValueError("unknown model_type")
242251

@@ -2067,7 +2076,9 @@ def get_estimator(model_type, vocabulary, mesh_shape,
20672076
input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
20682077
output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
20692078
layout_rules=layout_rules,
2070-
mesh_shape=mesh_shape)
2079+
mesh_shape=mesh_shape,
2080+
input_vocabulary=inputs_vocabulary(vocabulary),
2081+
target_vocabulary=targets_vocabulary(vocabulary))
20712082

20722083
model_fn = tpu_estimator_model_fn(
20732084
model_type=model_type,

0 commit comments

Comments
 (0)