diff --git a/learned_optimization/__init__.py b/learned_optimization/__init__.py index c99b0bf..47fb2b9 100644 --- a/learned_optimization/__init__.py +++ b/learned_optimization/__init__.py @@ -14,3 +14,7 @@ # limitations under the License. """learned_optimizer module.""" + +from .py_utils import patch_os_path_get_sep + +_OLD_OS_PATH_GET_SEP = patch_os_path_get_sep() diff --git a/learned_optimization/py_utils.py b/learned_optimization/py_utils.py index 840cb95..011b72f 100644 --- a/learned_optimization/py_utils.py +++ b/learned_optimization/py_utils.py @@ -14,6 +14,7 @@ # limitations under the License. """Common python utilities.""" +import os from concurrent import futures from typing import Any, Callable, Sequence import tqdm @@ -26,3 +27,26 @@ def threaded_tqdm_map(threads: int, func: Callable[[Any], Any], for l in tqdm.tqdm(data): future_list.append(executor.submit(func, l)) return [x.result() for x in tqdm.tqdm(future_list)] + + +def patch_os_path_get_sep(): + old_get_sep = os.path._get_sep + + def new_get_sep(path): + """Return the OS separator for the given path. + + If `path` starts with "gs://", "/" is used as the separator. + """ + if isinstance(path, bytes): + gs_prefix = b'gs://' + sep = b'/' + else: + gs_prefix = 'gs://' + sep = '/' + + if not path.startswith(gs_prefix): + sep = old_get_sep(path) + return sep + + os.path._get_sep = new_get_sep + return old_get_sep