diff --git a/src/con_duct/__main__.py b/src/con_duct/__main__.py index 464e650..8b572af 100755 --- a/src/con_duct/__main__.py +++ b/src/con_duct/__main__.py @@ -22,18 +22,39 @@ import threading import time from types import FrameType -from typing import IO, Any, Optional, TextIO +from typing import IO, Any, Callable, Dict, List, Optional, TextIO, Tuple __version__ = version("con-duct") __schema_version__ = "0.2.2" -lgr = logging.getLogger("con-duct") -DEFAULT_LOG_LEVEL = os.environ.get("DUCT_LOG_LEVEL", "INFO").upper() +ABOUT_DUCT = """ +duct is a lightweight wrapper that collects execution data for an arbitrary +command. Execution data includes execution time, system information, and +resource usage statistics of the command and all its child processes. It is +intended to simplify the problem of recording the resources necessary to +execute a command, particularly in an HPC environment. -DUCT_OUTPUT_PREFIX = os.getenv( - "DUCT_OUTPUT_PREFIX", ".duct/logs/{datetime_filesafe}-{pid}_" -) +Resource usage is determined by polling (at a sample-interval). +During execution, duct produces a JSON lines (see https://jsonlines.org) file +with one data point recorded for each report (at a report-interval). + +limitations: + Duct uses session id to track the command process and its children, so it + cannot handle the situation where a process creates a new session. + If a command spawns child processes, duct will collect data on them, but + duct exits as soon as the primary process exits. + +configuration: + Many options can be configured via JSON config file (--config), environment + variables, or command line arguments. Precedence: built-in defaults < config + file < environment variables < command line arguments. + + Default values shown below reflect the current configuration (built-in defaults + or loaded from config file). Environment variables are listed with each option. +""" + +DEFAULT_CONFIG_PATHS = "/etc/duct/config.json:${XDG_CONFIG_HOME:-~/.config}/duct/config.json:.duct/config.json" # noqa B950 ENV_PREFIXES = ("PBS_", "SLURM_", "OSG") SUFFIXES = { "stdout": "stdout", @@ -41,7 +62,7 @@ "usage": "usage.json", "info": "info.json", } -EXECUTION_SUMMARY_FORMAT = ( +_EXECUTION_SUMMARY_FORMAT = ( "Summary:\n" "Exit Code: {exit_code!E}\n" "Command: {command}\n" @@ -58,41 +79,7 @@ ) -ABOUT_DUCT = """ -duct is a lightweight wrapper that collects execution data for an arbitrary -command. Execution data includes execution time, system information, and -resource usage statistics of the command and all its child processes. It is -intended to simplify the problem of recording the resources necessary to -execute a command, particularly in an HPC environment. - -Resource usage is determined by polling (at a sample-interval). -During execution, duct produces a JSON lines (see https://jsonlines.org) file -with one data point recorded for each report (at a report-interval). - -limitations: - Duct uses session id to track the command process and its children, so it - cannot handle the situation where a process creates a new session. - If a command spawns child processes, duct will collect data on them, but - duct exits as soon as the primary process exits. - -environment variables: - Many duct options can be configured by environment variables (which are - overridden by command line options). - - DUCT_LOG_LEVEL: see --log-level - DUCT_OUTPUT_PREFIX: see --output-prefix - DUCT_SUMMARY_FORMAT: see --summary-format - DUCT_SAMPLE_INTERVAL: see --sample-interval - DUCT_REPORT_INTERVAL: see --report-interval - DUCT_CAPTURE_OUTPUTS: see --capture-outputs - DUCT_MESSAGE: see --message -""" - - -class CustomHelpFormatter(argparse.ArgumentDefaultsHelpFormatter): - def _fill_text(self, text: str, width: int, _indent: str) -> str: - # Override _fill_text to respect the newlines and indentation in descriptions - return "\n".join([textwrap.fill(line, width) for line in text.splitlines()]) +lgr = logging.getLogger("con-duct") def assert_num(*values: Any) -> None: @@ -139,6 +126,150 @@ def __str__(self) -> str: return self.value +@dataclass(frozen=True) +class FieldSpec: + """Specification for a configuration field.""" + + kind: str # "bool" | "value" + default: Any + cast: Callable[[Any], Any] + help: str + config_key: str # Hyphenated key used in config files and CLI flags + env_var: Optional[str] = None # Environment variable name (optional) + choices: Optional[Iterable[Any]] = None + validate: Optional[Callable[[Any], Any]] = None + file_configurable: bool = True # Whether this can be set in config files + alt_flag_names: Optional[List[str]] = None + # Additional metadata for argparse + metavar: Optional[str] = None + nargs: Optional[Any] = None + + +class CustomHelpFormatter(argparse.HelpFormatter): + """Custom help formatter that shows defaults and environment variables.""" + + def _fill_text(self, text: str, width: int, _indent: str) -> str: + # Override _fill_text to respect the newlines and indentation in descriptions + return "\n".join([textwrap.fill(line, width) for line in text.splitlines()]) + + def _get_help_string(self, action): + # Start with the base help text + help_text = action.help or "" + + # Add default value if available + if hasattr(action, "spec_default") and action.spec_default is not None: + if help_text: + help_text = help_text.rstrip() + help_text += f" (default: {str(action.spec_default)})" + + # Add environment variable info if available + if hasattr(action, "env_var") and action.env_var: + if help_text: + help_text = help_text.rstrip() + help_text += f" [env: {action.env_var}]" + + # Escape percent signs to prevent argparse formatting errors + help_text = help_text.replace("%", "%%") + + return help_text + + +def build_parser() -> argparse.ArgumentParser: + """Build argparse parser from FIELD_SPECS without injecting defaults.""" + parser = argparse.ArgumentParser( + allow_abbrev=False, + description=ABOUT_DUCT, + formatter_class=CustomHelpFormatter, + ) + + # Add --version + parser.add_argument( + "--version", action="version", version=f"%(prog)s {__version__}" + ) + # Add config manually (not in FIELD_SPECS) + config_action = parser.add_argument( + "-C", + "--config", + default=argparse.SUPPRESS, + help="Configuration file path", + ) + # Store default for help text + config_action.spec_default = DEFAULT_CONFIG_PATHS + + # Add --dump-config + parser.add_argument( + "--dump-config", + action="store_true", + help="Print the final merged config with value sources and exit", + ) + + # Add command and command_args as positional arguments (not in FIELD_SPECS) + parser.add_argument( + "command", + help="The command to execute", + metavar="command [command_args ...]", + ) + parser.add_argument( + "command_args", + nargs=argparse.REMAINDER, + help="Arguments for the command", + ) + + # Add all fields from specs + for name, spec in FIELD_SPECS.items(): + + if spec.kind == "bool": + # TODO there is no way to override clobber: True in config without --no-flag + # Simple boolean flags (just --flag, no --no-flag) + names = [] + if spec.alt_flag_names: + names.extend(spec.alt_flag_names) + names.append(f"--{spec.config_key}") + + action = parser.add_argument( + *names, + dest=name, + action="store_true", + default=argparse.SUPPRESS, + help=spec.help, + ) + + # Store default and env_var for help text + action.spec_default = spec.default + if spec.env_var: + action.env_var = spec.env_var + + else: # kind == "value" + # Regular value arguments - use SUPPRESS so only explicit values override config + kwargs = { + "default": argparse.SUPPRESS, # Don't inject defaults + "help": spec.help, + "type": spec.cast, + } + if spec.choices is not None: + kwargs["choices"] = list(spec.choices) + + # Build argument names using alt_flag_names if available + names = [] + if spec.alt_flag_names: + names.extend(spec.alt_flag_names) + names.append(f"--{spec.config_key}") + + # Special handling for quiet (deprecated boolean action) + if spec.config_key == "quiet": + kwargs.pop("type") # quiet is action store_true + kwargs["action"] = "store_true" + + action = parser.add_argument(*names, dest=name, **kwargs) + + # Store default and env_var for help text + action.spec_default = spec.default + if spec.env_var: + action.env_var = spec.env_var + + return parser + + @dataclass class SystemInfo: cpu_total: int @@ -730,189 +861,380 @@ def format_field(self, value: Any, format_spec: str) -> Any: return value_ -@dataclass -class Arguments: - command: str - command_args: list[str] - output_prefix: str - sample_interval: float - report_interval: float - fail_time: float - clobber: bool - capture_outputs: Outputs - outputs: Outputs - record_types: RecordTypes - summary_format: str - colors: bool - log_level: str - quiet: bool - session_mode: SessionMode - message: str = "" +class Config: + """Configuration management for duct. - def __post_init__(self) -> None: + This class loads configuration from multiple sources (files, env vars, CLI) + and provides validated access to all configuration values. + """ + + @staticmethod + def bool_from_str(x: Any) -> bool: + """Convert various string representations to boolean.""" + if isinstance(x, bool): + return x + s = str(x).strip().lower() + if s in {"1", "true", "yes", "on"}: + return True + if s in {"0", "false", "no", "off"}: + return False + raise ValueError(f"invalid boolean: {x!r}") + + @staticmethod + def validate_positive(v: float) -> float: + """Validate that a value is positive.""" + if v <= 0: + raise ValueError("must be greater than 0") + return v + + @staticmethod + def validate_sample_report_interval(sample: float, report: float) -> None: + """Validate that report interval >= sample interval.""" + if report < sample: + raise ValueError( + "report-interval must be greater than or equal to sample-interval" + ) + + def __init__(self, cli_args: Dict[str, Any]): + """Initialize and load configuration from all sources. + + Args: + cli_args: Parsed CLI arguments dictionary from argparse + (with command, command_args, and config already removed) + + Raises: + SystemExit: If configuration validation fails + """ + self._cli_args = cli_args + self._load_and_validate() + + def _load_and_validate(self) -> None: + """Load configuration from all sources and validate it.""" + config_paths_str = ( + self._cli_args.get("config") + or os.environ.get("DUCT_CONFIG") + or DEFAULT_CONFIG_PATHS + ) + config_paths = self._expand_config_paths(config_paths_str) + file_layers = self._load_files(config_paths) + + env_vals, env_src = self._load_env() + merged, provenance = self._merge_with_provenance( + file_layers=file_layers, + env_vals=env_vals, + env_src=env_src, + cli_vals=self._cli_args, + ) + final, errors = self._coerce_and_validate(merged, provenance) + + if errors: + print("Configuration errors:", file=sys.stderr) + for error in errors: + # TODO logger + print(error, file=sys.stderr) + sys.exit(1) + + # Set all configuration values as instance attributes + for name, spec in FIELD_SPECS.items(): + if name in final: + setattr(self, name, final[name]) + setattr( + self, + f"_source_{name}", + provenance.get(spec.config_key, f"default ({name})"), + ) + elif spec.default is not None: + setattr(self, name, spec.default) + setattr(self, f"_source_{name}", f"default ({name})") + + # Validate cross-field constraints + self._validate_constraints() + + def _validate_constraints(self) -> None: + """Validate cross-field constraints.""" if self.report_interval < self.sample_interval: + # TODO is ArgumentError appropriate if this was config value fail? raise argparse.ArgumentError( None, "--report-interval must be greater than or equal to --sample-interval.", ) - @classmethod - def from_argv( - cls, cli_args: Optional[list[str]] = None, **cli_kwargs: Any - ) -> Arguments: - parser = argparse.ArgumentParser( - allow_abbrev=False, - description=ABOUT_DUCT, - formatter_class=CustomHelpFormatter, - ) - parser.add_argument( - "command", - metavar="command [command_args ...]", - help="The command to execute, along with its arguments.", - ) - parser.add_argument( - "--version", action="version", version=f"%(prog)s {__version__}" - ) - parser.add_argument( - "command_args", nargs=argparse.REMAINDER, help="Arguments for the command." - ) - parser.add_argument( - "-p", - "--output-prefix", - type=str, - default=DUCT_OUTPUT_PREFIX, - help="File string format to be used as a prefix for the files -- the captured " - "stdout and stderr and the resource usage logs. The understood variables are " - "{datetime}, {datetime_filesafe}, and {pid}. " - "Leading directories will be created if they do not exist. " - "You can also provide value via DUCT_OUTPUT_PREFIX env variable. ", - ) - parser.add_argument( - "--summary-format", - type=str, - default=os.getenv("DUCT_SUMMARY_FORMAT", EXECUTION_SUMMARY_FORMAT), - help="Output template to use when printing the summary following execution. " - "Accepts custom conversion flags: " - "!S: Converts filesizes to human readable units, green if measured, red if None. " - "!E: Colors exit code, green if falsey, red if truthy, and red if None. " - "!X: Colors green if truthy, red if falsey. " - "!N: Colors green if not None, red if None", - ) - parser.add_argument( - "--colors", - action="store_true", - default=os.getenv("DUCT_COLORS", False), - help="Use colors in duct output.", - ) - parser.add_argument( - "--clobber", - action="store_true", - help="Replace log files if they already exist.", - ) - parser.add_argument( - "-l", - "--log-level", - default=DEFAULT_LOG_LEVEL, - type=str.upper, - choices=("NONE", "CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"), - help="Level of log output to stderr, use NONE to entirely disable.", - ) - parser.add_argument( - "-q", - "--quiet", - action="store_true", - help="[deprecated, use log level NONE] Disable duct logging output (to stderr)", - ) - parser.add_argument( - "--sample-interval", - "--s-i", - type=float, - default=float(os.getenv("DUCT_SAMPLE_INTERVAL", "1.0")), - help="Interval in seconds between status checks of the running process. " - "Sample interval must be less than or equal to report interval, and it achieves the " - "best results when sample is significantly less than the runtime of the process.", - ) - parser.add_argument( - "--report-interval", - "--r-i", - type=float, - default=float(os.getenv("DUCT_REPORT_INTERVAL", "60.0")), - help="Interval in seconds at which to report aggregated data.", - ) - parser.add_argument( - "--fail-time", - "--f-t", - type=float, - default=float(os.getenv("DUCT_FAIL_TIME", "3.0")), - help="If command fails in less than this specified time (seconds), duct would remove logs. " - "Set to 0 if you would like to keep logs for a failing command regardless of its run time. " - "Set to negative (e.g. -1) if you would like to not keep logs for any failing command.", + def _expand_config_paths(self, paths_str: str) -> List[str]: + """Expand environment variables and user paths in config path string.""" + expanded = paths_str.replace( + "${XDG_CONFIG_HOME:-~/.config}", + os.getenv("XDG_CONFIG_HOME", "~/.config"), ) + expanded = os.path.expandvars(expanded) + return [os.path.expanduser(p.strip()) for p in expanded.split(":") if p.strip()] - parser.add_argument( - "-c", - "--capture-outputs", - default=os.getenv("DUCT_CAPTURE_OUTPUTS", "all"), - choices=list(Outputs), - type=Outputs, - help="Record stdout, stderr, all, or none to log files. " - "You can also provide value via DUCT_CAPTURE_OUTPUTS env variable.", - ) - parser.add_argument( - "-o", - "--outputs", - default="all", - choices=list(Outputs), - type=Outputs, - help="Print stdout, stderr, all, or none to stdout/stderr respectively.", - ) - parser.add_argument( - "-t", - "--record-types", - default="all", - choices=list(RecordTypes), - type=RecordTypes, - help="Record system-summary, processes-samples, or all", - ) - parser.add_argument( - "-m", - "--message", - type=str, - default=os.getenv("DUCT_MESSAGE", ""), - help="Record a descriptive message about the purpose of this execution. " - "You can also provide value via DUCT_MESSAGE env variable.", - ) - parser.add_argument( - "--mode", - default="new-session", - choices=list(SessionMode), - type=SessionMode, - help="Session mode: 'new-session' creates a new session for the command (default), " - "'current-session' tracks the current session instead of starting a new one. " - "Useful for tracking slurm jobs or other commands that should run in the current session.", - ) - args = parser.parse_args( - args=cli_args, - namespace=cli_kwargs and argparse.Namespace(**cli_kwargs) or None, - ) - return cls( - command=args.command, - command_args=args.command_args, - output_prefix=args.output_prefix, - sample_interval=args.sample_interval, - report_interval=args.report_interval, - fail_time=args.fail_time, - capture_outputs=args.capture_outputs, - outputs=args.outputs, - record_types=args.record_types, - summary_format=args.summary_format, - clobber=args.clobber, - colors=args.colors, - log_level=args.log_level, - quiet=args.quiet, - session_mode=args.mode, - message=args.message, - ) + def _load_files(self, paths: List[str]) -> List[Tuple[Dict[str, Any], str]]: + """Load config files and return list of (dict, source_label).""" + out: List[Tuple[Dict[str, Any], str]] = [] + for path in paths: + try: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + out.append((data, f"config file: {path}")) + except FileNotFoundError: + continue + except (json.JSONDecodeError, OSError) as e: + raise SystemExit(f"Error loading config from {path}: {e}") + return out + + def _load_env(self) -> Tuple[Dict[str, Any], Dict[str, str]]: + """Load configuration from environment variables.""" + vals: Dict[str, Any] = {} + prov: Dict[str, str] = {} + for _name, spec in FIELD_SPECS.items(): + var = spec.env_var + if var and var in os.environ: + vals[spec.config_key] = os.environ[var] + prov[spec.config_key] = f"env var: {var}" + return vals, prov + + def _merge_with_provenance( + self, + file_layers: List[Tuple[Dict[str, Any], str]], + env_vals: Dict[str, Any], + env_src: Dict[str, str], + cli_vals: Dict[str, Any], + ) -> Tuple[Dict[str, Any], Dict[str, str]]: + """Merge configuration from all sources with provenance tracking.""" + merged: Dict[str, Any] = {} + src: Dict[str, str] = {} + + # Defaults first + for name, spec in FIELD_SPECS.items(): + if spec.default is not None: + merged[spec.config_key] = spec.default + src[spec.config_key] = f"default ({name})" + + # Files in order + for data, label in file_layers: + for k, v in data.items(): + # Convert both hyphen and underscore formats to underscore to match FIELD_SPECS + spec_key = k.replace("-", "_") + if spec_key in FIELD_SPECS and FIELD_SPECS[spec_key].file_configurable: + # Use the spec's config_key for consistency + config_key = FIELD_SPECS[spec_key].config_key + merged[config_key] = v + src[config_key] = label + + # Environment variables + for k, v in env_vals.items(): + merged[k] = v + src[k] = env_src[k] + + # CLI arguments + for k, v in cli_vals.items(): + if k in FIELD_SPECS: + config_key = FIELD_SPECS[k].config_key + merged[config_key] = v + src[config_key] = f"CLI: --{config_key}" + + return merged, src + + def _coerce_and_validate( + self, raw: Dict[str, Any], provenance: Dict[str, str] + ) -> Tuple[Dict[str, Any], List[str]]: + """Coerce types and validate all configuration values.""" + clean: Dict[str, Any] = {} + errors: List[str] = [] + + for name, spec in FIELD_SPECS.items(): + val = raw.get(spec.config_key, spec.default) + if val is None and spec.default is None: + continue + + src_label = provenance.get(spec.config_key, f"default ({name})") + try: + val = spec.cast(val) + except (ValueError, TypeError) as e: + errors.append( + f"- {spec.config_key}: {e} (value {val!r} from {src_label})" + ) + continue + + if spec.choices is not None and val not in spec.choices: + errors.append( + f"- {spec.config_key}: must be one of {list(spec.choices)} " + f"(value {val!r} from {src_label})" + ) + continue + + if spec.validate is not None: + try: + val = spec.validate(val) + except ValueError as e: + errors.append( + f"- {spec.config_key}: {e} (value {val!r} from {src_label})" + ) + continue + clean[name] = val + + # Cross-field validation + if "sample_interval" in clean and "report_interval" in clean: + try: + self.validate_sample_report_interval( + clean["sample_interval"], clean["report_interval"] + ) + except ValueError as e: + errors.append(f"- interval validation: {e}") + + return clean, errors + + def dump_config(self) -> None: + """Print the final merged config with value sources.""" + config_dump = {} + + # TODO since we are doing like this probably dont need to pop cmd and cmd args before calling + for name, spec in FIELD_SPECS.items(): + value = getattr(self, name, spec.default) + source = getattr(self, f"_source_{name}", "default") + + config_dump[spec.config_key] = { + "value": value, + "source": source, + "type": type( + value + ).__name__, # TODO should this come from spec not value type? + } + + print(json.dumps(config_dump, indent=2, default=str)) + + +FIELD_SPECS: Dict[str, FieldSpec] = { + "output_prefix": FieldSpec( + kind="value", + default=".duct/logs/{datetime_filesafe}-{pid}_", + cast=str, + help="File string format prefix for output files", + config_key="output-prefix", + env_var="DUCT_OUTPUT_PREFIX", + alt_flag_names=["-p"], + ), + "summary_format": FieldSpec( + kind="value", + default=_EXECUTION_SUMMARY_FORMAT, + cast=str, + help="Output template for execution summary", + config_key="summary-format", + env_var="DUCT_SUMMARY_FORMAT", + ), + "colors": FieldSpec( + kind="bool", + default=False, + cast=Config.bool_from_str, + help="Use colors in duct output", + config_key="colors", + env_var="DUCT_COLORS", + ), + "log_level": FieldSpec( + kind="value", + default="INFO", + cast=str.upper, + help="Level of log output to stderr", + config_key="log-level", + env_var="DUCT_LOG_LEVEL", + choices=["NONE", "CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], + alt_flag_names=["-l"], + ), + "clobber": FieldSpec( + kind="bool", + default=False, + cast=Config.bool_from_str, + help="Replace log files if they already exist", + config_key="clobber", + env_var="DUCT_CLOBBER", + ), + "sample_interval": FieldSpec( + kind="value", + default=1.0, + cast=float, + help="Interval in seconds between status checks", + config_key="sample-interval", + env_var="DUCT_SAMPLE_INTERVAL", + validate=Config.validate_positive, + alt_flag_names=["--s-i"], + ), + "report_interval": FieldSpec( + kind="value", + default=60.0, + cast=float, + help="Interval in seconds for reporting aggregated data", + config_key="report-interval", + env_var="DUCT_REPORT_INTERVAL", + validate=Config.validate_positive, + alt_flag_names=["--r-i"], + ), + "fail_time": FieldSpec( + kind="value", + default=3.0, + cast=float, + help="Time threshold for keeping logs of failing commands", + config_key="fail-time", + env_var="DUCT_FAIL_TIME", + alt_flag_names=["--f-t"], + ), + "capture_outputs": FieldSpec( + kind="value", + default=Outputs.ALL, + cast=Outputs, + help="Record stdout, stderr, all, or none to log files", + config_key="capture-outputs", + env_var="DUCT_CAPTURE_OUTPUTS", + choices=list(Outputs), + alt_flag_names=["-c"], + ), + "outputs": FieldSpec( + kind="value", + default=Outputs.ALL, + cast=Outputs, + help="Print stdout, stderr, all, or none", + config_key="outputs", + env_var="DUCT_OUTPUTS", + choices=list(Outputs), + alt_flag_names=["-o"], + ), + "record_types": FieldSpec( + kind="value", + default=RecordTypes.ALL, + cast=RecordTypes, + help="Record system-summary, processes-samples, or all", + config_key="record-types", + env_var="DUCT_RECORD_TYPES", + choices=list(RecordTypes), + alt_flag_names=["-t"], + ), + "mode": FieldSpec( + kind="value", + default=SessionMode.NEW_SESSION, + cast=SessionMode, + help="Session mode for command execution", + config_key="mode", + env_var="DUCT_MODE", + choices=list(SessionMode), + ), + "message": FieldSpec( + kind="value", + default="", + cast=str, + help="Descriptive message about this execution", + config_key="message", + env_var="DUCT_MESSAGE", + alt_flag_names=["-m"], + ), + "quiet": FieldSpec( + kind="bool", + default=False, + cast=Config.bool_from_str, + help="[deprecated] Disable duct logging output", + config_key="quiet", + env_var="DUCT_QUIET", + alt_flag_names=["-q"], + ), +} def monitor_process( @@ -1059,14 +1381,43 @@ def remove_files(log_paths: LogPaths, assert_empty: bool = False) -> None: os.remove(file_path) +def handle_dump_config(parser: argparse.ArgumentParser) -> None: + """Handle --dump-config option and exit if present.""" + if "--dump-config" not in sys.argv: + return + + # Add dummy command to avoid missing required positional args, then parse + argv_with_dummy = [arg for arg in sys.argv if arg != "--dump-config"] + [ + "--dump-config", + "dummy", + ] + cli_args = vars(parser.parse_args(argv_with_dummy[1:])) # Skip script name + # Remove non FieldSpec Args + cli_args.pop("dump_config", False) + cli_args.pop("command", None) + cli_args.pop("command_args", None) + config = Config(cli_args) + config.dump_config() + sys.exit(0) + + def main() -> None: + # TODO lets make this logger.Error instead so we can show configfile load issues + # Set up basic logging configuration (level will be set properly in execute) logging.basicConfig( format="%(asctime)s [%(levelname)-8s] %(name)s: %(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z", - level=getattr(logging, DEFAULT_LOG_LEVEL), + level=logging.INFO, # Use default level initially ) - args = Arguments.from_argv() - sys.exit(execute(args)) + parser = build_parser() + handle_dump_config(parser) + cli_args = vars(parser.parse_args()) + + # Extract positional args and special flags (not part of FieldSpec) + command = cli_args.pop("command", "") + command_args = cli_args.pop("command_args", []) + config = Config(cli_args) + sys.exit(execute(config, command, command_args)) class ProcessSignalHandler: @@ -1090,18 +1441,24 @@ def handle_signal(self, _sig: int, _frame: Optional[FrameType]) -> None: os._exit(1) -def execute(args: Arguments) -> int: +def execute(config: Config, command: str, command_args: List[str]) -> int: """A wrapper to execute a command, monitor and log the process details. - Returns exit code of the executed process. + Args: + config: Configuration object with all settings + command: The command to execute + command_args: Arguments for the command + + Returns: + Exit code of the executed process. """ - if args.log_level == "NONE" or args.quiet: + if config.log_level == "NONE" or config.quiet: lgr.disabled = True else: - lgr.setLevel(args.log_level) - log_paths = LogPaths.create(args.output_prefix, pid=os.getpid()) - log_paths.prepare_paths(args.clobber, args.capture_outputs) - stdout, stderr = prepare_outputs(args.capture_outputs, args.outputs, log_paths) + lgr.setLevel(config.log_level) + log_paths = LogPaths.create(config.output_prefix, pid=os.getpid()) + log_paths.prepare_paths(config.clobber, config.capture_outputs) + stdout, stderr = prepare_outputs(config.capture_outputs, config.outputs, log_paths) stdout_file: TextIO | IO[bytes] | int | None if isinstance(stdout, TailPipe): stdout_file = open(stdout.file_path, "wb") @@ -1114,28 +1471,28 @@ def execute(args: Arguments) -> int: stderr_file = stderr working_directory = os.getcwd() - full_command = " ".join([str(args.command)] + args.command_args) + full_command = " ".join([str(command)] + command_args) files_to_close = [stdout_file, stdout, stderr_file, stderr] report = Report( - args.command, - args.command_args, + command, + command_args, log_paths, - args.summary_format, + config.summary_format, working_directory, - args.colors, - args.clobber, - message=args.message, + config.colors, + config.clobber, + message=config.message, ) files_to_close.append(report.usage_file) report.start_time = time.time() try: report.process = process = subprocess.Popen( - [str(args.command)] + args.command_args, + [str(command)] + command_args, stdout=stdout_file, stderr=stderr_file, - start_new_session=(args.session_mode == SessionMode.NEW_SESSION), + start_new_session=(config.mode == SessionMode.NEW_SESSION), cwd=report.working_directory, ) except FileNotFoundError: @@ -1145,7 +1502,7 @@ def execute(args: Arguments) -> int: safe_close_files(files_to_close) remove_files(log_paths, assert_empty=True) # mimicking behavior of bash and zsh. - lgr.error("%s: command not found", args.command) + lgr.error("%s: command not found", command) return 127 # seems what zsh and bash return then handler = ProcessSignalHandler(process.pid) @@ -1153,7 +1510,7 @@ def execute(args: Arguments) -> int: lgr.info("duct %s is executing %r...", __version__, full_command) lgr.info("Log files will be written to %s", log_paths.prefix) try: - if args.session_mode == SessionMode.NEW_SESSION: + if config.mode == SessionMode.NEW_SESSION: report.session_id = os.getsid( process.pid ) # Get session ID of the new process @@ -1165,12 +1522,12 @@ def execute(args: Arguments) -> int: # TODO: log this at least. pass stop_event = threading.Event() - if args.record_types.has_processes_samples(): + if config.record_types.has_processes_samples(): monitoring_args = [ report, process, - args.report_interval, - args.sample_interval, + config.report_interval, + config.sample_interval, stop_event, ] monitoring_thread = threading.Thread( @@ -1180,7 +1537,7 @@ def execute(args: Arguments) -> int: else: monitoring_thread = None - if args.record_types.has_system_summary(): + if config.record_types.has_system_summary(): env_thread = threading.Thread(target=report.collect_environment) env_thread.start() sys_info_thread = threading.Thread(target=report.get_system_info) @@ -1212,17 +1569,17 @@ def execute(args: Arguments) -> int: sys_info_thread.join() lgr.debug("System information collection finished") - if args.record_types.has_system_summary(): + if config.record_types.has_system_summary(): with open(log_paths.info, "w") as system_logs: report.run_time_seconds = f"{report.end_time - report.start_time}" system_logs.write(report.dump_json()) safe_close_files(files_to_close) if process.returncode != 0 and ( - report.elapsed_time < args.fail_time or args.fail_time < 0 + report.elapsed_time < config.fail_time or config.fail_time < 0 ): lgr.info( "Removing log files since command failed%s.", - f" in less than {args.fail_time} seconds" if args.fail_time > 0 else "", + f" in less than {config.fail_time} seconds" if config.fail_time > 0 else "", ) remove_files(log_paths) else: diff --git a/src/con_duct/suite/main.py b/src/con_duct/suite/main.py index a05839d..d7a99bd 100644 --- a/src/con_duct/suite/main.py +++ b/src/con_duct/suite/main.py @@ -49,6 +49,7 @@ def main(argv: Optional[List[str]] = None) -> None: # Subcommand: pp parser_pp = subparsers.add_parser("pp", help="Pretty print a JSON log.") parser_pp.add_argument("file_path", help="JSON file to pretty print.") + # TODO(add env var and config) parser_pp.add_argument( "-H", "--humanize", @@ -81,6 +82,7 @@ def main(argv: Optional[List[str]] = None) -> None: "ls", help="Print execution information for all matching runs.", ) + # TODO(add env var DUCT_LS_FORMAT and config) parser_ls.add_argument( "-f", "--format", @@ -89,6 +91,7 @@ def main(argv: Optional[List[str]] = None) -> None: help="Output format. TODO Fixme. 'auto' chooses 'pyout' if pyout library is installed," " 'summaries' otherwise.", ) + # TODO(add env var DUCT_LS_FIELDS and config) parser_ls.add_argument( "-F", "--fields", @@ -104,12 +107,14 @@ def main(argv: Optional[List[str]] = None) -> None: "peak_rss", ], ) + # TODO(already exists todo in __main__ we should use here too) parser_ls.add_argument( "--colors", action="store_true", default=os.getenv("DUCT_COLORS", False), help="Use colors in duct output.", ) + # TODO(config: not now but we should consider adding parser_ls.add_argument( "paths", nargs="*", diff --git a/test/test_arg_parsing.py b/test/test_arg_parsing.py index a90fd3c..23ee8ba 100644 --- a/test/test_arg_parsing.py +++ b/test/test_arg_parsing.py @@ -1,9 +1,7 @@ -import os import re import subprocess -from unittest import mock import pytest -from con_duct.__main__ import Arguments +from con_duct.__main__ import build_parser def test_duct_help() -> None: @@ -87,13 +85,17 @@ def test_abreviation_disabled() -> None: ], ) def test_mode_argument_parsing(mode_arg: list, expected_mode: str) -> None: - """Test that --mode argument is parsed correctly with both long and short forms.""" - # Import here to avoid module loading issues in tests - from con_duct.__main__ import Arguments - + """Test that --mode argument is parsed correctly.""" + parser = build_parser() cmd_args = mode_arg + ["echo", "test"] - args = Arguments.from_argv(cmd_args) - assert str(args.session_mode) == expected_mode + args = parser.parse_args(cmd_args) + # When no mode is provided, it won't be in args due to argparse.SUPPRESS + if mode_arg: + assert hasattr(args, "mode") + assert str(args.mode) == expected_mode + else: + # Default case - mode won't be in args namespace + assert not hasattr(args, "mode") def test_mode_invalid_value() -> None: @@ -105,34 +107,99 @@ def test_mode_invalid_value() -> None: pytest.fail("Command should have failed with invalid mode value") except subprocess.CalledProcessError as e: assert e.returncode == 2 + # Enum shows class name but argparse includes the choices assert "invalid SessionMode value: 'invalid-mode'" in str(e.stdout) -def test_message_parsing() -> None: - """Test that -m/--message flag is correctly parsed.""" - # Test short flag - args = Arguments.from_argv(["-m", "test message", "echo", "hello"]) - assert args.message == "test message" - assert args.command == "echo" - assert args.command_args == ["hello"] +@pytest.mark.parametrize( + "cli_args,expected_key,expected_value", + [ + # Message parsing tests + (["-m", "test message", "echo", "hello"], "message", "test message"), + (["-m", "test message", "echo", "hello"], "command", "echo"), + (["-m", "test message", "echo", "hello"], "command_args", ["hello"]), + (["--message", "another message", "ls"], "message", "another message"), + (["--message", "another message", "ls"], "command", "ls"), + (["--message", "another message", "ls"], "command_args", []), + # Comprehensive argument parsing tests + ( + ["--output-prefix", "/tmp/test", "echo", "test"], + "output_prefix", + "/tmp/test", + ), + (["--sample-interval", "2", "echo", "test"], "sample_interval", 2), + (["--report-interval", "10", "echo", "test"], "report_interval", 10), + (["--log-level", "DEBUG", "echo", "test"], "log_level", "DEBUG"), + ( + ["--config", "/path/to/config.json", "echo", "test"], + "config", + "/path/to/config.json", + ), + # Basic command parsing + (["echo", "hello"], "command", "echo"), + (["echo", "hello"], "command_args", ["hello"]), + ], +) +def test_argument_parsing(cli_args: list, expected_key: str, expected_value) -> None: + """Test that parser correctly handles various argument combinations.""" + parser = build_parser() + args = parser.parse_args(cli_args) - # Test long flag - args = Arguments.from_argv(["--message", "another message", "ls"]) - assert args.message == "another message" - assert args.command == "ls" + assert getattr(args, expected_key) == expected_value - # Test without message (should be empty string) - args = Arguments.from_argv(["echo", "hello"]) - assert args.message == "" +@pytest.mark.parametrize( + "cli_args,expected_key,expected_value", + [ + (["--capture-outputs", "all", "echo", "test"], "capture_outputs", "all"), + (["--capture-outputs", "stderr", "echo", "test"], "capture_outputs", "stderr"), + (["--outputs", "stdout", "echo", "test"], "outputs", "stdout"), + (["--outputs", "none", "echo", "test"], "outputs", "none"), + ( + ["--record-types", "system-summary", "echo", "test"], + "record_types", + "system-summary", + ), + ( + ["--record-types", "processes-samples", "echo", "test"], + "record_types", + "processes-samples", + ), + (["--mode", "new-session", "echo", "test"], "mode", "new-session"), + (["--mode", "current-session", "echo", "test"], "mode", "current-session"), + ], +) +def test_enum_argument_parsing( + cli_args: list, expected_key: str, expected_value: str +) -> None: + """Test that parser correctly handles enum arguments.""" + parser = build_parser() + args = parser.parse_args(cli_args) + # For enums, compare string representation + assert str(getattr(args, expected_key)) == expected_value -def test_message_env_variable() -> None: - """Test that DUCT_MESSAGE environment variable is used as default.""" - with mock.patch.dict(os.environ, {"DUCT_MESSAGE": "env message"}): - args = Arguments.from_argv(["echo", "hello"]) - assert args.message == "env message" - # Command line should override env variable - with mock.patch.dict(os.environ, {"DUCT_MESSAGE": "env message"}): - args = Arguments.from_argv(["-m", "cli message", "echo", "hello"]) - assert args.message == "cli message" +@pytest.mark.parametrize( + "attribute_name", + [ + "message", + "output_prefix", + "sample_interval", + "report_interval", + "capture_outputs", + "outputs", + "record_types", + "mode", + "log_level", + "colors", + "clobber", + "fail_time", + "summary_format", + "quiet", + ], +) +def test_optional_attributes_absent_without_flags(attribute_name: str) -> None: + """Test that optional attributes are absent when their flags aren't provided.""" + parser = build_parser() + args = parser.parse_args(["echo", "hello"]) + assert not hasattr(args, attribute_name) diff --git a/test/test_config.py b/test/test_config.py new file mode 100644 index 0000000..a2ddf7c --- /dev/null +++ b/test/test_config.py @@ -0,0 +1,303 @@ +"""Tests for the Config class and validation functions.""" + +import json +import os +import tempfile +from unittest import mock +import pytest +from con_duct.__main__ import FIELD_SPECS, Config, build_parser, handle_dump_config + + +class TestValidationFunctions: + """Test validation helper functions.""" + + def test_bool_from_str(self): + """Test string to boolean conversion.""" + # True values + assert Config.bool_from_str("true") is True + assert Config.bool_from_str("True") is True + assert Config.bool_from_str("TRUE") is True + assert Config.bool_from_str("yes") is True + assert Config.bool_from_str("Yes") is True + assert Config.bool_from_str("1") is True + + # False values + assert Config.bool_from_str("false") is False + assert Config.bool_from_str("False") is False + assert Config.bool_from_str("FALSE") is False + assert Config.bool_from_str("no") is False + assert Config.bool_from_str("No") is False + assert Config.bool_from_str("0") is False + + # Invalid values + with pytest.raises(ValueError): + Config.bool_from_str("maybe") + with pytest.raises(ValueError): + Config.bool_from_str("2") + with pytest.raises(ValueError): + Config.bool_from_str("") + + def test_validate_positive(self): + """Test positive number validation.""" + # Valid positive numbers + assert Config.validate_positive(1) == 1 + assert Config.validate_positive(0.5) == 0.5 + assert Config.validate_positive(100) == 100 + + # Invalid non-positive numbers + with pytest.raises(ValueError, match="must be greater than 0"): + Config.validate_positive(0) + with pytest.raises(ValueError, match="must be greater than 0"): + Config.validate_positive(-1) + with pytest.raises(ValueError, match="must be greater than 0"): + Config.validate_positive(-0.5) + + @pytest.mark.parametrize( + "sample_interval,report_interval,should_raise,expected_message", + [ + # Valid cases: report >= sample + (1.0, 5.0, False, None), + (2.0, 2.0, False, None), # Equal is valid + (0.5, 1.0, False, None), + # Invalid cases: report < sample + ( + 5.0, + 4.0, + True, + "report-interval must be greater than or equal to sample-interval", + ), + ( + 10.0, + 5.0, + True, + "report-interval must be greater than or equal to sample-interval", + ), + ( + 3.5, + 2.1, + True, + "report-interval must be greater than or equal to sample-interval", + ), + ], + ) + def test_validate_sample_report_interval( + self, sample_interval, report_interval, should_raise, expected_message + ): + """Test sample/report interval validation.""" + if should_raise: + with pytest.raises(ValueError, match=expected_message): + Config.validate_sample_report_interval(sample_interval, report_interval) + else: + Config.validate_sample_report_interval( + sample_interval, report_interval + ) # Should not raise + + +class TestConfig: + """Test Config class functionality.""" + + def test_config_default_initialization(self): + """Test Config initialization with defaults.""" + # Clear env vars that conftest.py sets to test true defaults + with mock.patch.dict(os.environ, {}, clear=True): + config = Config({}) + + # Check some default values + assert config.sample_interval == FIELD_SPECS["sample_interval"].default + assert config.report_interval == FIELD_SPECS["report_interval"].default + assert config.capture_outputs == FIELD_SPECS["capture_outputs"].default + assert config.mode == FIELD_SPECS["mode"].default + assert config.message == FIELD_SPECS["message"].default + assert config.log_level == FIELD_SPECS["log_level"].default + + def test_config_with_cli_args(self): + """Test Config initialization with CLI arguments.""" + cli_args = { + "sample_interval": 2.0, + "report_interval": 120.0, + "message": "test message", + "log_level": "DEBUG", + } + config = Config(cli_args) + + assert config.sample_interval == 2.0 + assert config.report_interval == 120.0 + assert config.message == "test message" + assert config.log_level == "DEBUG" + + def test_config_with_env_vars(self): + """Test Config initialization with environment variables.""" + with mock.patch.dict( + os.environ, + { + "DUCT_SAMPLE_INTERVAL": "3.0", + "DUCT_REPORT_INTERVAL": "180.0", + "DUCT_MESSAGE": "env message", + "DUCT_LOG_LEVEL": "WARNING", + }, + ): + config = Config({}) + + assert config.sample_interval == 3.0 + assert config.report_interval == 180.0 + assert config.message == "env message" + assert config.log_level == "WARNING" + + def test_config_precedence(self): + """Test configuration precedence: defaults < config file < env vars < CLI.""" + # Create a temporary config file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + config_data = { + "sample-interval": 10.0, + "report-interval": 100.0, + "message": "file message", + "log-level": "ERROR", + } + json.dump(config_data, f) + config_file = f.name + + try: + # Test with config file (clear env vars to prevent conftest interference) + with mock.patch.dict(os.environ, {"DUCT_CONFIG": config_file}, clear=True): + config = Config({}) + assert config.sample_interval == 10.0 + assert config.message == "file message" + + # Test env vars override config file + with mock.patch.dict( + os.environ, + { + "DUCT_CONFIG": config_file, + "DUCT_MESSAGE": "env message", + "DUCT_SAMPLE_INTERVAL": "15.0", + }, + clear=True, + ): + config = Config({}) + assert config.sample_interval == 15.0 + assert config.message == "env message" + + # Test CLI args override everything + with mock.patch.dict( + os.environ, + { + "DUCT_CONFIG": config_file, + "DUCT_MESSAGE": "env message", + }, + clear=True, + ): + cli_args = {"message": "cli message", "sample_interval": 20.0} + config = Config(cli_args) + assert config.sample_interval == 20.0 + assert config.message == "cli message" + + finally: + os.unlink(config_file) + + def test_config_invalid_json_file(self): + """Test Config handles invalid JSON files gracefully.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write("not valid json") + config_file = f.name + + try: + with mock.patch.dict(os.environ, {"DUCT_CONFIG": config_file}, clear=True): + with pytest.raises(SystemExit): + Config({}) + finally: + os.unlink(config_file) + + def test_config_nonexistent_file(self): + """Test Config handles nonexistent config files gracefully.""" + # Nonexistent file in DUCT_CONFIG should cause error + # Clear env vars to test true defaults, then set only DUCT_CONFIG + with mock.patch.dict( + os.environ, {"DUCT_CONFIG": "/nonexistent/config.json"}, clear=True + ): + # Should log warning but not fail (uses defaults) + config = Config({}) + assert ( + config.sample_interval == FIELD_SPECS["sample_interval"].default + ) # default value + + def test_config_dump(self, capsys): + """Test Config.dump_config() method.""" + config = Config({"message": "test"}) + config.dump_config() + + captured = capsys.readouterr() + output = json.loads(captured.out) + + assert "message" in output + assert output["message"]["value"] == "test" + assert "CLI:" in output["message"]["source"] + assert output["message"]["type"] == "str" + + +class TestHandleDumpConfig: + """Test handle_dump_config function.""" + + def test_handle_dump_config_exits(self): + """Test that --dump-config causes early exit.""" + parser = build_parser() + + with pytest.raises(SystemExit) as exc_info: + with mock.patch("sys.argv", ["duct", "--dump-config", "echo", "test"]): + handle_dump_config(parser) + + assert exc_info.value.code == 0 + + def test_handle_dump_config_no_flag(self): + """Test that without --dump-config, function returns normally.""" + parser = build_parser() + + with mock.patch("sys.argv", ["duct", "echo", "test"]): + result = handle_dump_config(parser) + assert result is None # Should return without raising + + +class TestBuildParser: + """Test build_parser function.""" + + def test_build_parser_creates_parser(self): + """Test that build_parser creates a valid ArgumentParser.""" + parser = build_parser() + + # Check parser has expected arguments + args = parser.parse_args(["echo", "test"]) + assert args.command == "echo" + assert args.command_args == ["test"] + + def test_build_parser_includes_all_field_specs(self): + """Test that parser includes all FieldSpec arguments.""" + parser = build_parser() + + # Test various arguments are accepted + args = parser.parse_args( + [ + "--output-prefix", + "/tmp/test", + "--sample-interval", + "2", + "--report-interval", + "10", + "--capture-outputs", + "none", + "--mode", + "current-session", + "-m", + "test message", + "--log-level", + "DEBUG", + "echo", + "hello", + ] + ) + + assert args.output_prefix == "/tmp/test" + assert args.sample_interval == 2.0 + assert args.report_interval == 10.0 + assert args.capture_outputs == "none" + assert str(args.mode) == "current-session" + assert args.message == "test message" + assert args.log_level == "DEBUG"