From 596206f12e3b7de01230e6d51c566f7501b12ef9 Mon Sep 17 00:00:00 2001 From: martin de la gorce Date: Wed, 20 Sep 2023 20:41:27 +0100 Subject: [PATCH 1/2] making add_dataclass_options public --- argparse_dataclass.py | 8 ++++---- tests/test_functional.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/argparse_dataclass.py b/argparse_dataclass.py index 37f7ab7..c823af8 100644 --- a/argparse_dataclass.py +++ b/argparse_dataclass.py @@ -298,7 +298,7 @@ def parse_args( ) -> OptionsType: """Parse arguments and return as the dataclass type.""" parser = argparse.ArgumentParser() - _add_dataclass_options(options_class, parser) + add_dataclass_options(options_class, parser) kwargs = _get_kwargs(parser.parse_args(args)) return options_class(**kwargs) @@ -310,13 +310,13 @@ def parse_known_args( and list of remaining arguments. """ parser = argparse.ArgumentParser() - _add_dataclass_options(options_class, parser) + add_dataclass_options(options_class, parser) namespace, others = parser.parse_known_args(args=args) kwargs = _get_kwargs(namespace) return options_class(**kwargs), others -def _add_dataclass_options( +def add_dataclass_options( options_class: typing.Type[OptionsType], parser: argparse.ArgumentParser ) -> None: if not is_dataclass(options_class): @@ -420,7 +420,7 @@ class ArgumentParser(argparse.ArgumentParser, typing.Generic[OptionsType]): def __init__(self, options_class: typing.Type[OptionsType], *args, **kwargs): super().__init__(*args, **kwargs) self._options_type: typing.Type[OptionsType] = options_class - _add_dataclass_options(options_class, self) + add_dataclass_options(options_class, self) def parse_args(self, args: ArgsType = None, namespace=None) -> OptionsType: """Parse arguments and return as the dataclass type.""" diff --git a/tests/test_functional.py b/tests/test_functional.py index 70ec48f..e3327e5 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -53,6 +53,20 @@ class Opt: self.assertRaises(TypeError, parse_args, Opt, []) + def test_add_dataclass_options(self): + @dataclass + class Opt: + x: int = 42 + y: bool = False + argpument_parser = argparse.ArgumentParser() + add_dataclass_options(argpument_parser, Opt) + params = argpument_parser.parse_args() + self.assertEqual(42, params.x) + self.assertEqual(False, params.y) + params = params = argpument_parser.parse_args(["--x=10", "--y"]) + self.assertEqual(10, params.x) + self.assertEqual(True, params.y) + def test_bool_no_default(self): @dataclass class Opt: From e4026244911bd4993e92159764687bd26685e6b7 Mon Sep 17 00:00:00 2001 From: martin de la gorce Date: Wed, 20 Sep 2023 21:03:54 +0100 Subject: [PATCH 2/2] fixing tests --- argparse_dataclass.py | 2 +- tests/test_functional.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/argparse_dataclass.py b/argparse_dataclass.py index c823af8..7ad37c9 100644 --- a/argparse_dataclass.py +++ b/argparse_dataclass.py @@ -359,7 +359,7 @@ def add_dataclass_options( if field.default == field.default_factory == MISSING and not positional: kwargs["required"] = True else: - kwargs["default"] = MISSING + kwargs["default"] = field.default if field.type is bool: _handle_bool_type(field, args, kwargs) diff --git a/tests/test_functional.py b/tests/test_functional.py index e3327e5..4e99a29 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,3 +1,4 @@ +import argparse import sys import unittest import datetime as dt @@ -5,7 +6,7 @@ from typing import List, Optional, Union -from argparse_dataclass import parse_args, parse_known_args +from argparse_dataclass import add_dataclass_options, parse_args, parse_known_args class NegativeTestHelper: @@ -59,8 +60,9 @@ class Opt: x: int = 42 y: bool = False argpument_parser = argparse.ArgumentParser() - add_dataclass_options(argpument_parser, Opt) - params = argpument_parser.parse_args() + add_dataclass_options(Opt, argpument_parser) + params = argpument_parser.parse_args([]) + print(params) self.assertEqual(42, params.x) self.assertEqual(False, params.y) params = params = argpument_parser.parse_args(["--x=10", "--y"])