Skip to content

Support Colocated Python Checkpointing #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions pathwaysutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ def _is_persistence_enabled() -> bool:
return False


def _is_colocated_python_enabled() -> bool:
"""Returns whether colocated python checkpointing is enabled.

This function checks the environment variable
ENABLE_COLOCATED_PYTHON_CHECKPOINTING to determine whether colocated python
checkpointing is enabled. If the variable is set to "1", it is enabled. If the
variable is set to "0" or unset, it is disabled.

Returns:
True if colocated python checkpointing is enabled, False otherwise.
"""
if "ENABLE_COLOCATED_PYTHON_CHECKPOINTING" in os.environ:
if os.environ["ENABLE_COLOCATED_PYTHON_CHECKPOINTING"] == "1":
return True
if os.environ["ENABLE_COLOCATED_PYTHON_CHECKPOINTING"] == "0":
return False
else:
raise ValueError(
"ENABLE_COLOCATED_PYTHON_CHECKPOINTING must be set to 1/0 or"
" unset, got: "
+ os.environ["ENABLE_COLOCATED_PYTHON_CHECKPOINTING"]
)
return False


def initialize() -> None:
"""Initializes pathwaysutils.

Expand All @@ -93,8 +118,16 @@ def initialize() -> None:
proxy_backend.register_backend_factory()
profiling.monkey_patch_jax()
# TODO: b/365549911 - Remove when OCDBT-compatible
if _is_persistence_enabled():
orbax_handler.register_pathways_handlers(datetime.timedelta(hours=1))
if _is_persistence_enabled() ^ _is_colocated_python_enabled():
orbax_handler.register_pathways_handlers(
datetime.timedelta(hours=1),
use_colocated_python=_is_colocated_python_enabled(),
)
elif _is_persistence_enabled() and _is_colocated_python_enabled():
raise ValueError(
"Invalid configuration: ENABLE_PATHWAYS_PERSISTENCE and"
" ENABLE_COLOCATED_PYTHON_CHECKPOINTING cannot both be enabled."
)

# Turn off JAX compilation cache because Pathways handles its own
# compilation cache.
Expand All @@ -103,4 +136,4 @@ def initialize() -> None:
else:
_logger.debug(
"Did not detect Pathways-on-Cloud backend. No changes applied."
)
)
32 changes: 22 additions & 10 deletions pathwaysutils/persistence/orbax_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import typing

import jax
from orbax.checkpoint import experimental
from orbax.checkpoint import future
from orbax.checkpoint import type_handlers
from pathwaysutils.persistence import helper


ColocatedPythonArrayHandler = experimental.ColocatedPythonArrayHandler

logger = logging.getLogger(__name__)

ParamInfo = type_handlers.ParamInfo
Expand Down Expand Up @@ -192,15 +195,24 @@ async def deserialize(

def register_pathways_handlers(
read_timeout: datetime.timedelta | None = None,
use_colocated_python: bool = False,
):
"""Function that must be called before saving or restoring with Pathways."""
logger.debug(
"Registering CloudPathwaysArrayHandler (Pathways Persistence API)."
)
type_handlers.register_type_handler(
jax.Array,
CloudPathwaysArrayHandler(
read_timeout=read_timeout,
),
override=True,
)
if use_colocated_python:
logger.debug("Registering ColocatedPythonArrayHandler.")
type_handlers.register_type_handler(
jax.Array,
ColocatedPythonArrayHandler(),
override=True,
)
else:
logger.debug(
"Registering CloudPathwaysArrayHandler (Pathways Persistence API)."
)
type_handlers.register_type_handler(
jax.Array,
CloudPathwaysArrayHandler(
read_timeout=read_timeout,
),
override=True,
)
85 changes: 67 additions & 18 deletions pathwaysutils/test/pathwaysutils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,44 @@

class PathwaysutilsTest(parameterized.TestCase):

def test_first_initialize(self):
@parameterized.named_parameters(
("persistence", "ENABLE_PATHWAYS_PERSISTENCE"),
("colocated_python", "ENABLE_COLOCATED_PYTHON_CHECKPOINTING"),
)
def test_first_initialize(self, flag):
jax.config.update("jax_platforms", "proxy")
pathwaysutils._initialization_count = 0

with self.assertLogs(pathwaysutils._logger, level="DEBUG") as logs:
pathwaysutils.initialize()
with mock.patch.dict(os.environ, {flag: "1"}, clear=True):
with self.assertLogs("pathwaysutils", level="DEBUG") as logs:
pathwaysutils.initialize()

self.assertLen(logs.output, 2)
self.assertIn(
"Starting initialize.", logs.output[0]
)
self.assertLen(logs.output, 3)
self.assertIn("Starting initialize.", logs.output[0])
self.assertIn(
"Detected Pathways-on-Cloud backend. Applying changes.", logs.output[1]
)
if flag == "ENABLE_PATHWAYS_PERSISTENCE":
self.assertIn(
"Registering CloudPathwaysArrayHandler", logs.output[2]
)
else:
self.assertIn("Registering ColocatedPythonArrayHandler", logs.output[2])

def test_initialize_with_both_enabled_raises_error(self):
jax.config.update("jax_platforms", "proxy")
pathwaysutils._initialization_count = 0

with mock.patch.dict(
os.environ,
{
"ENABLE_PATHWAYS_PERSISTENCE": "1",
"ENABLE_COLOCATED_PYTHON_CHECKPOINTING": "1",
},
clear=True,
):
with self.assertRaises(ValueError):
pathwaysutils.initialize()

@parameterized.named_parameters(
("initialization_count 1", 1),
Expand Down Expand Up @@ -78,17 +102,42 @@ def test_is_pathways_backend_used(self, platform: str):
self.assertTrue(pathwaysutils.is_pathways_backend_used())

def test_persistence_enabled(self):
os.environ["ENABLE_PATHWAYS_PERSISTENCE"] = "1"
self.assertTrue(pathwaysutils._is_persistence_enabled())

os.environ["ENABLE_PATHWAYS_PERSISTENCE"] = "0"
self.assertFalse(pathwaysutils._is_persistence_enabled())

os.environ["ENABLE_PATHWAYS_PERSISTENCE"] = ""
self.assertRaises(ValueError, pathwaysutils._is_persistence_enabled)

del os.environ["ENABLE_PATHWAYS_PERSISTENCE"]
self.assertFalse(pathwaysutils._is_persistence_enabled())
with mock.patch.dict(
os.environ, {"ENABLE_PATHWAYS_PERSISTENCE": "1"}, clear=True
):
self.assertTrue(pathwaysutils._is_persistence_enabled())

with mock.patch.dict(
os.environ, {"ENABLE_PATHWAYS_PERSISTENCE": "0"}, clear=True
):
self.assertFalse(pathwaysutils._is_persistence_enabled())

with mock.patch.dict(
os.environ, {"ENABLE_PATHWAYS_PERSISTENCE": ""}, clear=True
):
self.assertRaises(ValueError, pathwaysutils._is_persistence_enabled)

with mock.patch.dict(os.environ, {}, clear=True):
self.assertFalse(pathwaysutils._is_persistence_enabled())

def test_colocated_python_enabled(self):
with mock.patch.dict(
os.environ, {"ENABLE_COLOCATED_PYTHON_CHECKPOINTING": "1"}, clear=True
):
self.assertTrue(pathwaysutils._is_colocated_python_enabled())

with mock.patch.dict(
os.environ, {"ENABLE_COLOCATED_PYTHON_CHECKPOINTING": "0"}, clear=True
):
self.assertFalse(pathwaysutils._is_colocated_python_enabled())

with mock.patch.dict(
os.environ, {"ENABLE_COLOCATED_PYTHON_CHECKPOINTING": ""}, clear=True
):
self.assertRaises(ValueError, pathwaysutils._is_colocated_python_enabled)

with mock.patch.dict(os.environ, {}, clear=True):
self.assertFalse(pathwaysutils._is_colocated_python_enabled())


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions pathwaysutils/test/proxy_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from jax.lib.xla_extension import ifrt_proxy
from pathwaysutils import proxy_backend


from absl.testing import absltest


Expand Down
Loading