Skip to content

Commit 55f9501

Browse files
authored
Merge pull request #11 from opendilab/dev/dist
dev(hansbug): add torch dist support && add log information inside
2 parents b02aa2d + bf8ed68 commit 55f9501

File tree

12 files changed

+595
-8
lines changed

12 files changed

+595
-8
lines changed

.github/workflows/test.yml

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,23 @@ jobs:
2121
- '3.11'
2222
- '3.12'
2323
- '3.13'
24+
torch-version:
25+
- 'none'
26+
- '2.4'
27+
- '2.7'
28+
exclude:
29+
- python-version: '3.8'
30+
torch-version: '2.7'
31+
- python-version: '3.9'
32+
torch-version: '2.4'
33+
- python-version: '3.10'
34+
torch-version: '2.4'
35+
- python-version: '3.11'
36+
torch-version: '2.4'
37+
- python-version: '3.12'
38+
torch-version: '2.4'
39+
- python-version: '3.13'
40+
torch-version: '2.4'
2441

2542
steps:
2643
- name: Get system version for Linux
@@ -77,8 +94,17 @@ jobs:
7794
uses: actions/setup-python@v6
7895
with:
7996
python-version: ${{ matrix.python-version }}
80-
- name: Install dependencies
97+
- name: Install dependencies With Torch
8198
shell: bash
99+
if: ${{ matrix.torch-version != 'none' }}
100+
run: |
101+
python -m pip install --upgrade pip
102+
pip install --upgrade flake8 setuptools wheel twine
103+
pip install 'torch==${{ matrix.torch-version }}' -r requirements.txt
104+
pip install -r requirements-test.txt
105+
- name: Install dependencies Without Torch
106+
shell: bash
107+
if: ${{ matrix.torch-version == 'none' }}
82108
run: |
83109
python -m pip install --upgrade pip
84110
pip install --upgrade flake8 setuptools wheel twine

ditk/distributed/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .env import is_main_process, is_distributed, get_rank, get_world_size

