diff --git a/tools/train.py b/tools/train.py index 9fdf56cad..b28f8c547 100644 --- a/tools/train.py +++ b/tools/train.py @@ -42,6 +42,7 @@ def set_default_flags(flags): if __name__ == "__main__": + set_default_flags({'FLAGS_enable_cublas_tensor_op_math': True, }) args = config.parse_args() cfg = config.get_config(args.config, overrides=args.override, show=False)