Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Adding a new Gradient Estimator for Routing using REINFORCE with a leave-one-out baseline. #374

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 73 additions & 12 deletions mesh_tensorflow/transformer/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from __future__ import division
from __future__ import print_function

import math
import gin

import mesh_tensorflow as mtf
Expand Down Expand Up @@ -65,7 +66,10 @@ def __init__(self,
word_embed_mode=None,
use_second_place_expert_prob=None,
use_second_place_expert_prob_temp=None,
top_n_num_experts_per_token=3):
top_n_num_experts_per_token=3,
rloo=False,
loss_type="load_balance",
p_dot_e=True):
self._hparams = HParams(
moe_gating=moe_gating,
moe_num_experts=num_experts,
Expand Down Expand Up @@ -95,7 +99,10 @@ def __init__(self,
use_second_place_expert_prob),
moe_use_second_place_expert_prob_temp=(
use_second_place_expert_prob_temp),
moe_top_n_num_experts_per_token=top_n_num_experts_per_token)
moe_top_n_num_experts_per_token=top_n_num_experts_per_token,
moe_rloo=rloo,
loss_type=loss_type,
p_dot_e=p_dot_e)
self._activation = activation

def call(self, context, x, losses=None):
Expand Down Expand Up @@ -127,7 +134,8 @@ def call(self, context, x, losses=None):
nonpadding=context.nonpadding,
activation=self._activation,
num_microbatches=context.num_microbatches,
token_embeddings=context.input_embeddings)
token_embeddings=context.input_embeddings,
context=context)
if context.losses is not None:
context.losses.append(loss)
if not has_length_dim:
Expand Down Expand Up @@ -202,7 +210,7 @@ def call(self, context, x, losses=None):
def transformer_moe_layer_v1(
inputs, output_dim, hparams, train, variable_dtype,
layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
num_microbatches=None, token_embeddings=None):
num_microbatches=None, token_embeddings=None, context=None):
"""Local mixture of experts that works well on TPU.

Adapted from the paper https://arxiv.org/abs/1701.06538
Expand Down Expand Up @@ -281,6 +289,8 @@ def transformer_moe_layer_v1(
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
that correspond to the inputs. These can optionally be used to make
routing decisions.
context: a Context object contains extra information that layers need
at call time, as defined in transformer.py.

Returns:
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
Expand Down Expand Up @@ -436,7 +446,8 @@ def transformer_moe_layer_v1(
variable_dtype=variable_dtype,
importance=nonpadding,
num_microbatches=num_microbatches,
token_embeddings=token_embeddings)
token_embeddings=token_embeddings,
context=context)
elif hparams.moe_gating == "ntlb":
dispatch_tensor, combine_tensor, loss = _ntlb_gating(
inputs=inputs,
Expand Down Expand Up @@ -1303,7 +1314,8 @@ def _expert_selection_gating(
def _switch_gating(
inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
hparams, train, variable_dtype, importance=None, name="switch_gating",
num_microbatches=None, token_embeddings=None):
num_microbatches=None, token_embeddings=None,
context=None):
"""Compute Switch gating."""
# SELECT EXPERT
if train:
Expand Down Expand Up @@ -1351,6 +1363,11 @@ def _switch_gating(
expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim)
else:
raise ValueError("Unknown Switch gating policy %s" % policy)
full_expert_gate_log_probs = gate_logits / hparams.moe_switch_temperature
full_expert_gate_log_probs -= mtf.reduce_logsumexp(full_expert_gate_log_probs,
reduced_dim=experts_dim)
expert_gate_log_probs = mtf.gather(full_expert_gate_log_probs, expert_index,
dim=experts_dim)

expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype)

Expand All @@ -1363,21 +1380,40 @@ def _switch_gating(
expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
density_1_proxy *= mtf.cast(
mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
loss = (
load_balance_loss = (
mtf.reduce_mean(density_1_proxy * density_1) *
float(experts_dim.size * experts_dim.size))

kl_with_uniform = (
- math.log(float(experts_dim.size))
- mtf.reduce_logsumexp(full_expert_gate_log_probs,
reduced_dim=group_size_dim)
+ math.log(float(group_size_dim.size)))
if importance:
kl_with_uniform *= mtf.cast(mtf.equal(importance, 1.0),
dtype=raw_gates.dtype)
kl_with_uniform = mtf.reduce_mean(kl_with_uniform)

if hparams.loss_type.lower() == "kl":
loss = kl_with_uniform
else:
loss = load_balance_loss

if num_microbatches and num_microbatches > 1:
tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
num_microbatches))
loss /= num_microbatches

# Logging
if train:
entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
reduced_dim=experts_dim)
entropy = mtf.reduce_sum(
-mtf.exp(full_expert_gate_log_probs) * full_expert_gate_log_probs,
reduced_dim=experts_dim)
batch_entropy = mtf.reduce_mean(entropy)
mtf.scalar_summary(name + "/entropy", batch_entropy)
mtf.scalar_summary("expert_gate", mtf.reduce_mean(expert_gate))
mtf.scalar_summary("tempered_expert_gate",
mtf.reduce_mean(mtf.exp(expert_gate_log_probs)))

mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
total_routed = mtf.reduce_sum(mask_count_experts)
Expand All @@ -1389,7 +1425,25 @@ def _switch_gating(
for fraction in split_fractions:
mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
mtf.reduce_mean(fraction))
mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))
dead_expert_fraction = mtf.reduce_mean(
mtf.cast(mtf.equal(mask_count_experts, 0.),
dtype=raw_gates.dtype))
mtf.scalar_summary("dead_expert_fraction",
dead_expert_fraction)
mtf.scalar_summary("load_balancing_loss",
mtf.reduce_mean(load_balance_loss))
mtf.scalar_summary("kl_with_uniform",
mtf.reduce_mean(kl_with_uniform))

split_expert_index = mtf.rename_dimension(
expert_index, 'batch', 'batch_split')
first_expert_index, second_expert_index = mtf.split(
split_expert_index,
split_expert_index.shape.get_dim_by_name('batch_split'), 2)
duplicate_sample = mtf.reduce_mean(
mtf.cast(mtf.equal(first_expert_index, second_expert_index),
dtype=raw_gates.dtype))
mtf.scalar_summary("duplicate_sample_fraction", duplicate_sample)

# Add in the z_loss for router.
if train and hparams.moe_z_loss is not None:
Expand Down Expand Up @@ -1421,9 +1475,16 @@ def _switch_gating(
# Mask out the experts that have overflowed expert capacity. Sparsify the
# expert_gate.
expert_gate *= expert_mask_flat
if hparams.moe_rloo:
expert_gate_log_probs *= expert_mask_flat
context.expert_gate_log_probs.append(expert_gate_log_probs)

combine_tensor = (
expert_gate * expert_mask_flat *
if hparams.p_dot_e:
combine_tensor = expert_gate
else:
combine_tensor = expert_mask_flat

combine_tensor *= (
mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) *
mtf.one_hot(
mtf.to_int32(position_in_expert),
Expand Down
43 changes: 41 additions & 2 deletions mesh_tensorflow/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def __init__(self,
read_priority=None,
inputs=None,
encoder_inputs=None,
num_microbatches=1):
num_microbatches=1,
expert_gate_log_probs=None):
"""Create a context.