ditk/distributed/env.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
Distributed training utilities for PyTorch.
3+
4+
This module provides utility functions to handle distributed training scenarios in PyTorch.
5+
It offers convenient methods to check distributed status, get process information, and
6+
determine the main process. The functions gracefully handle both distributed and
7+
non-distributed environments.
8+
9+
Example::
10+
>>> # Check if distributed training is active
11+
>>> if is_distributed():
12+
... print(f"Running on rank {get_rank()} of {get_world_size()}")
13+
>>>
14+
>>> # Execute code only on main process
15+
>>> if is_main_process():
16+
... print("This runs only on the main process")
17+
"""
18+
19+
20+
def is_distributed() -> bool:
21+
"""
22+
Check if distributed training is available and initialized.
23+
24+
This function verifies whether PyTorch distributed training is both available
25+
(compiled with distributed support) and properly initialized. It handles cases
26+
where PyTorch or its distributed module might not be installed.
27+
28+
:return: True if distributed training is available and initialized, False otherwise.
29+
:rtype: bool
30+
31+
Example::
32+
>>> if is_distributed():
33+
... print("Distributed training is active")
34+
... else:
35+
... print("Running in single-process mode")
36+
"""
37+
try:
38+
import torch
39+
import torch.distributed as dist
40+
except (ImportError, ModuleNotFoundError):
41+
return False
42+
43+
# Check if distributed is available (compiled with distributed support) and is initialized
44+
return dist.is_available() and dist.is_initialized()
45+
46+
47+
def get_rank() -> int:
48+
"""
49+
Get the global rank of the current process.
50+
51+
Returns the global rank (process ID) of the current process in distributed training.
52+
In non-distributed environments, this function returns 0, making it safe to use
53+
in both distributed and single-process scenarios.
54+
55+
:return: Global rank of the current process. Returns 0 if distributed training is not active.
56+
:rtype: int
57+
58+
Example::
59+
>>> rank = get_rank()
60+
>>> print(f"Current process rank: {rank}")
61+
"""
62+
if is_distributed():
63+
import torch.distributed as dist
64+
return dist.get_rank()
65+
else:
66+
return 0
67+
68+
69+
def get_world_size() -> int:
70+
"""
71+
Get the total number of processes across all nodes.
72+
73+
Returns the total number of processes participating in distributed training.
74+
In non-distributed environments, this function returns 1, ensuring consistent
75+
behavior across different training setups.
76+
77+
:return: Total number of processes in the distributed training. Returns 1 if distributed training is not active.
78+
:rtype: int
79+
80+
Example::
81+
>>> world_size = get_world_size()
82+
>>> print(f"Total number of processes: {world_size}")
83+
"""
84+
if is_distributed():
85+
import torch.distributed as dist
86+
return dist.get_world_size()
87+
else:
88+
return 1
89+
90+
91+
# Utility functions for easier usage
92+
def is_main_process() -> bool:
93+
"""
94+
Check if the current process is the main process (global rank 0).
95+
96+
This function is useful for executing code that should only run once across
97+
all processes, such as logging, saving checkpoints, or printing progress.
98+
In non-distributed environments, it always returns True.
99+
100+
:return: True if current process is main process (rank 0) or if distributed is not available.
101+
:rtype: bool
102+
103+
Example::
104+
>>> if is_main_process():
105+
... print("Saving model checkpoint...")
106+
... # Save checkpoint logic here
107+
"""
108+
if not is_distributed():
109+
return True
110+
return get_rank() == 0

ditk/logging/rich.py

Lines changed: 141 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,176 @@
1+
"""
2+
Rich logging utilities for distributed training environments.
3+
4+
This module provides enhanced logging capabilities using the Rich library,
5+
with special support for distributed training scenarios. It includes utilities
6+
for creating properly formatted console outputs with terminal width detection,
7+
distributed rank information, and rich text formatting.
8+
9+
The module automatically detects distributed training environments and can
10+
include rank information in log messages to help distinguish between different
11+
processes in multi-GPU or multi-node training setups.
12+
"""
13+
114
import logging
15+
import os
216
import shutil
317
from functools import lru_cache
18+
from typing import Optional
419

520
from rich.console import Console
621
from rich.logging import RichHandler
722

823
import ditk
924
from .base import _LogLevelType
25+
from ..distributed import is_distributed, get_rank, get_world_size
1026

1127
# This value is set due the requirement of displaying the tables
1228
_DEFAULT_WIDTH = 170
1329

1430

1531
@lru_cache()
1632
def _get_terminal_width() -> int:
33+
"""
34+
Get the current terminal width with caching for performance.
35+
36+
This function detects the terminal width and caches the result to avoid
37+
repeated system calls. It falls back to a default width if terminal
38+
size detection fails.
39+
40+
:return: The terminal width in characters.
41+
:rtype: int
42+
43+
Example::
44+
>>> width = _get_terminal_width()
45+
>>> print(f"Terminal width: {width}")
46+
Terminal width: 170
47+
"""
1748
width, _ = shutil.get_terminal_size(fallback=(_DEFAULT_WIDTH, 24))
1849
return width
1950

2051

2152
@lru_cache()
2253
def _get_rich_console(use_stdout: bool = False) -> Console:
54+
"""
55+
Create and cache a Rich Console instance with appropriate configuration.
56+
57+
This function creates a Rich Console with the detected terminal width
58+
and configures output to stderr by default (or stdout if specified).
59+
The result is cached to ensure consistent console usage across the application.
60+
61+
:param use_stdout: Whether to use stdout instead of stderr for output.
62+
:type use_stdout: bool
63+
64+
:return: A configured Rich Console instance.
65+
:rtype: Console
66+
67+
Example::
68+
>>> console = _get_rich_console()
69+
>>> console.print("Hello, World!")
70+
Hello, World!
71+
"""
2372
return Console(width=_get_terminal_width(), stderr=not use_stdout)
2473

2574

26-
_RICH_FMT = logging.Formatter(fmt="%(message)s", datefmt="[%m-%d %H:%M:%S]")
75+
def _get_log_format(
76+
include_distributed: bool = True,
77+
distributed_format: str = "[Rank {rank}/{world_size}][PID: {pid}]"
78+
) -> str:
79+
"""
80+
Get the appropriate log format based on distributed training status.
81+
82+
This function generates a logging format string that optionally includes
83+
distributed training information such as rank and world size. When
84+
distributed training is detected and enabled, it prepends rank information
85+
to log messages to help identify which process generated each log entry.
86+
87+
:param include_distributed: Whether to include distributed information in the format.
88+
:type include_distributed: bool
89+
:param distributed_format: Format string template for distributed info, should contain
90+
{rank} and {world_size} placeholders.
91+
:type distributed_format: str
2792
93+
:return: Format string for logging that includes distributed info if applicable.
94+
:rtype: str
95+
96+
Example::
97+
>>> # In a distributed environment
98+
>>> format_str = _get_log_format(include_distributed=True)
99+
>>> print(format_str)
100+
[Rank 0/4] %(message)s
101+
102+
>>> # Without distributed info
103+
>>> format_str = _get_log_format(include_distributed=False)
104+
>>> print(format_str)
105+
%(message)s
106+
"""
107+
if include_distributed and is_distributed():
108+
rank = get_rank()
109+
world_size = get_world_size()
110+
prefix = distributed_format.format(rank=rank, world_size=world_size, pid=os.getpid())
111+
return f"{prefix} %(message)s"
112+
else:
113+
return "%(message)s"
114+
115+
116+
def _create_rich_handler(
117+
use_stdout: bool = False,
118+
level: _LogLevelType = logging.NOTSET,
119+
include_distributed: bool = True,
120+
distributed_format: Optional[str] = None
121+
) -> RichHandler:
122+
"""
123+
Create a Rich handler with optional distributed training information.
124+
125+
This function creates a fully configured RichHandler that provides
126+
enhanced logging output with rich text formatting, traceback highlighting,
127+
and optional distributed training rank information. The handler is
128+
configured with appropriate formatters and console settings.
129+
130+
:param use_stdout: Whether to use stdout instead of stderr for log output.
131+
:type use_stdout: bool
132+
:param level: Logging level threshold for this handler.
133+
:type level: _LogLevelType
134+
:param include_distributed: Whether to include distributed rank information
135+
in log messages when running in distributed mode.
136+
:type include_distributed: bool
137+
:param distributed_format: Custom format template for distributed info.
138+
If None, uses a default Rich markup format with
139+
bold blue styling. Should contain {rank} and
140+
{world_size} placeholders.
141+
:type distributed_format: Optional[str]
142+
143+
:return: A configured RichHandler instance ready for use with Python logging.
144+
:rtype: RichHandler
145+
146+
Example::
147+
>>> # Create a basic rich handler
148+
>>> handler = _create_rich_handler()
149+
>>> logger = logging.getLogger("my_logger")
150+
>>> logger.addHandler(handler)
151+
>>> logger.info("This will be beautifully formatted!")
152+
153+
>>> # Create handler with custom distributed format
154+
>>> handler = _create_rich_handler(
155+
... distributed_format="[Process {rank}]",
156+
... level=logging.INFO
157+
... )
158+
"""
159+
if distributed_format is None:
160+
distributed_format = "[bold blue]\\[Rank {rank}/{world_size}][/bold blue][bold blue]\\[PID: {pid}][/bold blue]" # Rich markup support
161+
162+
# Dynamically create formatter with distributed information
163+
rich_fmt = logging.Formatter(
164+
fmt=_get_log_format(include_distributed, distributed_format),
165+
datefmt="[%m-%d %H:%M:%S]"
166+
)
28167

29-
def _create_rich_handler(use_stdout: bool = False, level: _LogLevelType = logging.NOTSET) -> RichHandler:
30168
handler = RichHandler(
31169
level=level,
32170
console=_get_rich_console(use_stdout),
33171
rich_tracebacks=True,
34172
markup=True,
35173
tracebacks_suppress=[ditk],
36174
)
37-
handler.setFormatter(_RICH_FMT)
175+
handler.setFormatter(rich_fmt)
38176
return handler

0 commit comments

Comments
 (0)