diff --git a/aws/lambda/benchmark_regression_summary_report/common/benchmark_time_series_api_model.py b/aws/lambda/benchmark_regression_summary_report/common/benchmark_time_series_api_model.py new file mode 100644 index 0000000000..fe7705a6ea --- /dev/null +++ b/aws/lambda/benchmark_regression_summary_report/common/benchmark_time_series_api_model.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List + +import requests + + +# The data class to provide api response model from get_time_series api + + +@dataclass +class TimeRange: + start: str + end: str + + +@dataclass +class BenchmarkTimeSeriesItem: + group_info: Dict[str, Any] + num_of_dp: int + data: List[Dict[str, Any]] = field(default_factory=list) + + +@dataclass +class BenchmarkTimeSeriesApiData: + time_series: List[BenchmarkTimeSeriesItem] + time_range: TimeRange + + +@dataclass +class BenchmarkTimeSeriesApiResponse: + data: BenchmarkTimeSeriesApiData + + @classmethod + def from_request( + cls, url: str, query: dict, timeout: int = 180 + ) -> "BenchmarkTimeSeriesApiResponse": + """ + Send a POST request and parse into BenchmarkTimeSeriesApiResponse. + + Args: + url: API endpoint + query: JSON payload must + timeout: max seconds to wait for connect + response (default: 30) + Returns: + ApiResponse + Raises: + requests.exceptions.RequestException if network/timeout/HTTP error + RuntimeError if the API returns an "error" field or malformed data + """ + resp = requests.post(url, json=query, timeout=timeout) + resp.raise_for_status() + payload = resp.json() + + if "error" in payload: + raise RuntimeError(f"API error: {payload['error']}") + try: + tr = TimeRange(**payload["data"]["time_range"]) + ts = [ + BenchmarkTimeSeriesItem(**item) + for item in payload["data"]["time_series"] + ] + except Exception as e: + raise RuntimeError(f"Malformed API payload: {e}") + return cls(data=BenchmarkTimeSeriesApiData(time_series=ts, time_range=tr)) diff --git a/aws/lambda/benchmark_regression_summary_report/common/config.py b/aws/lambda/benchmark_regression_summary_report/common/config.py new file mode 100644 index 0000000000..ef0586758f --- /dev/null +++ b/aws/lambda/benchmark_regression_summary_report/common/config.py @@ -0,0 +1,94 @@ +from common.config_model import ( + BenchmarkApiSource, + BenchmarkConfig, + BenchmarkRegressionConfigBook, + DayRangeWindow, + Frequency, + Policy, + RangeConfig, + RegressionPolicy, +) + + +# Compiler benchmark regression config +# todo(elainewy): eventually each team should configure +# their own benchmark regression config, currenlty place +# here for lambda + + +COMPILER_BENCHMARK_CONFIG = BenchmarkConfig( + name="Compiler Benchmark Regression", + id="compiler_regression", + source=BenchmarkApiSource( + api_query_url="https://hud.pytorch.org/api/benchmark/get_time_series", + type="benchmark_time_series_api", + # currently we only detect the regression for h100 with dtype bfloat16, and mode inference + # we can extend this to other devices, dtypes and mode in the future + api_endpoint_params_template=""" + { + "name": "compiler_precompute", + "query_params": { + "commits": [], + "compilers": [], + "arch": "h100", + "device": "cuda", + "dtype": "bfloat16", + "granularity": "hour", + "mode": "inference", + "startTime": "{{ startTime }}", + "stopTime": "{{ stopTime }}", + "suites": ["torchbench", "huggingface", "timm_models"], + "workflowId": 0, + "branches": ["main"] + } + } + """, + ), + # set baseline from past 7 days using avg, and compare with the last 1 day + policy=Policy( + frequency=Frequency(value=1, unit="days"), + range=RangeConfig( + baseline=DayRangeWindow(value=7), + comparison=DayRangeWindow(value=2), + ), + metrics={ + "passrate": RegressionPolicy( + name="passrate", + condition="greater_equal", + threshold=0.9, + baseline_aggregation="max", + ), + "geomean": RegressionPolicy( + name="geomean", + condition="greater_equal", + threshold=0.95, + baseline_aggregation="max", + ), + "compression_ratio": RegressionPolicy( + name="compression_ratio", + condition="greater_equal", + threshold=0.9, + baseline_aggregation="max", + ), + }, + notification_config={ + "type": "github", + "repo": "pytorch/test-infra", + "issue": "7081", + }, + ), +) + +BENCHMARK_REGRESSION_CONFIG = BenchmarkRegressionConfigBook( + configs={ + "compiler_regression": COMPILER_BENCHMARK_CONFIG, + } +) + + +def get_benchmark_regression_config(config_id: str) -> BenchmarkConfig: + """Get benchmark regression config by config id""" + try: + return BENCHMARK_REGRESSION_CONFIG[config_id] + except KeyError: + raise ValueError(f"Invalid config id: {config_id}") diff --git a/aws/lambda/benchmark_regression_summary_report/common/config_model.py b/aws/lambda/benchmark_regression_summary_report/common/config_model.py new file mode 100644 index 0000000000..c262b35939 --- /dev/null +++ b/aws/lambda/benchmark_regression_summary_report/common/config_model.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any, ClassVar, Dict, Literal, Optional + +import requests +from jinja2 import Environment, meta, Template + + +# -------- Frequency -------- +@dataclass(frozen=True) +class Frequency: + """ + The frequency of how often the report should be generated. + The minimum frequency we support is 1 day. + Attributes: + value: Number of units (e.g., 7 for 7 days). + unit: Unit of time, either "days" or "weeks". + + Methods: + to_timedelta: Convert frequency into a datetime.timedelta. + get_text: return the frequency in text format + """ + + value: int + unit: Literal["days", "weeks"] + + def to_timedelta(self) -> timedelta: + """Convert frequency N days or M weeks into a datetime.timedelta.""" + if self.unit == "days": + return timedelta(days=self.value) + elif self.unit == "weeks": + return timedelta(weeks=self.value) + else: + raise ValueError(f"Unsupported unit: {self.unit}") + + def to_timedelta_s(self) -> int: + return int(self.to_timedelta().total_seconds()) + + def get_text(self): + return f"{self.value}_{self.unit}" + + +# -------- Source -------- +_JINJA_ENV = Environment(autoescape=False) + + +@dataclass +class BenchmarkApiSource: + """ + Defines the source of the benchmark data we want to query + api_query_url: the url of the api to query + api_endpoint_params_template: the jinjia2 template of the api endpoint's query params + default_ctx: the default context to use when rendering the api_endpoint_params_template + """ + + api_query_url: str + api_endpoint_params_template: str + type: Literal["benchmark_time_series_api", "other"] = "benchmark_time_series_api" + default_ctx: Dict[str, Any] = field(default_factory=dict) + + def required_template_vars(self) -> set[str]: + ast = _JINJA_ENV.parse(self.api_endpoint_params_template) + return set(meta.find_undeclared_variables(ast)) + + def render(self, ctx: Dict[str, Any], strict: bool = True) -> dict: + """Render with caller-supplied context (no special casing for start/end).""" + merged = {**self.default_ctx, **ctx} + + if strict: + required = self.required_template_vars() + missing = required - merged.keys() + if missing: + raise ValueError(f"Missing required vars: {missing}") + rendered = Template(self.api_endpoint_params_template).render(**merged) + return json.loads(rendered) + + +# -------- Policy: range windows -------- +@dataclass +class DayRangeWindow: + value: int + # raw indicates fetch from the source data + source: Literal["raw"] = "raw" + + +@dataclass +class RangeConfig: + """ + Defines the range of baseline and comparison windows for a given policy. + - baseline: the baseline window that build the baseline value + - comparison: the comparison window that we fetch data from to compare against the baseline value + """ + + baseline: DayRangeWindow + comparison: DayRangeWindow + + def total_timedelta(self) -> timedelta: + return timedelta(days=self.baseline.value + self.comparison.value) + + def total_timedelta_s(self) -> int: + return int( + timedelta(days=self.baseline.value + self.comparison.value).total_seconds() + ) + + def comparison_timedelta(self) -> timedelta: + return timedelta(days=self.comparison.value) + + def comparison_timedelta_s(self) -> int: + return int(self.comparison_timedelta().total_seconds()) + + def baseline_timedelta(self) -> timedelta: + return timedelta(days=self.baseline.value) + + def baseline_timedelta_s(self) -> int: + return int(self.baseline_timedelta().total_seconds()) + + +# -------- Policy: metrics -------- +@dataclass +class RegressionPolicy: + """ + Defines the policy for a given metric. + - new value muset be {x} baseline value: + - "greater_than": higher is better; new value must be strictly greater to baseline + - "less_than": lower is better; new value must be strictly lower to baseline + - "equal_to": new value should be ~= baseline * threshold within rel_tol + - "greater_equal": higher is better; new value must be greater or equal to baseline + - "less_equal": lower is better; new value must be less or equal to baseline + """ + + name: str + condition: Literal[ + "greater_than", "less_than", "equal_to", "greater_equal", "less_equal" + ] + threshold: float + baseline_aggregation: Literal["max", "min", "latest", "earliest"] = "max" + rel_tol: float = 1e-3 # used only for "equal_to" + + def is_violation(self, value: float, baseline: float) -> bool: + target = baseline * self.threshold + + if self.condition == "greater_than": + # value must be strictly greater than target + return value <= target + + if self.condition == "greater_equal": + # value must be greater or equal to target + return value < target + + if self.condition == "less_than": + # value must be strictly less than target + return value >= target + + if self.condition == "less_equal": + # value must be less or equal to target + return value > target + + if self.condition == "equal_to": + # |value - target| should be within rel_tol * max(1, |target|) + denom = max(1.0, abs(target)) + return abs(value - target) > self.rel_tol * denom + + raise ValueError(f"Unknown condition: {self.condition}") + + +@dataclass +class BaseNotificationConfig: + # subclasses override this + type_tag: ClassVar[str] = "" + + @classmethod + def matches(cls, d: Dict[str, Any]) -> bool: + return d.get("type") == cls.type_tag + + +@dataclass +class GitHubNotificationConfig(BaseNotificationConfig): + type_tag: ClassVar[str] = "github" + + # actual fields + type: str = "github" + repo: str = "" # e.g. "owner/repo" + issue_number: str = "" # store as str for simplicity + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "GitHubNotificationConfig": + # support 'issue' alias + issue = d.get("issue_number") or d.get("issue") or "" + return cls( + type="github", + repo=d.get("repo", ""), + issue_number=str(issue), + ) + + def create_github_comment(self, body: str, github_token: str) -> Dict[str, Any]: + url = f"https://api.github.com/repos/{self.repo}/issues/{self.issue_number}/comments" + headers = { + "Authorization": f"token {github_token}", + "Accept": "application/vnd.github+json", + "User-Agent": "bench-reporter/1.0", + } + resp = requests.post(url, headers=headers, json={"body": body}) + resp.raise_for_status() + return resp.json() + + +@dataclass +class Policy: + frequency: "Frequency" + range: "RangeConfig" + metrics: Dict[str, "RegressionPolicy"] + + notification_config: Optional[dict[str, Any]] = None + + def get_github_notification_config(self) -> Optional[GitHubNotificationConfig]: + if not self.notification_config: + return None + if self.notification_config.get("type") != "github": + return None + return GitHubNotificationConfig.from_dict(self.notification_config) + + +# -------- Top-level benchmark regression config -------- +@dataclass +class BenchmarkConfig: + """ + Represents a single benchmark regression configuration. + - BenchmarkConfig defines the benchmark regression config for a given benchmark. + - source: defines the source of the benchmark data we want to query + - policy: defines the policy for the benchmark regressions, including frequency to + generate the report, range of the baseline and new values, and regression thresholds + for metrics + - name: the name of the benchmark + - id: the id of the benchmark, this must be unique for each benchmark, and cannot be changed once set + """ + + name: str + id: str + source: BenchmarkApiSource + policy: Policy + + +@dataclass +class BenchmarkRegressionConfigBook: + configs: Dict[str, BenchmarkConfig] = field(default_factory=dict) + + def __getitem__(self, key: str) -> BenchmarkConfig: + config = self.configs.get(key) + if not config: + raise KeyError(f"Config {key} not found") + return config diff --git a/aws/lambda/benchmark_regression_summary_report/lambda_function.py b/aws/lambda/benchmark_regression_summary_report/lambda_function.py new file mode 100644 index 0000000000..d32ebf33b8 --- /dev/null +++ b/aws/lambda/benchmark_regression_summary_report/lambda_function.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python +import argparse +import datetime as dt +import logging +import os +import threading +import time +from typing import Any, Optional + +import clickhouse_connect +import requests +from common.benchmark_time_series_api_model import BenchmarkTimeSeriesApiResponse +from common.config import get_benchmark_regression_config +from common.config_model import BenchmarkApiSource, BenchmarkConfig, Frequency +from dateutil.parser import isoparse + + +BENCHMARK_REGRESSION_REPORT_TABLE = "fortesting.benchmark_regression_report" + +logging.basicConfig( + level=logging.INFO, +) +logger = logging.getLogger() +logger.setLevel("INFO") + +ENVS = { + "GITHUB_TOKEN": os.getenv("GITHUB_TOKEN", ""), + "CLICKHOUSE_ENDPOINT": os.getenv("CLICKHOUSE_ENDPOINT", ""), + "CLICKHOUSE_PASSWORD": os.getenv("CLICKHOUSE_PASSWORD", ""), + "CLICKHOUSE_USERNAME": os.getenv("CLICKHOUSE_USERNAME", ""), +} + +# TODO(elainewy): change this to benchmark.benchmark_regression_report once the table is created +BENCHMARK_REGRESSION_TRACKING_CONFIG_IDS = ["compiler_regression"] + + +def format_ts_with_t(ts: int) -> str: + return dt.datetime.fromtimestamp(ts, tz=dt.timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%S" + ) + + +def truncate_to_hour(ts: dt.datetime) -> dt.datetime: + return ts.replace(minute=0, second=0, microsecond=0) + + +def get_clickhouse_client( + host: str, user: str, password: str +) -> clickhouse_connect.driver.client.Client: + # for local testing only, disable SSL verification + return clickhouse_connect.get_client( + host=host, user=user, password=password, secure=True, verify=False + ) + + return clickhouse_connect.get_client( + host=host, user=user, password=password, secure=True + ) + + +def get_clickhouse_client_environment() -> clickhouse_connect.driver.client.Client: + for name, env_val in ENVS.items(): + if not env_val: + raise ValueError(f"Missing environment variable {name}") + return get_clickhouse_client( + host=ENVS["CLICKHOUSE_ENDPOINT"], + user=ENVS["CLICKHOUSE_USERNAME"], + password=ENVS["CLICKHOUSE_PASSWORD"], + ) + + +class BenchmarkSummaryProcessor: + def __init__( + self, + config_id: str, + end_time: int, + is_dry_run: bool = False, + ) -> None: + self.is_dry_run = is_dry_run + self.config_id = config_id + self.end_time = end_time + + def log_info(self, msg: str): + logger.info("[%s][%s] %s", self.end_time, self.config_id, msg) + + def log_error(self, msg: str): + logger.error("[%s][%s] %s", self.end_time, self.config_id, msg) + + def process( + self, + cc: Optional[clickhouse_connect.driver.client.Client] = None, + args: Optional[argparse.Namespace] = None, + ): + # ensure each thread has its own clickhouse client. clickhouse client + # is not thread-safe. + self.log_info("start process, getting clickhouse client") + if cc is None: + tlocal = threading.local() + if not hasattr(tlocal, "cc") or tlocal.cc is None: + if args: + tlocal.cc = get_clickhouse_client( + args.clickhouse_endpoint, + args.clickhouse_username, + args.clickhouse_password, + ) + else: + tlocal.cc = get_clickhouse_client_environment() + cc = tlocal.cc + self.log_info("done. got clickhouse client") + try: + config = get_benchmark_regression_config(self.config_id) + self.log_info(f"found config with config_id: `{self.config_id}`") + except ValueError as e: + self.log_error(f"Skip process, Invalid config: {e}") + return + except Exception as e: + self.log_error( + f"Unexpected error from get_benchmark_regression_config: {e}" + ) + return + + # check if the current time is > policy's time_delta + previous record_ts from summary_table + report_freq = config.policy.frequency + + should_generate = self._should_generate_report( + cc, self.end_time, self.config_id, report_freq + ) + if not should_generate: + self.log_info( + "Skip generate report", + ) + return + else: + self.log_info( + f"Plan to generate report for time: {format_ts_with_t(self.end_time)} " + f"with frequency {report_freq.get_text()}..." + ) + + self.log_info("get target data") + target, ls, le = self.get_target(config, self.end_time) + if not target: + self.log_info( + f"no target data found for time range [{ls},{le}] with frequency {report_freq.get_text()}..." + ) + return + baseline, bs, be = self.get_baseline(config, self.end_time) + if not baseline: + self.log_info( + f"no baseline data found for time range [{bs},{be}] with frequency {report_freq.get_text()}..." + ) + return + return + + def get_target(self, config: BenchmarkConfig, end_time: int): + data_range = config.policy.range + target_s = end_time - data_range.comparison_timedelta_s() + target_e = end_time + self.log_info( + f"get baseline data for time range [{format_ts_with_t(target_s)},{format_ts_with_t(target_e)}]" + ) + target_data = self._fetch_from_benchmark_ts_api( + config_id=config.id, + start_time=target_s, + end_time=target_e, + source=config.source, + ) + self.log_info( + f"found {len(target_data.time_series)} # of data, with time range {target_data.time_range}", + ) + if not target_data.time_range or not target_data.time_range.end: + return None, target_s, target_e + + target_ts = int(isoparse(target_data.time_range.end).timestamp()) + if not self.should_use_data(target_ts, end_time): + return None, target_s, target_e + return target_data, target_s, target_e + + def get_baseline(self, config: BenchmarkConfig, end_time: int): + data_range = config.policy.range + baseline_s = end_time - data_range.total_timedelta_s() + baseline_e = end_time - data_range.comparison_timedelta_s() + self.log_info( + f"get baseline data for time range [{format_ts_with_t(baseline_s)},{format_ts_with_t(baseline_e)}]" + ) + # fetch baseline from api + raw_data = self._fetch_from_benchmark_ts_api( + config_id=config.id, + start_time=baseline_s, + end_time=baseline_e, + source=config.source, + ) + + self.log_info( + f"get baseline data for time range [{format_ts_with_t(baseline_s)},{format_ts_with_t(baseline_e)}]" + ) + + self.log_info( + f"found {len(raw_data.time_series)} # of data, with time range {raw_data.time_range}", + ) + + baseline_latest_ts = int(isoparse(raw_data.time_range.end).timestamp()) + + if not self.should_use_data(baseline_latest_ts, baseline_e): + self.log_info( + "[get_basline] Skip generate report, no data found during " + f"[{format_ts_with_t(baseline_s)},{format_ts_with_t(baseline_e)}]" + ) + return None, baseline_s, baseline_e + return raw_data, baseline_s, baseline_e + + def should_use_data( + self, + latest_ts: int, + end_time: int, + min_delta: Optional[dt.timedelta] = None, + ) -> bool: + # set default + if not min_delta: + min_delta = dt.timedelta(days=2) + + if not latest_ts: + return False + + cutoff = end_time - min_delta.total_seconds() + + if latest_ts >= cutoff: + return True + self.log_info(f"expect latest data to be after {cutoff}, but got {latest_ts}") + return False + + def _fetch_from_benchmark_ts_api( + self, + config_id: str, + end_time: int, + start_time: int, + source: BenchmarkApiSource, + ): + str_end_time = format_ts_with_t(end_time) + str_start_time = format_ts_with_t(start_time) + query = source.render( + ctx={ + "startTime": str_start_time, + "stopTime": str_end_time, + } + ) + url = source.api_query_url + + self.log_info(f"trying to call {url}") + t0 = time.perf_counter() + try: + resp: BenchmarkTimeSeriesApiResponse = ( + BenchmarkTimeSeriesApiResponse.from_request(url, query) + ) + + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + logger.info( + "[%s] call OK in %.1f ms (query_len=%d)", + config_id, + elapsed_ms, + len(query), + ) + return resp.data + except requests.exceptions.HTTPError as e: + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + # Try to extract a useful server message safely + try: + err_msg = ( + e.response.json().get("error") if e.response is not None else str(e) + ) + except Exception: + err_msg = ( + e.response.text + if (e.response is not None and hasattr(e.response, "text")) + else str(e) + ) + self.log_error( + f"[{config_id}] call FAILED in {elapsed_ms} ms: {err_msg}", + ) + raise + + except Exception as e: + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + self.log_error(f"call CRASHED in {elapsed_ms} ms: {e}") + raise RuntimeError(f"[{config_id}]Fetch failed: {e}") + + def _should_generate_report( + self, + cc: clickhouse_connect.driver.client.Client, + end_time: int, + config_id: str, + f: Frequency, + ) -> bool: + def _get_latest_record_ts( + cc: clickhouse_connect.driver.Client, + config_id: str, + ) -> Optional[int]: + table = BENCHMARK_REGRESSION_REPORT_TABLE + res = cc.query( + f""" + SELECT toUnixTimestamp(max(last_record_ts)) + FROM {table} + WHERE report_id = {{config_id:String}} + """, + parameters={"config_id": config_id}, + ) + + if not res.result_rows or res.result_rows[0][0] is None: + return None + return int(res.result_rows[0][0]) + + freq_delta = f.to_timedelta_s() + latest_record_ts = _get_latest_record_ts(cc, config_id) + # No report exists yet, generate + if not latest_record_ts: + self.log_info( + f"no latest record ts from db for the config_id, got {latest_record_ts}" + ) + return True + self.log_info(f"found latest record ts from db {latest_record_ts}") + time_boundary = latest_record_ts + freq_delta + should_generate = end_time > time_boundary + + if not should_generate: + self.log_info( + f"[{f.get_text()}] skip generate report. end_time({format_ts_with_t(end_time)})" + f" must greater than time_boundary({format_ts_with_t(time_boundary)})" + f"based on latest_record_ts({format_ts_with_t(latest_record_ts)})", + ) + else: + self.log_info( + f"[{f.get_text()}]plan to generate report. end_time({format_ts_with_t(end_time)}) is greater than " + f"time_boundary({format_ts_with_t(time_boundary)})" + f"based on latest_record_ts({format_ts_with_t(latest_record_ts)})", + ) + return should_generate + + +def main( + config_id: str, + github_access_token: str = "", + args: Optional[argparse.Namespace] = None, + *, + is_dry_run: bool = False, +): + if not github_access_token: + raise ValueError("Missing environment variable GITHUB_TOKEN") + + if not config_id: + raise ValueError("Missing required parameter: config_id") + + end_time = dt.datetime.now(dt.timezone.utc).replace( + minute=0, second=0, microsecond=0 + ) + end_time_ts = int(end_time.timestamp()) + logger.info( + "[Main] current time with hour granularity(utc) %s with unix timestamp %s", + end_time, + end_time_ts, + ) + logger.info("[Main] start work ....") + + # caution, raise exception may lead lambda to retry + try: + processor = BenchmarkSummaryProcessor( + config_id=config_id, end_time=end_time_ts, is_dry_run=is_dry_run + ) + processor.process(args=args) + except Exception as e: + logger.error(f"[Main] failed to process config_id {config_id}, error: {e}") + raise + logger.info(" [Main] Done. work completed.") + + +def lambda_handler(event: Any, context: Any) -> None: + """ + Main method to run in aws lambda environment + """ + config_id = event.get("config_id") + if not config_id: + raise ValueError("Missing required parameter: config_id") + + main( + config_id=config_id, + github_access_token=ENVS["GITHUB_TOKEN"], + ) + return + + +def parse_args() -> argparse.Namespace: + """ + Parse command line args, this is mainly used for local test environment. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--dry-run", + dest="dry_run", + action="store_true", + help="Enable dry-run mode", + ) + parser.add_argument( + "--no-dry-run", + dest="dry_run", + action="store_false", + help="Disable dry-run mode", + ) + parser.add_argument( + "--config-id", + type=str, + help="the config id to run", + ) + parser.add_argument( + "--clickhouse-endpoint", + default=ENVS["CLICKHOUSE_ENDPOINT"], + type=str, + help="the clickhouse endpoint, the clickhouse_endpoint " + + "name is https://{clickhouse_endpoint}:{port} for full url ", + ) + parser.add_argument( + "--clickhouse-username", + type=str, + default=ENVS["CLICKHOUSE_USERNAME"], + help="the clickhouse username", + ) + parser.add_argument( + "--clickhouse-password", + type=str, + default=ENVS["CLICKHOUSE_PASSWORD"], + help="the clickhouse password for the user name", + ) + parser.add_argument( + "--github-access-token", + type=str, + default=ENVS["GITHUB_TOKEN"], + help="the github access token to access github api", + ) + parser.set_defaults(dry_run=True) # default is True + args, _ = parser.parse_known_args() + return args + + +def local_run() -> None: + """ + method to run in local test environment + """ + + args = parse_args() + # update environment variables for input parameters + main( + config_id=args.config_id, + github_access_token=args.github_access_token, + args=args, + is_dry_run=args.dry_run, + ) + + +if __name__ == "__main__": + local_run()