|  | 
|  | 1 | +""" | 
|  | 2 | +Rich logging utilities for distributed training environments. | 
|  | 3 | +
 | 
|  | 4 | +This module provides enhanced logging capabilities using the Rich library, | 
|  | 5 | +with special support for distributed training scenarios. It includes utilities | 
|  | 6 | +for creating properly formatted console outputs with terminal width detection, | 
|  | 7 | +distributed rank information, and rich text formatting. | 
|  | 8 | +
 | 
|  | 9 | +The module automatically detects distributed training environments and can | 
|  | 10 | +include rank information in log messages to help distinguish between different | 
|  | 11 | +processes in multi-GPU or multi-node training setups. | 
|  | 12 | +""" | 
|  | 13 | + | 
| 1 | 14 | import logging | 
|  | 15 | +import os | 
| 2 | 16 | import shutil | 
| 3 | 17 | from functools import lru_cache | 
|  | 18 | +from typing import Optional | 
| 4 | 19 | 
 | 
| 5 | 20 | from rich.console import Console | 
| 6 | 21 | from rich.logging import RichHandler | 
| 7 | 22 | 
 | 
| 8 | 23 | import ditk | 
| 9 | 24 | from .base import _LogLevelType | 
|  | 25 | +from ..distributed import is_distributed, get_rank, get_world_size | 
| 10 | 26 | 
 | 
| 11 | 27 | # This value is set due the requirement of displaying the tables | 
| 12 | 28 | _DEFAULT_WIDTH = 170 | 
| 13 | 29 | 
 | 
| 14 | 30 | 
 | 
| 15 | 31 | @lru_cache() | 
| 16 | 32 | def _get_terminal_width() -> int: | 
|  | 33 | +    """ | 
|  | 34 | +    Get the current terminal width with caching for performance. | 
|  | 35 | +
 | 
|  | 36 | +    This function detects the terminal width and caches the result to avoid | 
|  | 37 | +    repeated system calls. It falls back to a default width if terminal | 
|  | 38 | +    size detection fails. | 
|  | 39 | +
 | 
|  | 40 | +    :return: The terminal width in characters. | 
|  | 41 | +    :rtype: int | 
|  | 42 | +
 | 
|  | 43 | +    Example:: | 
|  | 44 | +        >>> width = _get_terminal_width() | 
|  | 45 | +        >>> print(f"Terminal width: {width}") | 
|  | 46 | +        Terminal width: 170 | 
|  | 47 | +    """ | 
| 17 | 48 |     width, _ = shutil.get_terminal_size(fallback=(_DEFAULT_WIDTH, 24)) | 
| 18 | 49 |     return width | 
| 19 | 50 | 
 | 
| 20 | 51 | 
 | 
| 21 | 52 | @lru_cache() | 
| 22 | 53 | def _get_rich_console(use_stdout: bool = False) -> Console: | 
|  | 54 | +    """ | 
|  | 55 | +    Create and cache a Rich Console instance with appropriate configuration. | 
|  | 56 | +
 | 
|  | 57 | +    This function creates a Rich Console with the detected terminal width | 
|  | 58 | +    and configures output to stderr by default (or stdout if specified). | 
|  | 59 | +    The result is cached to ensure consistent console usage across the application. | 
|  | 60 | +
 | 
|  | 61 | +    :param use_stdout: Whether to use stdout instead of stderr for output. | 
|  | 62 | +    :type use_stdout: bool | 
|  | 63 | +
 | 
|  | 64 | +    :return: A configured Rich Console instance. | 
|  | 65 | +    :rtype: Console | 
|  | 66 | +
 | 
|  | 67 | +    Example:: | 
|  | 68 | +        >>> console = _get_rich_console() | 
|  | 69 | +        >>> console.print("Hello, World!") | 
|  | 70 | +        Hello, World! | 
|  | 71 | +    """ | 
| 23 | 72 |     return Console(width=_get_terminal_width(), stderr=not use_stdout) | 
| 24 | 73 | 
 | 
| 25 | 74 | 
 | 
