Skip to content

Commit 7a89ce9

Browse files
committed
CheckpointServer: fast streaming parallel transfers
1 parent 9533676 commit 7a89ce9

File tree

6 files changed

+855
-40
lines changed

6 files changed

+855
-40
lines changed

torchft/_serialization.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import pickle
2+
from dataclasses import dataclass
3+
from io import BufferedIOBase
4+
from typing import Any, Dict, List, Tuple
5+
6+
import torch
7+
import torch._weights_only_unpickler as _weights_only_unpickler
8+
from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION
9+
10+
11+
__all__: List[str] = []
12+
13+
14+
@dataclass
15+
class _Entry:
16+
key: str
17+
is_storage: bool
18+
length: int
19+
20+
21+
_weights_only_unpickler._add_safe_globals([_Entry])
22+
23+
24+
class _PseudoZipFile:
25+
def __init__(self) -> None:
26+
self.records: Dict[str, Tuple[object, int]] = {}
27+
28+
def write_record(self, key: str, data: object, length: int) -> None:
29+
self.records[key] = (data, length)
30+
31+
def write_to(self, f: BufferedIOBase) -> None:
32+
entries = []
33+
for key, (data, length) in self.records.items():
34+
entries.append(
35+
_Entry(
36+
key=key,
37+
is_storage=isinstance(data, torch.UntypedStorage),
38+
length=length,
39+
)
40+
)
41+
42+
pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL)
43+
44+
for key, (data, length) in self.records.items():
45+
if isinstance(data, bytes):
46+
f.write(data)
47+
elif isinstance(data, str):
48+
f.write(data.encode("utf-8"))
49+
elif isinstance(data, torch.UntypedStorage):
50+
data._write_file(f, False, False, 1)
51+
else:
52+
raise TypeError(f"unknown type: {type(data)}")
53+
54+
def read_from(self, f: BufferedIOBase) -> None:
55+
entries = _weights_only_unpickler.load(f)
56+
57+
for entry in entries:
58+
data = f.read(entry.length)
59+
if entry.is_storage:
60+
storage = torch.frombuffer(
61+
data,
62+
dtype=torch.uint8,
63+
).untyped_storage()
64+
65+
self.records[entry.key] = (
66+
storage,
67+
entry.length,
68+
)
69+
else:
70+
self.records[entry.key] = (data, entry.length)
71+
72+
def has_record(self, key: str) -> bool:
73+
return key in self.records
74+
75+
def get_record(self, key: str) -> object:
76+
return self.records[key][0]
77+
78+
def get_storage_from_record(
79+
self, key: str, _length: int, _type: int
80+
) -> torch.Tensor:
81+
return torch.tensor(self.records[key][0], dtype=torch.uint8)
82+
83+
def serialization_id(self) -> str:
84+
return "torchft"
85+
86+
87+
def _streaming_save(
88+
obj: object,
89+
f: BufferedIOBase,
90+
pickle_module: Any = pickle,
91+
pickle_protocol: int = DEFAULT_PROTOCOL,
92+
) -> None:
93+
"""
94+
Save the object to a file-like object in a streaming fashion compatible with
95+
network sockets.
96+
97+
This behaves similarly to :func:`torch.save` with a few notable differences:
98+
99+
* A non-seekable file like object can be used when loading.
100+
* No forwards/backwards compatiblity is provided for the serialization
101+
format. This is only intended to be used with a single version of PyTorch
102+
with transient storage (i.e. sockets or temp files).
103+
* mmap is not supported
104+
105+
See :func:`torch.save` for more details on specific arguments.
106+
"""
107+
108+
zip_file = _PseudoZipFile()
109+
_save(
110+
obj,
111+
zip_file=zip_file,
112+
pickle_module=pickle_module,
113+
pickle_protocol=pickle_protocol,
114+
_disable_byteorder_record=False,
115+
)
116+
zip_file.write_to(f)
117+
118+
119+
def _streaming_load(
120+
f: BufferedIOBase,
121+
map_location: MAP_LOCATION = None,
122+
pickle_module: Any = None,
123+
*,
124+
weights_only: bool = True,
125+
**pickle_load_args: Any,
126+
) -> object:
127+
"""
128+
Load the object from a file-like object in a streaming fashion compatible with
129+
network sockets.
130+
131+
See :func:`_streaming_save` for more details about the streaming behavior.
132+
133+
See :func:`torch.load` for more details on specific arguments.
134+
"""
135+
if weights_only:
136+
if pickle_module is not None:
137+
raise RuntimeError(
138+
"Can not safely load weights when explicit pickle_module is specified"
139+
)
140+
pickle_module = _weights_only_unpickler
141+
else:
142+
if pickle_module is None:
143+
pickle_module = pickle
144+
145+
if "encoding" not in pickle_load_args.keys():
146+
pickle_load_args["encoding"] = "utf-8"
147+
148+
zip_file = _PseudoZipFile()
149+
zip_file.read_from(f)
150+
return _load(
151+
zip_file=zip_file,
152+
map_location=map_location,
153+
pickle_module=pickle_module,
154+
**pickle_load_args,
155+
)

0 commit comments

Comments
 (0)