diff --git a/google/cloud/bigquery_v2/services/job_service/pagers.py b/google/cloud/bigquery_v2/services/job_service/pagers.py index 83f723917..b976c1d75 100644 --- a/google/cloud/bigquery_v2/services/job_service/pagers.py +++ b/google/cloud/bigquery_v2/services/job_service/pagers.py @@ -67,7 +67,6 @@ def __init__( retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), - ): """Instantiate the pager. diff --git a/scripts/microgenerator/config.yaml b/scripts/microgenerator/config.yaml new file mode 100644 index 000000000..a36e7c728 --- /dev/null +++ b/scripts/microgenerator/config.yaml @@ -0,0 +1,47 @@ +# config.yaml + +# The name of the service, used for variable names and comments. +service_name: "bigquery" + +# A list of paths to the source code files to be parsed. +# Globs are supported. +source_files: + - "autogen/google/cloud/bigquery_v2/services/dataset_service/client.py" + - "autogen/google/cloud/bigquery_v2/services/job_service/client.py" + - "autogen/google/cloud/bigquery_v2/services/model_service/client.py" + - "autogen/google/cloud/bigquery_v2/services/project_service/client.py" + - "autogen/google/cloud/bigquery_v2/services/routine_service/client.py" + - "autogen/google/cloud/bigquery_v2/services/row_access_policy_service/client.py" + - "autogen/google/cloud/bigquery_v2/services/table_service/client.py" + +# Filtering rules for classes and methods. +filter: + classes: + # Only include classes with these suffixes. + include_suffixes: + - "DatasetServiceClient" + - "JobServiceClient" + - "ModelServiceClient" + methods: + # Include methods with these prefixes. + include_prefixes: + # - "batch_delete_" + # - "cancel_" + # - "create_" + # - "delete_" + - "get_" + - "insert_" + - "list_" + - "patch_" + # - "undelete_" + # - "update_" + # Exclude methods with these prefixes. + exclude_prefixes: + - "get_mtls_endpoint_and_cert_source" + +# A list of templates to render and their corresponding output files. +templates: + - template: "autogen/scripts/microgenerator/templates/client.py.j2" + output: "autogen/google/cloud/bigquery_v2/services/centralized_service/client.py" + # - template: "test_bigqueryclient.py.j2" + # output: "tests/unit/test_bigqueryclient.py" diff --git a/scripts/microgenerator/generate.py b/scripts/microgenerator/generate.py new file mode 100644 index 000000000..6c394f3f4 --- /dev/null +++ b/scripts/microgenerator/generate.py @@ -0,0 +1,510 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A dual-purpose module for Python code analysis and BigQuery client generation. + +When run as a script, it generates the BigQueryClient source code. +When imported, it provides utility functions for parsing and exploring +any Python codebase using the `ast` module. +""" + +import ast +import os +import argparse +import glob +import re +from collections import defaultdict +from typing import List, Dict, Any, Iterator + +import name_utils +import utils + +# ============================================================================= +# Section 1: Generic AST Analysis Utilities +# ============================================================================= + +class CodeAnalyzer(ast.NodeVisitor): + """ + A node visitor to traverse an AST and extract structured information + about classes, methods, and their arguments. + """ + + def __init__(self): + self.structure: List[Dict[str, Any]] = [] + self.imports: set[str] = set() + self.types: set[str] = set() + self._current_class_info: Dict[str, Any] | None = None + self._is_in_method: bool = False + + def _get_type_str(self, node: ast.AST | None) -> str | None: + """Recursively reconstructs a type annotation string from an AST node.""" + if node is None: + return None + # Handles simple names like 'str', 'int', 'HttpRequest' + if isinstance(node, ast.Name): + return node.id + # Handles dotted names like 'service.GetDatasetRequest' + if isinstance(node, ast.Attribute): + # Attempt to reconstruct the full dotted path + parts = [] + curr = node + while isinstance(curr, ast.Attribute): + parts.append(curr.attr) + curr = curr.value + if isinstance(curr, ast.Name): + parts.append(curr.id) + return ".".join(reversed(parts)) + # Handles subscripted types like 'list[str]', 'Optional[...]' + if isinstance(node, ast.Subscript): + value_str = self._get_type_str(node.value) + slice_str = self._get_type_str(node.slice) + return f"{value_str}[{slice_str}]" + # Handles tuples inside subscripts, e.g., 'dict[str, int]' + if isinstance(node, ast.Tuple): + return ", ".join( + [s for s in (self._get_type_str(e) for e in node.elts) if s] + ) + # Handles forward references as strings, e.g., '"Dataset"' + if isinstance(node, ast.Constant): + return str(node.value) + return None # Fallback for unhandled types + + def _collect_types_from_node(self, node: ast.AST | None) -> None: + """Recursively traverses an annotation node to find and collect all type names.""" + if node is None: + return + + if isinstance(node, ast.Name): + self.types.add(node.id) + elif isinstance(node, ast.Attribute): + type_str = self._get_type_str(node) + if type_str: + self.types.add(type_str) + elif isinstance(node, ast.Subscript): + self._collect_types_from_node(node.value) + self._collect_types_from_node(node.slice) + elif isinstance(node, (ast.Tuple, ast.List)): + for elt in node.elts: + self._collect_types_from_node(elt) + elif isinstance(node, ast.Constant) and isinstance(node.value, str): + self.types.add(node.value) + elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): # For | union type + self._collect_types_from_node(node.left) + self._collect_types_from_node(node.right) + + def visit_Import(self, node: ast.Import) -> None: + """Catches 'import X' and 'import X as Y' statements.""" + for alias in node.names: + if alias.asname: + self.imports.add(f"import {alias.name} as {alias.asname}") + else: + self.imports.add(f"import {alias.name}") + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + """Catches 'from X import Y' statements.""" + module = node.module or "" + if not module: + module = "." * node.level + else: + module = "." * node.level + module + + names = [] + for alias in node.names: + if alias.asname: + names.append(f"{alias.name} as {alias.asname}") + else: + names.append(alias.name) + + if names: + self.imports.add(f"from {module} import {', '.join(names)}") + self.generic_visit(node) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + """Visits a class definition node.""" + class_info = { + "class_name": node.name, + "methods": [], + "attributes": [], + } + self.structure.append(class_info) + self._current_class_info = class_info + self.generic_visit(node) + self._current_class_info = None + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + """Visits a function/method definition node.""" + if self._current_class_info: # This is a method + + args_info = [] + for arg in node.args.args: + type_str = self._get_type_str(arg.annotation) + args_info.append({"name": arg.arg, "type": type_str}) + self._collect_types_from_node(arg.annotation) + + # Collect return type + return_type = self._get_type_str(node.returns) + self._collect_types_from_node(node.returns) + + method_info = { + "method_name": node.name, + "args": args_info, + "return_type": return_type, + } + self._current_class_info["methods"].append(method_info) + + # Visit nodes inside the method to find instance attributes. + self._is_in_method = True + self.generic_visit(node) + self._is_in_method = False + + def _add_attribute(self, attr_name: str): + """Adds a unique attribute to the current class context.""" + if self._current_class_info: + if attr_name not in self._current_class_info["attributes"]: + self._current_class_info["attributes"].append(attr_name) + + def visit_Assign(self, node: ast.Assign) -> None: + """Handles attribute assignments: `x = ...` and `self.x = ...`.""" + if self._current_class_info: + for target in node.targets: + # Instance attribute: self.x = ... + if ( + isinstance(target, ast.Attribute) + and isinstance(target.value, ast.Name) + and target.value.id == "self" + ): + self._add_attribute(target.attr) + # Class attribute: x = ... (only if not inside a method) + elif isinstance(target, ast.Name) and not self._is_in_method: + self._add_attribute(target.id) + self.generic_visit(node) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + """Handles annotated assignments: `x: int = ...` and `self.x: int = ...`.""" + if self._current_class_info: + target = node.target + # Instance attribute: self.x: int = ... + if ( + isinstance(target, ast.Attribute) + and isinstance(target.value, ast.Name) + and target.value.id == "self" + ): + self._add_attribute(target.attr) + # Class attribute: x: int = ... (only if not inside a method) + elif isinstance(target, ast.Name) and not self._is_in_method: + self._add_attribute(target.id) + self.generic_visit(node) + + +def parse_code(code: str) -> tuple[List[Dict[str, Any]], set[str], set[str]]: + """ + Parses a string of Python code into a structured list of classes, a set of imports, + and a set of all type annotations found. + + Args: + code: A string containing Python code. + + Returns: + A tuple containing: + - A list of dictionaries, where each dictionary represents a class. + - A set of strings, where each string is an import statement. + - A set of strings, where each string is a type annotation. + """ + tree = ast.parse(code) + analyzer = CodeAnalyzer() + analyzer.visit(tree) + return analyzer.structure, analyzer.imports, analyzer.types + + +def parse_file(file_path: str) -> tuple[List[Dict[str, Any]], set[str], set[str]]: + """ + Parses a Python file into a structured list of classes, a set of imports, + and a set of all type annotations found. + + Args: + file_path: The absolute path to the Python file. + + Returns: + A tuple containing the class structure, a set of import statements, + and a set of type annotations. + """ + with open(file_path, "r", encoding="utf-8") as source: + code = source.read() + return parse_code(code) + + +def list_code_objects( + path: str, + show_methods: bool = False, + show_attributes: bool = False, + show_arguments: bool = False, +) -> Any: + """ + Lists classes and optionally their methods, attributes, and arguments + from a given Python file or directory. + + This function consolidates the functionality of the various `list_*` functions. + + Args: + path (str): The absolute path to a Python file or directory. + show_methods (bool): Whether to include methods in the output. + show_attributes (bool): Whether to include attributes in the output. + show_arguments (bool): If True, includes method arguments. Implies show_methods. + + Returns: + - If `show_methods` and `show_attributes` are both False, returns a + sorted `List[str]` of class names (mimicking `list_classes`). + - Otherwise, returns a `Dict[str, Dict[str, Any]]` containing the + requested details about each class. + """ + # If show_arguments is True, we must show methods. + if show_arguments: + show_methods = True + + results = defaultdict(dict) + all_class_keys = [] + + def process_structure(structure: List[Dict[str, Any]], file_name: str | None = None): + """Populates the results dictionary from the parsed AST structure.""" + for class_info in structure: + key = class_info["class_name"] + if file_name: + key = f"{key} (in {file_name})" + + all_class_keys.append(key) + + # Skip filling details if not needed for the dictionary. + if not show_methods and not show_attributes: + continue + + if show_attributes: + results[key]["attributes"] = sorted(class_info["attributes"]) + + if show_methods: + if show_arguments: + method_details = {} + # Sort methods by name for consistent output + for method in sorted(class_info["methods"], key=lambda m: m["method_name"]): + method_details[method["method_name"]] = method["args"] + results[key]["methods"] = method_details + else: + results[key]["methods"] = sorted( + [m["method_name"] for m in class_info["methods"]] + ) + + # Determine if the path is a file or directory and process accordingly + if os.path.isfile(path) and path.endswith(".py"): + structure, _, _ = parse_file(path) + process_structure(structure) + elif os.path.isdir(path): + # This assumes `utils.walk_codebase` is defined elsewhere. + for file_path in utils.walk_codebase(path): + structure, _, _ = parse_file(file_path) + process_structure(structure, file_name=os.path.basename(file_path)) + + # Return the data in the desired format based on the flags + if not show_methods and not show_attributes: + return sorted(all_class_keys) + else: + return dict(results) + + +# ============================================================================= +# Section 2: Source file data gathering +# ============================================================================= + + +def _should_include_class(class_name: str, class_filters: Dict[str, Any]) -> bool: + """Checks if a class should be included based on filter criteria.""" + if class_filters.get("include_suffixes"): + if not class_name.endswith(tuple(class_filters["include_suffixes"])): + return False + if class_filters.get("exclude_suffixes"): + if class_name.endswith(tuple(class_filters["exclude_suffixes"])): + return False + return True + + +def _should_include_method(method_name: str, method_filters: Dict[str, Any]) -> bool: + """Checks if a method should be included based on filter criteria.""" + if method_filters.get("include_prefixes"): + if not any(method_name.startswith(p) for p in method_filters["include_prefixes"]): + return False + if method_filters.get("exclude_prefixes"): + if any(method_name.startswith(p) for p in method_filters["exclude_prefixes"]): + return False + return True + + +def analyze_source_files(config: Dict[str, Any]) -> tuple[Dict[str, Any], set[str], set[str]]: + """ + Analyzes source files per the configuration to extract class and method info, + as well as information on imports and typehints. + + Args: + config: The generator's configuration dictionary. + + Returns: + A tuple containing: + - A dictionary containing the data needed for template rendering. + - A set of all import statements required by the parsed methods. + - A set of all type annotations found in the parsed methods. + """ + parsed_data = defaultdict(dict) + all_imports: set[str] = set() + all_types: set[str] = set() + + source_patterns = config.get("source_files", []) + filter_rules = config.get("filter", {}) + class_filters = filter_rules.get("classes", {}) + method_filters = filter_rules.get("methods", {}) + + source_files = [] + for pattern in source_patterns: + source_files.extend(glob.glob(pattern, recursive=True)) + + for file_path in source_files: + structure, imports, types = parse_file(file_path) + all_imports.update(imports) + all_types.update(types) + + for class_info in structure: + class_name = class_info["class_name"] + if not _should_include_class(class_name, class_filters): + continue + + parsed_data[class_name] # Ensure class is in dict + + for method in class_info["methods"]: + method_name = method["method_name"] + if not _should_include_method(method_name, method_filters): + continue + parsed_data[class_name][method_name] = method + + return parsed_data, all_imports, all_types + + +def _format_class_name(method_name: str, suffix: str = "Request") -> str: + """Formats a class name from a method name.""" + return "".join(word.capitalize() for word in method_name.split("_")) + suffix + +# ============================================================================= +# Section 3: Code Generation +# ============================================================================= + +def _generate_import_statement(context: List[Dict[str, Any]], key: str, path: str) -> str: + """Generates a formatted import statement from a list of context dictionaries. + + Args: + context: A list of dictionaries containing the data. + key: The key to extract from each dictionary in the context. + path: The base import path (e.g., "google.cloud.bigquery_v2.services"). + + Returns: + A formatted, multi-line import statement string. + """ + names = sorted(list(set([item[key] for item in context]))) + names_str = ",\n ".join(names) + return f"from {path} import (\n {names_str}\n)" + + +def generate_code(config: Dict[str, Any], analysis_results: tuple) -> None: + """ + Generates source code files using Jinja2 templates. + """ + data, all_imports, all_types = analysis_results + + templates_config = config.get("templates", []) + for item in templates_config: + template_path = item["template"] + output_path = item["output"] + + template = utils.load_template(template_path) + methods_context = [] + for class_name, methods in data.items(): + for method_name, method_info in methods.items(): + + methods_context.append( + { + "name": method_name, + "class_name": class_name, + "return_type": method_info["return_type"], + } + ) + + # Prepare imports for the template + services_context = [] + client_class_names = sorted(list(set([m['class_name'] for m in methods_context]))) + + for class_name in client_class_names: + service_name_cluster = name_utils.generate_service_names(class_name) + services_context.append(service_name_cluster) + + # Also need to update methods_context to include the service_name and module_name + # so the template knows which client to use for each method. + class_to_service_map = {s['service_client_class']: s for s in services_context} + for method in methods_context: + service_info = class_to_service_map.get(method['class_name']) + if service_info: + method['service_name'] = service_info['service_name'] + method['service_module_name'] = service_info['service_module_name'] + + # Prepare new imports + service_imports = [ + _generate_import_statement( + services_context, "service_module_name", "google.cloud.bigquery_v2.services" + ) + ] + + # Prepare type imports + type_imports = [ + _generate_import_statement( + services_context, "service_name", "google.cloud.bigquery_v2.types" + ) + ] + + final_code = template.render( + service_name=config.get("service_name"), + methods=methods_context, + services=services_context, + service_imports=service_imports, + type_imports=type_imports, + ) + + utils.write_code_to_file(output_path, final_code) + + + +# ============================================================================= +# Section 4: Main Execution +# ============================================================================= + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="A generic Python code generator for clients." + ) + parser.add_argument( + "config", help="Path to the YAML configuration file." + ) + args = parser.parse_args() + + config = utils.load_config(args.config) + analysis_results = analyze_source_files(config) + generate_code(config, analysis_results) + + # TODO: Ensure blacken gets called on the generated source files as a final step. diff --git a/scripts/microgenerator/name_utils.py b/scripts/microgenerator/name_utils.py new file mode 100644 index 000000000..560a18751 --- /dev/null +++ b/scripts/microgenerator/name_utils.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A utility module for handling name transformations.""" + +import re +from typing import Dict + +def to_snake_case(name: str) -> str: + """Converts a PascalCase name to snake_case.""" + return re.sub(r"(? Dict[str, str]: + """ + Generates various name formats for a service based on its client class name. + + Args: + class_name: The PascalCase name of the service client class + (e.g., 'DatasetServiceClient'). + + Returns: + A dictionary containing different name variations. + """ + snake_case_name = to_snake_case(class_name) + module_name = snake_case_name.replace("_client", "") + service_name = module_name.replace("_service", "") + + return { + "service_name": service_name, + "service_module_name": module_name, + "service_client_class": class_name, + "property_name": snake_case_name, # Direct use of snake_case_name + } diff --git a/scripts/microgenerator/templates/client.py.j2 b/scripts/microgenerator/templates/client.py.j2 new file mode 100644 index 000000000..98c2b4e99 --- /dev/null +++ b/scripts/microgenerator/templates/client.py.j2 @@ -0,0 +1,103 @@ +# TODO: Add a header if needed. + +# ======== 🦕 HERE THERE BE DINOSAURS 🦖 ========= +# This content is subject to significant change. Not for review yet. +# Included as a proof of concept for context or testing ONLY. +# ================================================ + +# Imports +import os + +from typing import ( + Dict, + Optional, + Sequence, + Tuple, + Union, +) + + +{% for imp in service_imports %} +{{ imp }} +{% endfor %} +from google.cloud.bigquery_v2.services.centralized_service import _helpers + +{% for imp in type_imports %} +{{ imp }} +{% endfor %} +from google.cloud.bigquery_v2.types import dataset_reference + +from google.api_core import client_options as client_options_lib +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as auth_credentials + +# Create type aliases +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +DatasetIdentifier = Union[str, dataset_reference.DatasetReference] + +DEFAULT_RETRY: OptionalRetry = gapic_v1.method.DEFAULT +DEFAULT_TIMEOUT: Union[float, object] = gapic_v1.method.DEFAULT +DEFAULT_METADATA: Sequence[Tuple[str, Union[str, bytes]]] = () + + +class BigQueryClient: + def __init__(self, credentials=None, client_options=None): + self._clients = {} + self._credentials = credentials + self._client_options = client_options + + # --- *METHOD SECTION --- +{% for method in methods %} + def {{ method.name }}( + self, + *, + request: Optional["{{ method.service_module_name.replace('_service', '') }}.{{ method.name.replace('_', ' ').title().replace(' ', '') }}Request"] = None, + retry: OptionalRetry = DEFAULT_RETRY, + timeout: Union[float, object] = DEFAULT_TIMEOUT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = DEFAULT_METADATA, + ) -> "{{ method.return_type }}": + """ + TODO: Docstring is purposefully blank. microgenerator will add automatically. + """ + + return self.{{ method.service_module_name }}_client.{{ method.name }}( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) +{% endfor %} + +{#- *ServiceClient Properties Section: methods to get/set service clients -#} + # --- *SERVICECLIENT PROPERTIES --- +{% for service in services %} + @property + def {{ service.property_name }}(self): + if "{{ service.service_name }}" not in self._clients: + self._clients["{{ service.service_name }}"] = {{ service.service_module_name }}.{{ service.service_client_class }}( + credentials=self._credentials, client_options=self._client_options + ) + return self._clients["{{ service.service_name }}"] + + @{{ service.property_name }}.setter + def {{ service.property_name }}(self, value): + if not isinstance(value, {{ service.service_module_name }}.{{ service.service_client_class }}): + raise TypeError( + "Expected an instance of {{ service.service_module_name }}.{{ service.service_client_class }}." + ) + self._clients["{{ service.service_name }}"] = value +{% endfor %} + +{#- Helper Section: methods included from partial template -#} + {%- include "partials/_client_helpers.j2" %} + + +# ======== 🦕 HERE THERE WERE DINOSAURS 🦖 ========= +# The above content is subject to significant change. Not for review yet. +# Included as a proof of concept for context or testing ONLY. +# ================================================ \ No newline at end of file diff --git a/scripts/microgenerator/templates/partials/_client_helpers.j2 b/scripts/microgenerator/templates/partials/_client_helpers.j2 new file mode 100644 index 000000000..a6e6343b5 --- /dev/null +++ b/scripts/microgenerator/templates/partials/_client_helpers.j2 @@ -0,0 +1,49 @@ + {# + This is a partial template file intended to be included in other templates. + It contains helper methods for the BigQueryClient class. + #} + + # --- HELPER METHODS --- + def _parse_dataset_path(self, dataset_path: str) -> Tuple[Optional[str], str]: + """ + Helper to parse project_id and/or dataset_id from a string identifier. + + Args: + dataset_path: A string in the format 'project_id.dataset_id' or + 'dataset_id'. + + Returns: + A tuple of (project_id, dataset_id). + """ + if "." in dataset_path: + # Use rsplit to handle legacy paths like `google.com:my-project.my_dataset`. + project_id, dataset_id = dataset_path.rsplit(".", 1) + return project_id, dataset_id + return self.project, dataset_path + + def _parse_dataset_id_to_dict(self, dataset_id: "DatasetIdentifier") -> dict: + """ + Helper to create a dictionary from a project_id and dataset_id to pass + internally between helper functions. + + Args: + dataset_id: A string or DatasetReference. + + Returns: + A dict of {"project_id": project_id, "dataset_id": dataset_id_str }. + """ + if isinstance(dataset_id, str): + project_id, dataset_id_str = self._parse_dataset_path(dataset_id) + return {"project_id": project_id, "dataset_id": dataset_id_str} + elif isinstance(dataset_id, dataset_reference.DatasetReference): + return { + "project_id": dataset_id.project_id, + "dataset_id": dataset_id.dataset_id, + } + else: + raise TypeError(f"Invalid type for dataset_id: {type(dataset_id)}") + + def _parse_project_id_to_dict(self, project_id: Optional[str] = None) -> dict: + """Helper to create a request dictionary from a project_id.""" + final_project_id = project_id or self.project + return {"project_id": final_project_id} \ No newline at end of file diff --git a/scripts/microgenerator/utils.py b/scripts/microgenerator/utils.py new file mode 100644 index 000000000..324c20662 --- /dev/null +++ b/scripts/microgenerator/utils.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Utility functions for the microgenerator.""" + +import os +import sys +import yaml +import jinja2 +from typing import Dict, Any, Iterator, Callable + + +def _load_resource( + loader_func: Callable, + path: str, + not_found_exc: type, + parse_exc: type, + resource_type_name: str, +) -> Any: + """ + Generic resource loader with common error handling. + + Args: + loader_func: A callable that performs the loading and returns the resource. + It should raise appropriate exceptions on failure. + path: The path/name of the resource for use in error messages. + not_found_exc: The exception type to catch for a missing resource. + parse_exc: The exception type to catch for a malformed resource. + resource_type_name: A human-readable name for the resource type. + """ + try: + return loader_func() + except not_found_exc: + print(f"Error: {resource_type_name} '{path}' not found.", file=sys.stderr) + sys.exit(1) + except parse_exc as e: + print( + f"Error: Could not load {resource_type_name.lower()} from '{path}': {e}", + file=sys.stderr, + ) + sys.exit(1) + + +def load_template(template_path: str) -> jinja2.Template: + """ + Loads a Jinja2 template from a given file path. + """ + template_dir = os.path.dirname(template_path) + template_name = os.path.basename(template_path) + + def _loader() -> jinja2.Template: + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(template_dir or "."), + trim_blocks=True, + lstrip_blocks=True, + ) + return env.get_template(template_name) + + return _load_resource( + loader_func=_loader, + path=template_path, + not_found_exc=jinja2.exceptions.TemplateNotFound, + parse_exc=jinja2.exceptions.TemplateError, + resource_type_name="Template file", + ) + + +def load_config(config_path: str) -> Dict[str, Any]: + """Loads the generator's configuration from a YAML file.""" + + def _loader() -> Dict[str, Any]: + with open(config_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + return _load_resource( + loader_func=_loader, + path=config_path, + not_found_exc=FileNotFoundError, + parse_exc=yaml.YAMLError, + resource_type_name="Configuration file", + ) + + +def walk_codebase(path: str) -> Iterator[str]: + """Yields all .py file paths in a directory.""" + for root, _, files in os.walk(path): + for file in files: + if file.endswith(".py"): + yield os.path.join(root, file) + + +def write_code_to_file(output_path: str, content: str): + """Ensures the output directory exists and writes content to the file.""" + output_dir = os.path.dirname(output_path) + + # An empty output_dir means the file is in the current directory. + if output_dir: + print(f" Ensuring output directory exists: {os.path.abspath(output_dir)}") + os.makedirs(output_dir, exist_ok=True) + if not os.path.isdir(output_dir): + print(f" Error: Output directory was not created.", file=sys.stderr) + sys.exit(1) + + print(f" Writing generated code to: {os.path.abspath(output_path)}") + with open(output_path, "w", encoding="utf-8") as f: + f.write(content) + print(f"Successfully generated {output_path}") \ No newline at end of file