| 26 |  | -_RICH_FMT = logging.Formatter(fmt="%(message)s", datefmt="[%m-%d %H:%M:%S]") | 
|  | 75 | +def _get_log_format( | 
|  | 76 | +        include_distributed: bool = True, | 
|  | 77 | +        distributed_format: str = "[Rank {rank}/{world_size}][PID: {pid}]" | 
|  | 78 | +) -> str: | 
|  | 79 | +    """ | 
|  | 80 | +    Get the appropriate log format based on distributed training status. | 
|  | 81 | +
 | 
|  | 82 | +    This function generates a logging format string that optionally includes | 
|  | 83 | +    distributed training information such as rank and world size. When | 
|  | 84 | +    distributed training is detected and enabled, it prepends rank information | 
|  | 85 | +    to log messages to help identify which process generated each log entry. | 
|  | 86 | +
 | 
|  | 87 | +    :param include_distributed: Whether to include distributed information in the format. | 
|  | 88 | +    :type include_distributed: bool | 
|  | 89 | +    :param distributed_format: Format string template for distributed info, should contain | 
|  | 90 | +                              {rank} and {world_size} placeholders. | 
|  | 91 | +    :type distributed_format: str | 
| 27 | 92 | 
 | 
|  | 93 | +    :return: Format string for logging that includes distributed info if applicable. | 
|  | 94 | +    :rtype: str | 
|  | 95 | +
 | 
|  | 96 | +    Example:: | 
|  | 97 | +        >>> # In a distributed environment | 
|  | 98 | +        >>> format_str = _get_log_format(include_distributed=True) | 
|  | 99 | +        >>> print(format_str) | 
|  | 100 | +        [Rank 0/4] %(message)s | 
|  | 101 | +
 | 
|  | 102 | +        >>> # Without distributed info | 
|  | 103 | +        >>> format_str = _get_log_format(include_distributed=False) | 
|  | 104 | +        >>> print(format_str) | 
|  | 105 | +        %(message)s | 
|  | 106 | +    """ | 
|  | 107 | +    if include_distributed and is_distributed(): | 
|  | 108 | +        rank = get_rank() | 
|  | 109 | +        world_size = get_world_size() | 
|  | 110 | +        prefix = distributed_format.format(rank=rank, world_size=world_size, pid=os.getpid()) | 
|  | 111 | +        return f"{prefix} %(message)s" | 
|  | 112 | +    else: | 
|  | 113 | +        return "%(message)s" | 
|  | 114 | + | 
|  | 115 | + | 
|  | 116 | +def _create_rich_handler( | 
|  | 117 | +        use_stdout: bool = False, | 
|  | 118 | +        level: _LogLevelType = logging.NOTSET, | 
|  | 119 | +        include_distributed: bool = True, | 
|  | 120 | +        distributed_format: Optional[str] = None | 
|  | 121 | +) -> RichHandler: | 
|  | 122 | +    """ | 
|  | 123 | +    Create a Rich handler with optional distributed training information. | 
|  | 124 | +
 | 
|  | 125 | +    This function creates a fully configured RichHandler that provides | 
|  | 126 | +    enhanced logging output with rich text formatting, traceback highlighting, | 
|  | 127 | +    and optional distributed training rank information. The handler is | 
|  | 128 | +    configured with appropriate formatters and console settings. | 
|  | 129 | +
 | 
|  | 130 | +    :param use_stdout: Whether to use stdout instead of stderr for log output. | 
|  | 131 | +    :type use_stdout: bool | 
|  | 132 | +    :param level: Logging level threshold for this handler. | 
|  | 133 | +    :type level: _LogLevelType | 
|  | 134 | +    :param include_distributed: Whether to include distributed rank information | 
|  | 135 | +                               in log messages when running in distributed mode. | 
|  | 136 | +    :type include_distributed: bool | 
|  | 137 | +    :param distributed_format: Custom format template for distributed info. | 
|  | 138 | +                              If None, uses a default Rich markup format with | 
|  | 139 | +                              bold blue styling. Should contain {rank} and | 
|  | 140 | +                              {world_size} placeholders. | 
|  | 141 | +    :type distributed_format: Optional[str] | 
|  | 142 | +
 | 
|  | 143 | +    :return: A configured RichHandler instance ready for use with Python logging. | 
|  | 144 | +    :rtype: RichHandler | 
|  | 145 | +
 | 
|  | 146 | +    Example:: | 
|  | 147 | +        >>> # Create a basic rich handler | 
|  | 148 | +        >>> handler = _create_rich_handler() | 
|  | 149 | +        >>> logger = logging.getLogger("my_logger") | 
|  | 150 | +        >>> logger.addHandler(handler) | 
|  | 151 | +        >>> logger.info("This will be beautifully formatted!") | 
|  | 152 | +
 | 
|  | 153 | +        >>> # Create handler with custom distributed format | 
|  | 154 | +        >>> handler = _create_rich_handler( | 
|  | 155 | +        ...     distributed_format="[Process {rank}]", | 
|  | 156 | +        ...     level=logging.INFO | 
|  | 157 | +        ... ) | 
|  | 158 | +    """ | 
|  | 159 | +    if distributed_format is None: | 
|  | 160 | +        distributed_format = "[bold blue]\\[Rank {rank}/{world_size}][/bold blue][bold blue]\\[PID: {pid}][/bold blue]"  # Rich markup support | 
|  | 161 | + | 
|  | 162 | +    # Dynamically create formatter with distributed information | 
|  | 163 | +    rich_fmt = logging.Formatter( | 
|  | 164 | +        fmt=_get_log_format(include_distributed, distributed_format), | 
|  | 165 | +        datefmt="[%m-%d %H:%M:%S]" | 
|  | 166 | +    ) | 
| 28 | 167 | 
 | 
| 29 |  | -def _create_rich_handler(use_stdout: bool = False, level: _LogLevelType = logging.NOTSET) -> RichHandler: | 
| 30 | 168 |     handler = RichHandler( | 
| 31 | 169 |         level=level, | 
| 32 | 170 |         console=_get_rich_console(use_stdout), | 
| 33 | 171 |         rich_tracebacks=True, | 
| 34 | 172 |         markup=True, | 
| 35 | 173 |         tracebacks_suppress=[ditk], | 
| 36 | 174 |     ) | 
| 37 |  | -    handler.setFormatter(_RICH_FMT) | 
|  | 175 | +    handler.setFormatter(rich_fmt) | 
| 38 | 176 |     return handler | 
0 commit comments