diff --git a/torchft/futures.py b/torchft/futures.py index 3ff9801..54d72fe 100644 --- a/torchft/futures.py +++ b/torchft/futures.py @@ -1,9 +1,10 @@ import asyncio import threading from datetime import timedelta -from typing import Optional, TypeVar +from typing import Callable, Optional, TypeVar from unittest.mock import Mock +import torch from torch.futures import Future T = TypeVar("T") @@ -17,7 +18,6 @@ def __init__(self) -> None: def set_timer(self, timer_handle: asyncio.TimerHandle) -> None: assert self._lock.locked() - self._timer_handle = timer_handle self._lock.release() @@ -99,6 +99,18 @@ def callback(fut: Future[T]) -> None: fut.add_done_callback(callback) return timed_fut + def stream_timeout(self, callback: Callable[[], None], timeout: timedelta) -> None: + loop = self._maybe_start_event_loop() + + event = torch.cuda.Event() + event.record() + + def handler() -> None: + if not event.query(): + callback() + + loop.call_soon_threadsafe(self._register_handler, loop, handler, timeout) + @classmethod def _register( cls, @@ -116,6 +128,18 @@ def _register( ) handle.set_timer(timer_handle) + @classmethod + def _register_handler( + cls, + loop, + handler: Callable[[], None], + timeout: timedelta, + ) -> None: + loop.call_later( + timeout.total_seconds(), + handler, + ) + _TIMEOUT_MANAGER = _TimeoutManager() @@ -163,3 +187,18 @@ def callback(fut: Future[T]) -> T: raise TimeoutError(f"future did not complete within {timeout}") return fut.wait() + + +def stream_timeout(callback: Callable[[], None], timeout: timedelta) -> None: + """ + Registers a callback that will be called after the specified timeout if + the current stream doesn't complete in time. + + This uses a cuda Event to track the completion of the current stream. If + the stream is not complete after the timeout, the callback is called. + + Args: + callback: The callback to call if the stream doesn't complete in time. + timeout: The timeout to wait for the stream to complete. + """ + _TIMEOUT_MANAGER.stream_timeout(callback, timeout) diff --git a/torchft/manager.py b/torchft/manager.py index 668189c..8576481 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -33,7 +33,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast +from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar import torch from torch.distributed import ReduceOp, TCPStore @@ -477,7 +477,7 @@ def _async_quorum( self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size) self._quorum_id = quorum_id - if allow_heal: + if allow_heal and False: if quorum.recover_dst_ranks: self._logger.info( f"peers need recovery from us {quorum.recover_dst_ranks}" diff --git a/torchft/process_group.py b/torchft/process_group.py index 0b7507d..7f92f29 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -17,23 +17,25 @@ """ import logging +import queue import threading from contextlib import contextmanager, nullcontext from dataclasses import dataclass from datetime import timedelta +from multiprocessing import Process from multiprocessing.connection import Connection from typing import ( - TYPE_CHECKING, Any, Callable, + cast, Dict, Generator, List, Optional, Tuple, + TYPE_CHECKING, TypeVar, Union, - cast, ) import torch @@ -44,14 +46,14 @@ # pyre-fixme[21]: no attribute ProcessGroupGloo from torch.distributed import ( DeviceMesh, + get_rank, + init_device_mesh, PrefixStore, ProcessGroup as BaseProcessGroup, ProcessGroupGloo as BaseProcessGroupGloo, ProcessGroupNCCL as BaseProcessGroupNCCL, Store, TCPStore, - get_rank, - init_device_mesh, ) from torch.distributed.distributed_c10d import ( AllgatherOptions, @@ -66,6 +68,7 @@ ) from torch.futures import Future from torch.utils._pytree import tree_any +from torchft.futures import stream_timeout from torchft.multiprocessing import _MonitoredPipe @@ -349,16 +352,21 @@ def __init__(self, pg: Optional[ProcessGroup] = None) -> None: super().__init__(0, 1) self._pg: Optional[BaseProcessGroup] = pg + def abort(self) -> None: + if self._pg is not None: + self._pg.abort() + self._pg = None + + def _wrap_work(self, work: Work, opts: object) -> Work: + return work + def configure(self, store_addr: str, rank: int, world_size: int) -> None: pg = self._pg if isinstance(pg, ProcessGroup): pg.configure(store_addr, rank, world_size) return - if pg is not None: - if hasattr(pg, "abort"): - pg.abort() # pyre-fixme[16]: no attribute abort - self._pg = None + self.abort() store = create_store_client(store_addr) @@ -377,7 +385,9 @@ def allgather( input_tensor: List[torch.Tensor], opts: AllgatherOptions, ) -> Work: - return self.parent.allgather(output_tensors, input_tensor, opts) + return self._wrap_work( + self.parent.allgather(output_tensors, input_tensor, opts), opts + ) def allgather_into_tensor_coalesced( self, @@ -385,17 +395,20 @@ def allgather_into_tensor_coalesced( input_tensors: List[torch.Tensor], opts: AllgatherOptions, ) -> Work: - return self.parent.allgather_into_tensor_coalesced( - output_tensors, input_tensors, opts + return self._wrap_work( + self.parent.allgather_into_tensor_coalesced( + output_tensors, input_tensors, opts + ), + opts, ) def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: - return self.parent.allreduce(tensors, opts) + return self._wrap_work(self.parent.allreduce(tensors, opts), opts) def allreduce_coalesced( self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp] ) -> Work: - return self.parent.allreduce_coalesced(tensors, opts) + return self._wrap_work(self.parent.allreduce_coalesced(tensors, opts), opts) def alltoall_base( self, @@ -405,18 +418,21 @@ def alltoall_base( input_split_sizes: List[int], opts: AllToAllOptions, ) -> Work: - return self.parent.alltoall_base( - output_buffer, input_buffer, output_split_sizes, input_split_sizes, opts + return self._wrap_work( + self.parent.alltoall_base( + output_buffer, input_buffer, output_split_sizes, input_split_sizes, opts + ), + opts, ) def barrier(self, opts: BarrierOptions) -> Work: - return self.parent.barrier(opts) + return self._wrap_work(self.parent.barrier(opts), opts) def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: - return self.parent.broadcast(tensor_list, opts) + return self._wrap_work(self.parent.broadcast(tensor_list, opts), opts) def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work: - return self.parent.recv(tensors, src_rank, tag) + return self._wrap_work(self.parent.recv(tensors, src_rank, tag), None) def reduce_scatter( self, @@ -424,7 +440,10 @@ def reduce_scatter( input_tensors: List[List[torch.Tensor]], opts: object, ) -> Work: - return self.parent.reduce_scatter(output_tensors, input_tensors, opts) + return self._wrap_work( + self.parent.reduce_scatter(output_tensors, input_tensors, opts), + opts, + ) def reduce_scatter_tensor_coalesced( self, @@ -432,12 +451,15 @@ def reduce_scatter_tensor_coalesced( input_tensors: List[torch.Tensor], opts: ReduceScatterOptions, ) -> Work: - return self.parent.reduce_scatter_tensor_coalesced( - output_tensors, input_tensors, opts + return self._wrap_work( + self.parent.reduce_scatter_tensor_coalesced( + output_tensors, input_tensors, opts + ), + opts, ) def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: - return self.parent.send(tensors, dst_rank, tag) + return self._wrap_work(self.parent.send(tensors, dst_rank, tag), None) def size(self) -> int: return self.parent.size() @@ -516,6 +538,28 @@ def reduce_scatter_tensor_coalesced( ) +class _WorkCUDATimeout(Work): + def __init__(self, pg: ProcessGroup, work: Work, timeout: timedelta) -> None: + super().__init__() + self._pg = pg + self._work = work + self._timeout = timeout + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + if timeout is not None: + return self._work.wait(timeout) + + self._work.wait() + # free work and tensors + del self._work + + def callback() -> None: + logger.error(f"aborting after {self._timeout}!") + self._pg.abort() + + stream_timeout(callback, self._timeout) + + class ProcessGroupNCCL(ProcessGroupWrapper): """ This is a reconfigurable version of ProcessGroupNCCL. @@ -527,6 +571,16 @@ class ProcessGroupNCCL(ProcessGroupWrapper): abort when reconfiguring, we need to ensure this is safe. """ + def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None: + super().__init__() + self._timeout = timeout + + def _wrap_work(self, work: Work, opts: object) -> Work: + timeout = self._timeout + if hasattr(opts, "timeout") and opts.timeout.total_seconds() > 0: + timeout = opts.timeout + return _WorkCUDATimeout(self, work, timeout) + def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup: pg = BaseProcessGroup(store, rank, world_size) pg._set_default_backend(ProcessGroup.BackendType.NCCL) @@ -883,6 +937,7 @@ def _is_any_cuda(obj: object) -> bool: class _OpMetadata: work: Work stream: Optional[torch.cuda.Stream] + event: Optional[torch.cuda.Event] @contextmanager def set_stream(self) -> Generator[None, None, None]: @@ -913,6 +968,188 @@ def _assert_list(tensors: Union[List[torch.Tensor], List[List[torch.Tensor]]]) - raise TypeError(f"expected list but got {type(tensors)}") +@dataclass +class _StreamID: + device: int + stream_id: int + + +def _current_stream_id() -> Optional[_StreamID]: + if torch.cuda.is_available(): + device: int = torch.cuda.current_stream().device + stream_id: int = torch.cuda.current_stream().stream_id + return _StreamID(device=device, stream_id=stream_id) + + +class _ProcessGroupBabyWorker: + def __init__( + self, + pg: ProcessGroup, + req_pipe: "Connection[object, object]", + future_pipe: "Connection[object, object]", + ) -> None: + self._pg = pg + self._req_pipe = req_pipe + self._future_pipe = future_pipe + + self._streams: Dict[str, torch.cuda.Stream] = {} + self._work: Dict[int, _OpMetadata] = {} + + self._del_queue = queue.SimpleQueue() + self._del_threads = [] + for _ in range(10): + t = threading.Thread(target=self._del_worker, daemon=True) + t.start() + self._del_threads.append(t) + + @classmethod + def run( + cls, + create_pg: Callable[[str, int, int], ProcessGroup], + store_addr: str, + rank: int, + world_size: int, + req_pipe: "Connection[object, object]", + future_pipe: "Connection[object, object]", + ) -> None: + try: + store = create_store_client(store_addr) + + try: + pg = create_pg(store, rank, world_size) + except Exception as e: + logger.exception(f"got exception in worker: {e}") + req_pipe.send(e) + return + req_pipe.send(None) + + handler = cls(pg, req_pipe, future_pipe) + + while True: + op = cast(list[object], req_pipe.recv()) + cmd = cast(str, op[0]) + + key = "handle_" + cmd + if not hasattr(handler, key): + raise RuntimeError(f"unknown command: {cmd}") + + getattr(handler, key)(*op[1:]) + + except Exception as e: + logger.exception(f"worker errored: {e}") + req_pipe.send(e) + raise + + def _del_worker(self) -> None: + while True: + metadata = self._del_queue.get() + with ( + torch.cuda.stream(metadata.stream) + if metadata.stream is not None + else nullcontext() + ): + # if metadata.event is not None: + # metadata.event.synchronize() + del metadata + # get and immediately free object + + def get_stream( + self, + stream_id: Optional[_StreamID], + ) -> Optional[torch.cuda.Stream]: + if stream_id is None: + return None + + # To avoid potential deadlocks we need to preserve the + # stream/synchronization behavior of the parent process. + # We allocate one Stream per stream_id to make sure that we + # don't accidentally introduce cross stream synchronization + # points. + stream_key = f"{stream_id.device}/{stream_id.stream_id}" + if stream_key not in self._streams: + self._streams[stream_key] = torch.cuda.Stream(device=stream_id.device) + return self._streams[stream_key] + + def handle_func( + self, + op_id: int, + stream_id: Optional[_StreamID], + event: Optional[torch.cuda.Event], + ) -> None: + stream = self.get_stream(stream_id) + + with torch.cuda.stream(stream) if stream is not None else nullcontext(): + # Make the stream wait on the cuda event to make sure we + # don't start the operation until the tensor is ready. + if event is not None: + event.wait() + + func_name, args, kwargs = cast( + tuple[str, list[object], dict[str, object]], + self._req_pipe.recv(), + ) + + args = _PickleSafeOptions.unsafe_args(args) + fn = getattr(self._pg, func_name) + work = fn(*args, **kwargs) + assert work is not None, f"got None for work for {func_name}" + self._work[op_id] = _OpMetadata( + work=work, + stream=stream, + event=None, + ) + + def handle_wait( + self, + op_id: int, + stream_id: Optional[_StreamID], + timeout: timedelta, + ) -> None: + metadata = self._work[op_id] + + stream = self.get_stream(stream_id) if metadata.stream is not None else None + + with torch.cuda.stream(stream) if stream is not None else nullcontext(): + # With WorkNCCL this makes the stream wait not the CPU when + # no timeout is passed. + if timeout is not None: + metadata.work.wait(timeout) + else: + metadata.work.wait() # timedelta(seconds=10.0)) + + # Register event on the stream that we can pass to the main + # process. + # event = None + event = ( + torch.cuda.current_stream().record_event( + torch.cuda.Event(interprocess=True) + ) + if metadata.stream is not None + else None + ) + metadata.event = event + + self._req_pipe.send((op_id, event)) + + def handle_del(self, op_id: int) -> None: + metadata = self._work[op_id] + del self._work[op_id] + self._del_queue.put(metadata) + + def handle_future(self, op_id: int) -> None: + def callback(fut: Future[object]) -> None: + try: + fut.wait() + self._future_pipe.send((op_id, _FUTURE_RESULT, None)) + except Exception as e: + self._future_pipe.send((op_id, _FUTURE_EXCEPTION, e)) + + self._work[op_id].work.get_future().add_done_callback(callback) + + def handle_num_active_work(self) -> None: + self._req_pipe.send(len(self._work)) + + class ProcessGroupBaby(ProcessGroup): """ This is a process group that runs the underlying process group in a @@ -982,8 +1219,9 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: self._future_pipe = future_local = _MonitoredPipe(future_local) self._p = p = ctx.Process( - target=self._worker, + target=_ProcessGroupBabyWorker.run, args=( + self.__class__._create_pg, store_addr, rank, world_size, @@ -1016,128 +1254,6 @@ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGrou """ raise NotImplementedError("not implemented") - @classmethod - def _worker( - cls, - store_addr: str, - rank: int, - world_size: int, - req_pipe: "Connection[object, object]", - future_pipe: "Connection[object, object]", - ) -> None: - try: - store = create_store_client(store_addr) - - try: - pg = cls._create_pg(store, rank, world_size) - except Exception as e: - logger.exception(f"got exception in worker: {e}") - req_pipe.send(e) - return - req_pipe.send(None) - - streams: Dict[str, torch.cuda.Stream] = {} - work: Dict[int, _OpMetadata] = {} - - while True: - op = cast(list[object], req_pipe.recv()) - cmd = op[0] - if cmd == "func": - op_id: int - op_id, func_name, args, kwargs, stream_device, stream_id, event = ( - cast( - Tuple[ - int, - str, - list[object], - dict[str, object], - int, - int, - Optional[torch.cuda.Event], - ], - op[1:], - ) - ) - - # To avoid potential deadlocks we need to preserve the - # stream/synchronization behavior of the parent process. - # We allocate one Stream per stream_id to make sure that we - # don't accidentally introduce cross stream synchronization - # points. - if stream_id is not None: - stream_key = f"{stream_device}/{stream_id}" - if stream_key not in streams: - streams[stream_key] = torch.cuda.Stream( - device=stream_device - ) - stream = streams[stream_key] - else: - stream = None - - with ( - torch.cuda.stream(stream) - if stream is not None - else nullcontext() - ): - # Make the stream wait on the cuda event to make sure we - # don't start the operation until the tensor is ready. - if event is not None: - event.wait() - - args = _PickleSafeOptions.unsafe_args(args) - fn = getattr(pg, func_name) - work[op_id] = _OpMetadata( - work=fn(*args, **kwargs), - stream=stream, - ) - elif cmd == "wait": - op_id, timeout = cast(tuple[int, timedelta], op[1:]) - - metadata = work[op_id] - - with metadata.set_stream(): - # With WorkNCCL this makes the stream wait not the CPU when - # no timeout is passed. - if timeout is not None: - metadata.work.wait(timeout) - else: - metadata.work.wait() - - # Register event on the stream that we can pass to the main - # process. - event = ( - torch.cuda.current_stream().record_event( - torch.cuda.Event(interprocess=True) - ) - if metadata.stream is not None - else None - ) - - req_pipe.send((op_id, event)) - elif cmd == "del": - op_id: int = cast(int, op[1]) - del work[op_id] - elif cmd == "future": - op_id: int = cast(int, op[1]) - - def callback(fut: Future[object]) -> None: - try: - fut.wait() - future_pipe.send((op_id, _FUTURE_RESULT, None)) - except Exception as e: - future_pipe.send((op_id, _FUTURE_EXCEPTION, e)) - - work[op_id].work.get_future().add_done_callback(callback) - elif cmd == "num_active_work": - req_pipe.send(len(work)) - else: - raise ValueError(f"unknown cmd: {cmd}") - - except Exception as e: - logger.exception(f"worker errored: {e}") - req_pipe.send(e) - raise - def _future_handler(self, future_pipe: _MonitoredPipe) -> None: try: while True: @@ -1171,7 +1287,10 @@ def _get_future(self, op_id: int) -> Future[object]: def _wait(self, op_id: int, timeout: Optional[timedelta] = None) -> bool: assert self._pipe is not None - self._pipe.send(("wait", op_id, timeout)) + + stream_id = _current_stream_id() + + self._pipe.send(("wait", op_id, stream_id, timeout)) assert self._pipe is not None op_id, event = cast( @@ -1194,8 +1313,7 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: is_cuda = _is_any_cuda(args) - stream_device = torch.cuda.current_stream().device if is_cuda else None - stream_id = torch.cuda.current_stream().stream_id if is_cuda else None + stream_id = _current_stream_id() if is_cuda else None event = ( torch.cuda.current_stream().record_event( torch.cuda.Event(interprocess=True) @@ -1211,12 +1329,15 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: ( "func", op_id, + stream_id, + event, + ), + ) + pipe.send( + ( func, _PickleSafeOptions.safe_args(args), kwargs, - stream_device, - stream_id, - event, ), )