diff --git a/craft_application/application.py b/craft_application/application.py index c92b5d154..bc2ba528c 100644 --- a/craft_application/application.py +++ b/craft_application/application.py @@ -34,8 +34,10 @@ import craft_cli import craft_parts import craft_providers +import pydantic from craft_parts.plugins.plugins import PluginType from platformdirs import user_cache_path +from pydantic.v1.utils import deep_update from craft_application import _config, commands, errors, grammar, models, secrets, util from craft_application.errors import PathInvalidError @@ -379,7 +381,7 @@ def get_project( craft_cli.emit.debug(f"Loading project file '{project_path!s}'") with project_path.open() as file: - yaml_data = util.safe_yaml_load(file) + yaml_data = util.safe_yaml_load_with_lines(file) host_arch = util.get_host_architecture() build_planner = self.app.BuildPlannerClass.from_yaml_data( @@ -404,13 +406,27 @@ def get_project( build_for = self._build_plan[0].build_for # validate project grammar - GrammarAwareProject.validate_grammar(yaml_data) + try: + GrammarAwareProject.validate_grammar(yaml_data) + except pydantic.ValidationError as err: + raise errors.CraftValidationError.from_pydantic( + err, + file_name=project_path.name, + doc_slug="common/craft-parts/reference/part_properties", + logpath_report=False, + validated_object=yaml_data, + ) from None build_on = host_arch # Setup partitions, some projects require the yaml data, most will not self._partitions = self._setup_partitions(yaml_data) - yaml_data = self._transform_project_yaml(yaml_data, build_on, build_for) + + # Apply transformations to base yaml, then update to preserve line numbers + yaml_base = util.remove_yaml_lines(yaml_data) + yaml_update = self._transform_project_yaml(yaml_base, build_on, build_for) + yaml_data = deep_update(yaml_data, yaml_update) + self.__project = self.app.ProjectClass.from_yaml_data(yaml_data, project_path) # check if mandatory adoptable fields exist if adopt-info not used diff --git a/craft_application/errors.py b/craft_application/errors.py index 7a637a9e6..4b7e10eb2 100644 --- a/craft_application/errors.py +++ b/craft_application/errors.py @@ -22,7 +22,7 @@ import os from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import yaml from craft_cli import CraftError @@ -72,6 +72,7 @@ def from_pydantic( error: pydantic.ValidationError, *, file_name: str = "yaml file", + validated_object: dict[str, Any] | None = None, **kwargs: str | bool | int | None, ) -> Self: """Convert this error from a pydantic ValidationError. @@ -81,7 +82,9 @@ def from_pydantic( :param doc_slug: The optional slug to this error's docs. :param kwargs: additional keyword arguments get passed to CraftError """ - message = format_pydantic_errors(error.errors(), file_name=file_name) + message = format_pydantic_errors( + error.errors(), file_name=file_name, validated_object=validated_object + ) return cls(message, **kwargs) # type: ignore[arg-type] diff --git a/craft_application/grammar.py b/craft_application/grammar.py index 3af035b81..ffa86d29b 100644 --- a/craft_application/grammar.py +++ b/craft_application/grammar.py @@ -121,6 +121,9 @@ def self_check(value: Any) -> bool: # noqa: ANN401 processor = GrammarProcessor(arch=arch, target_arch=target_arch, checker=self_check) for part_name, part_data in parts_yaml_data.items(): + # Ignore line numbers coming from yaml reader + if part_name.startswith("__line__"): + continue parts_yaml_data[part_name] = process_part( part_yaml_data=part_data, processor=processor ) diff --git a/craft_application/models/base.py b/craft_application/models/base.py index e0eaff5d9..caa614481 100644 --- a/craft_application/models/base.py +++ b/craft_application/models/base.py @@ -21,6 +21,7 @@ from typing import Any import pydantic +from pydantic import model_validator from typing_extensions import Self from craft_application import errors, util @@ -42,6 +43,11 @@ class CraftBaseModel(pydantic.BaseModel): coerce_numbers_to_str=True, ) + @model_validator(mode="before") + @classmethod + def _flatten(cls, values: dict[str, Any]) -> dict[str, Any]: + return util.remove_yaml_lines(values) + def marshal(self) -> dict[str, str | list[str] | dict[str, Any]]: """Convert to a dictionary.""" return self.model_dump(mode="json", by_alias=True, exclude_unset=True) @@ -84,6 +90,7 @@ def from_yaml_data(cls, data: dict[str, Any], filepath: pathlib.Path) -> Self: file_name=filepath.name, doc_slug=cls.model_reference_slug(), logpath_report=False, + validated_object=data, ) from None def to_yaml_file(self, path: pathlib.Path) -> None: diff --git a/craft_application/models/grammar.py b/craft_application/models/grammar.py index 9338e1538..3b3342ef0 100644 --- a/craft_application/models/grammar.py +++ b/craft_application/models/grammar.py @@ -22,6 +22,7 @@ from craft_grammar.models import Grammar # type: ignore[import-untyped] from pydantic import ConfigDict +from craft_application import util from craft_application.models.base import alias_generator from craft_application.models.constraints import SingleEntryDict @@ -92,6 +93,7 @@ def _ensure_parts(cls, data: dict[str, Any]) -> dict[str, Any]: item defined, set it to an empty dictionary. This is distinct from having `parts` be invalid, which is not coerced here. """ + data = util.remove_yaml_lines(data) data.setdefault("parts", {}) return data diff --git a/craft_application/util/__init__.py b/craft_application/util/__init__.py index c95cb71f0..87ab102be 100644 --- a/craft_application/util/__init__.py +++ b/craft_application/util/__init__.py @@ -35,7 +35,12 @@ ) from craft_application.util.string import humanize_list, strtobool from craft_application.util.system import get_parallel_build_count -from craft_application.util.yaml import dump_yaml, safe_yaml_load +from craft_application.util.yaml import ( + dump_yaml, + safe_yaml_load, + safe_yaml_load_with_lines, + remove_yaml_lines, +) from craft_application.util.cli import format_timestamp __all__ = [ @@ -55,6 +60,8 @@ "get_host_base", "dump_yaml", "safe_yaml_load", + "safe_yaml_load_with_lines", + "remove_yaml_lines", "retry", "get_parallel_build_count", "get_hostname", diff --git a/craft_application/util/error_formatting.py b/craft_application/util/error_formatting.py index 1aa7acd36..61e3c854d 100644 --- a/craft_application/util/error_formatting.py +++ b/craft_application/util/error_formatting.py @@ -17,8 +17,8 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import NamedTuple +from collections.abc import Iterable, Sequence +from typing import Any, NamedTuple from pydantic import error_wrappers @@ -45,7 +45,11 @@ def from_str(cls, loc_str: str) -> FieldLocationTuple: return cls(field, location) -def format_pydantic_error(loc: Iterable[str | int], message: str) -> str: +def format_pydantic_error( + loc: Sequence[str | int], + message: str, + validated_object: dict[str, Any] | None = None, +) -> str: """Format a single pydantic ErrorDict as a string. :param loc: An iterable of strings and integers determining the error location. @@ -56,9 +60,12 @@ def format_pydantic_error(loc: Iterable[str | int], message: str) -> str: """ field_path = _format_pydantic_error_location(loc) message = _format_pydantic_error_message(message) + line_num = _get_line_number(loc, validated_object) field_name, location = FieldLocationTuple.from_str(field_path) if location != "top-level": location = repr(location) + if line_num is not None: + location += f" - line {line_num}" if message == "field required": return f"- field {field_name!r} required in {location} configuration" @@ -68,11 +75,16 @@ def format_pydantic_error(loc: Iterable[str | int], message: str) -> str: return f"- duplicate {field_name!r} entry not permitted in {location} configuration" if field_path in ("__root__", ""): return f"- {message}" - return f"- {message} (in field {field_path!r})" + return f"- {message} (in field {field_path!r}" + ( + f" - line {line_num})" if line_num else ")" + ) def format_pydantic_errors( - errors: Iterable[error_wrappers.ErrorDict], *, file_name: str = "yaml file" + errors: Iterable[error_wrappers.ErrorDict], + *, + file_name: str = "yaml file", + validated_object: dict[str, Any] | None = None, ) -> str: """Format errors. @@ -87,7 +99,10 @@ def format_pydantic_errors( - field: reason: . """ - messages = (format_pydantic_error(error["loc"], error["msg"]) for error in errors) + messages = ( + format_pydantic_error(error["loc"], error["msg"], validated_object) + for error in errors + ) return "\n".join((f"Bad {file_name} content:", *messages)) @@ -115,3 +130,24 @@ def _format_pydantic_error_message(msg: str) -> str: if msg: msg = msg[0].lower() + msg[1:] return msg + + +def _get_line_number( + loc: Sequence[str | int], validated_object: dict[str, Any] | None +) -> int | None: + """Return the line number of a key based on its location.""" + if validated_object is None: + return None + + object_value: dict[str, Any] | Sequence[Any] = validated_object + line_number: int | None = None + + for i, location in enumerate(loc): + if isinstance(location, int) and isinstance(object_value, Sequence): + object_value = object_value[location] # type: ignore[arg-type] + elif isinstance(location, str) and isinstance(object_value, dict): + if i == len(loc) - 1 and f"__line__{location}" in object_value: + line_number = object_value[f"__line__{location}"] + elif location in object_value: + object_value = object_value[location] + return line_number diff --git a/craft_application/util/yaml.py b/craft_application/util/yaml.py index 2b4c17991..646f17505 100644 --- a/craft_application/util/yaml.py +++ b/craft_application/util/yaml.py @@ -22,6 +22,10 @@ from typing import TYPE_CHECKING, Any, TextIO, cast, overload import yaml +from yaml.composer import Composer +from yaml.constructor import Constructor +from yaml.nodes import MappingNode, Node, ScalarNode +from yaml.resolver import BaseResolver from craft_application import errors @@ -97,6 +101,42 @@ def __init__(self, stream: TextIO) -> None: ) +class _SafeLineNoLoader(_SafeYamlLoader): + def __init__(self, stream: TextIO) -> None: + super().__init__(stream) + + self.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _dict_constructor + ) + + def compose_node(self, parent: Node | None, index: int) -> Node | None: + # the line number where the previous token has ended (plus empty lines) + line = self.line + node = Composer.compose_node(self, parent, index) + setattr(node, "__line__", line + 1) # noqa: B010 - used internally, prevent mypy error + return node + + def construct_mapping( + self, + node: MappingNode, + deep: bool = False, # noqa: FBT001, FBT002 - used internally by yaml.SafeLoader + ) -> dict[Hashable, Any]: + node_pair_lst = node.value + node_pair_lst_for_appending = [] + + for key_node, _ in node_pair_lst: + shadow_key_node = ScalarNode( + tag=BaseResolver.DEFAULT_SCALAR_TAG, value="__line__" + key_node.value + ) + shadow_value_node = ScalarNode( + tag=BaseResolver.DEFAULT_SCALAR_TAG, value=key_node.__line__ + ) + node_pair_lst_for_appending.append((shadow_key_node, shadow_value_node)) + + node.value = node_pair_lst + node_pair_lst_for_appending + return Constructor.construct_mapping(self, node, deep=deep) # type: ignore[arg-type] + + def safe_yaml_load(stream: TextIO) -> Any: # noqa: ANN401 - The YAML could be anything """Equivalent to pyyaml's safe_load function, but constraining duplicate keys. @@ -112,6 +152,21 @@ def safe_yaml_load(stream: TextIO) -> Any: # noqa: ANN401 - The YAML could be a raise errors.YamlError.from_yaml_error(filename, error) from error +def safe_yaml_load_with_lines(stream: TextIO) -> Any: # noqa: ANN401 - The YAML could be anything + """Equivalent to pyyaml's safe_load function, but constraining duplicate keys and including line numbers. + + :param stream: Any text-like IO object. + :returns: A dict object mapping the yaml. + """ + try: + # Silencing S506 ("probable use of unsafe loader") because we override it by + # using our own safe loader. + return yaml.load(stream, Loader=_SafeLineNoLoader) # noqa: S506 + except yaml.YAMLError as error: + filename = pathlib.Path(stream.name).name + raise errors.YamlError.from_yaml_error(filename, error) from error + + @overload def dump_yaml( data: Any, # noqa: ANN401 # Any gets passed to pyyaml @@ -146,3 +201,17 @@ def dump_yaml(data: Any, stream: TextIO | None = None, **kwargs: Any) -> str | N return cast( # This cast is needed for pyright but not mypy str | None, yaml.dump(data, stream, Dumper=yaml.SafeDumper, **kwargs) ) + + +def remove_yaml_lines(data: dict[str, Any] | list[Any]) -> dict[str, Any]: + """Recursively flattens a nested dictionary by removing the '__line__' fields.""" + if type(data) is list: + return [remove_yaml_lines(v) for v in data] # type: ignore[return-value] + if type(data) is not dict: + return data # type: ignore[return-value] + # k is only None in one test case + return { + k: remove_yaml_lines(v) + for k, v in data.items() + if k is None or "__line__" not in k # type: ignore[reportUnnecessaryComparison] + } diff --git a/tests/unit/test_application.py b/tests/unit/test_application.py index 898f1f75f..f78291383 100644 --- a/tests/unit/test_application.py +++ b/tests/unit/test_application.py @@ -2141,8 +2141,8 @@ def test_build_planner_errors(tmp_path, monkeypatch, fake_services): expected = ( "Bad testcraft.yaml content:\n" - "- bad value1: 10 (in field 'value1')\n" - "- bad value2: banana (in field 'value2')" + "- bad value1: 10 (in field 'value1' - line 3)\n" + "- bad value2: banana (in field 'value2' - line 4)" ) assert str(err.value) == expected