Args:
Expand Down Expand Up @@ -201,6 +202,8 @@ def __init__(self,
decoder.
num_microbatches: integer - greater than one if the step has been
serialized into multiple microbatches to save memory.
expert_gate_log_probs: an optional list of Tensors of expert gate log
probs. This will be used to compute REINFORCE gradients.
"""
self.model = model
self.mesh = mesh
Expand Down Expand Up @@ -235,6 +238,7 @@ def __init__(self,
self.encoder_inputs = encoder_inputs
self.num_microbatches = num_microbatches
self.input_embeddings = None
self.expert_gate_log_probs = expert_gate_log_probs

@property
def train(self):
Expand Down Expand Up @@ -848,6 +852,19 @@ def _compute_loss(self, context, logits, targets, output_vocab_dim):
if self.loss_on_targets_only:
weights *= mtf.cast(mtf.logical_not(delimited_lm_inputs_mask(targets)),
dtype=context.activation_dtype)

# Compute REINFORCE loss
if context.expert_gate_log_probs:
log_probs = mtf.reshape(
mtf.add_n(context.expert_gate_log_probs), loss.shape)
split_loss = mtf.rename_dimension(loss, "batch", "batch_unsplit")
first_loss, second_loss = mtf.split(
split_loss, split_loss.shape.get_dim_by_name("batch_unsplit"), 2)
baseline = mtf.concat([second_loss, first_loss], "batch_unsplit")
baseline = mtf.rename_dimension(baseline, "batch_unsplit", "batch")
loss += mtf.stop_gradient(loss - baseline) * mtf.cast(
log_probs, loss.dtype)

return (mtf.reduce_sum(loss * weights) /
self.loss_denominator(targets, context.num_microbatches))

Expand Down Expand Up @@ -1007,6 +1024,27 @@ def call_simple(self,
logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
loss: an optional Scalar (if compute_loss=True)
"""
if mode == tf.estimator.ModeKeys.TRAIN:

def duplicate_batch(t, batch_dim_name="batch"):
if t:
# Assumes that the batch size is divisible by 2
half_batch_size = t.shape.get_dim_by_name(batch_dim_name).size // 2
t = mtf.rename_dimension(t, batch_dim_name, batch_dim_name + "_slice")
half_batch = mtf.slice(t, 0, half_batch_size,
batch_dim_name + "_slice")
t = mtf.concat([half_batch, half_batch], batch_dim_name + "_slice")
return mtf.rename_dimension(t, batch_dim_name + "_slice",
batch_dim_name)
else:
return t

inputs = duplicate_batch(inputs)
targets = duplicate_batch(targets)
sequence_id = duplicate_batch(sequence_id)
position = duplicate_batch(position)
encoder_sequence_id = duplicate_batch(encoder_sequence_id)

batch_dims = inputs.shape.dims[:-1]
length_dim = inputs.shape.dims[-1]
length_range = mtf.range(inputs.mesh, length_dim, dtype=tf.int32)
Expand Down Expand Up @@ -1061,7 +1099,8 @@ def call_simple(self,
read_priority=read_priority,
inputs=inputs,
encoder_inputs=encoder_inputs,
num_microbatches=num_microbatches)
num_microbatches=num_microbatches,
expert_gate_log_probs=[],)
with tf.variable_scope(self.name):
logits = self._call_internal(context, inputs, targets)
if compute_loss:
Expand Down