Skip to content

weight normalization layers throws error with TPU Strategy #1703

Open
@sourcecode369

Description

@sourcecode369

Environment : Google Colab
Hardware Accelerator: TPU
TensorFlow version: 2.2.0-rc3
TensorFlow Addons: 0.9.1

Describe the bug

Weight Normalization layer (tfa.layers.WeightNormalization) throws error when running in TPU

Code to reproduce the issue

import tensorflow as tf
print(f"tf.__version__: {tf.__version__}")
tf.config.optimizer.set_jit(True)
import tensorflow_addons as tfa
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
import tensorflow_datasets as tfds
import os

try: 
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    strategy = tf.distribute.get_strategy()
    raise BaseException('ERROR: Not connected to a TPU runtime.')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

print("REPLICAS: ", tpu_strategy.num_replicas_in_sync)

def get_dataset(batch_size=200):
  datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
                             try_gcs=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.0

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(10000).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
  test_dataset = mnist_test.map(scale).cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

  return train_dataset, test_dataset

def create_model():
  return tf.keras.Sequential(
      [tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(32, 3, activation="elu", input_shape=(28, 28, 1))),
       tf.keras.layers.Flatten(),
       tfa.layers.WeightNormalization(tf.keras.layers.Dense(128, "elu")),
       tfa.layers.WeightNormalization(tf.keras.layers.Dense(10))])
  
train_dataset, test_dataset = get_dataset()

with tpu_strategy.scope():
  model = create_model()
  model.compile(optimizer=tfa.optimizers.Lookahead(tfa.optimizers.RectifiedAdam(
                          lr=1e-3,
                          total_steps=10000,
                          warmup_proportion=0.1,
                          min_lr=1e-5,
                      ),sync_period=6, slow_step_size=0.5),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['sparse_categorical_accuracy'])
  
model.fit(train_dataset,epochs=10,validation_data=test_dataset, callbacks=[tfa.callbacks.TQDMProgressBar()],verbose=0)

Other info / logs


TypeError                                 Traceback (most recent call last)
<ipython-input-1-84d7c4d8b77c> in <module>()
     57                 metrics=['sparse_categorical_accuracy'])
     58 
---> 59 model.fit(train_dataset,epochs=10,validation_data=test_dataset, callbacks=[tfa.callbacks.TQDMProgressBar()],verbose=0)

10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     67 
     68     # Running inside `run_distribute_coordinator` already.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    849                 batch_size=batch_size):
    850               callbacks.on_train_batch_begin(step)
--> 851               tmp_logs = train_function(iterator)
    852               # Catch OutOfRangeError for Datasets of unknown size.
    853               # This blocks until the batch has finished executing.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    578         xla_context.Exit()
    579     else:
--> 580       result = self._call(*args, **kwds)
    581 
    582     if tracing_count == self._get_tracing_count():

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    625       # This is the first call of __call__, so we have to initialize.
    626       initializers = []
--> 627       self._initialize(args, kwds, add_initializers_to=initializers)
    628     finally:
    629       # At this point we know that the initialization is complete (or less

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    504     self._concrete_stateful_fn = (
    505         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 506             *args, **kwds))
    507 
    508     def invalid_creator_scope(*unused_args, **unused_kwds):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2444       args, kwargs = None, None
   2445     with self._lock:
-> 2446       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2447     return graph_function
   2448 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2775 
   2776       self._function_cache.missed.add(call_context_key)
-> 2777       graph_function = self._create_graph_function(args, kwargs)
   2778       self._function_cache.primary[cache_key] = graph_function
   2779       return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2665             arg_names=arg_names,
   2666             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2667             capture_by_value=self._capture_by_value),
   2668         self._function_attributes,
   2669         # Tell the ConcreteFunction to clean up its graph once it goes out of

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    979         _, original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args, **func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    439         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    440         # the function a weak reference to itself to avoid a reference cycle.
--> 441         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    442     weak_wrapped_fn = weakref.ref(wrapped_fn)
    443 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

TypeError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:571 train_function  *
        outputs = self.distribute_strategy.run(
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:170 run  **
        return self.extended.tpu_run(fn, args, kwargs, options)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:863 tpu_run
        return func(args, kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:930 tpu_function
        padding_spec=padding_spec)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/tpu/tpu.py:893 replicate
        padding_spec=padding_spec)[1]
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/tpu/tpu.py:1280 split_compile_and_replicate
        outputs = computation(*computation_inputs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_strategy.py:892 replicated_fn
        result[0] = fn(*replica_args, **replica_kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:531 train_step  **
        y_pred = self(x, training=True)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:927 __call__
        outputs = call_fn(cast_inputs, *args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/sequential.py:291 call
        outputs = layer(inputs, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:897 __call__
        self._maybe_build(inputs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:2416 _maybe_build
        self.build(input_shapes)  # pylint:disable=not-callable
    /usr/local/lib/python3.6/dist-packages/tensorflow_addons/layers/wrappers.py:119 build
        self._naked_clone_layer.set_weights(self.layer.get_weights())
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:1588 get_weights
        return backend.batch_get_value(output_weights)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:3327 batch_get_value
        return [x.numpy() for x in tensors]
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py:3327 <listcomp>
        return [x.numpy() for x in tensors]
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_values.py:102 numpy
        return self.read_value().numpy()
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_values.py:135 read_value
        return self._read_variable_op()
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/tpu_values.py:129 _read_variable_op
        return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py:475 read_variable_op
        resource, dtype=dtype, name=name, ctx=_ctx)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py:502 read_variable_op_eager_fallback
        attrs=_attrs, ctx=ctx, name=name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:75 quick_execute
        raise e
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:60 quick_execute
        inputs, attrs, num_outputs)

    TypeError: An op outside of the function building code is being passed
    a "Graph" tensor. It is possible to have Graph tensors
    leak out of the function building context by including a
    tf.init_scope in your function building code.
    For example, the following function will fail:
      @tf.function
      def has_init_scope():
        my_constant = tf.constant(1.)
        with tf.init_scope():
          added = my_constant * 2
    The graph tensor has name: sequential/weight_normalization/sequential/weight_normalization/kernel_140486508184800/handle:0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions