diff --git a/tf2_imdb.py b/tf2_imdb.py index 8276cf1..1ffbe63 100644 --- a/tf2_imdb.py +++ b/tf2_imdb.py @@ -13,7 +13,15 @@ print('Using Tensorflow version: {}, and Keras version: {}.'.format( tf.__version__, tf.keras.__version__)) +print(tf.config.get_visible_devices()) +# create a distribution strategy +if tf.config.list_physical_devices('GPU'): + strategy = tf.distribute.MirroredStrategy() +else: # a default fallback strategy + strategy = tf.distribute.get_strategy() + +print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) class DetectSentiment: def __init__(self): @@ -40,7 +48,9 @@ def predict(self, text): idx = oov_idx v[0, i+1] = idx - p = self.model.predict(v, batch_size=1) + with strategy.scope(): + p = self.model.predict(v, batch_size=1) + return float(p[0, 0])