Skip to content

Commit eb3befa

Browse files
committed
regression_db
ghstack-source-id: 9de515e Pull-Request: #7089
1 parent b55fed9 commit eb3befa

File tree

7 files changed

+818
-0
lines changed

7 files changed

+818
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*.zip
2+
deployment/
3+
venv/
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from dataclasses import dataclass, field
2+
from typing import Optional, List, Dict, Any
3+
import requests
4+
5+
# The data class to provide api response model from get_time_series api
6+
7+
@dataclass
8+
class TimeRange:
9+
start: str
10+
end: str
11+
12+
13+
@dataclass
14+
class BenchmarkTimeSeriesItem:
15+
group_info: Dict[str, Any]
16+
num_of_dp: int
17+
data: List[Dict[str, Any]] = field(default_factory=list)
18+
19+
20+
@dataclass
21+
class BenchmarkTimeSeriesApiData:
22+
time_series: List[BenchmarkTimeSeriesItem]
23+
time_range: TimeRange
24+
25+
26+
@dataclass
27+
class BenchmarkTimeSeriesApiResponse:
28+
data: BenchmarkTimeSeriesApiData
29+
30+
@classmethod
31+
def from_request(
32+
cls, url: str, query: dict, timeout: int = 180
33+
) -> "BenchmarkTimeSeriesApiResponse":
34+
"""
35+
Send a POST request and parse into BenchmarkTimeSeriesApiResponse.
36+
37+
Args:
38+
url: API endpoint
39+
query: JSON payload must
40+
timeout: max seconds to wait for connect + response (default: 30)
41+
Returns:
42+
ApiResponse
43+
Raises:
44+
requests.exceptions.RequestException if network/timeout/HTTP error
45+
RuntimeError if the API returns an "error" field or malformed data
46+
"""
47+
resp = requests.post(url, json=query, timeout=timeout)
48+
resp.raise_for_status()
49+
payload = resp.json()
50+
51+
if "error" in payload:
52+
raise RuntimeError(f"API error: {payload['error']}")
53+
try:
54+
tr = TimeRange(**payload["data"]["time_range"])
55+
ts = [
56+
BenchmarkTimeSeriesItem(**item)
57+
for item in payload["data"]["time_series"]
58+
]
59+
except Exception as e:
60+
raise RuntimeError(f"Malformed API payload: {e}")
61+
return cls(data=BenchmarkTimeSeriesApiData(time_series=ts, time_range=tr))
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from common.config_model import (
2+
BenchmarkApiSource,
3+
BenchmarkConfig,
4+
BenchmarkRegressionConfigBook,
5+
DayRangeWindow,
6+
Frequency,
7+
RegressionPolicy,
8+
Policy,
9+
RangeConfig,
10+
)
11+
12+
# Compiler benchmark regression config
13+
# todo(elainewy): eventually each team should configure their own benchmark regression config, currenlty place here for lambda
14+
15+
16+
COMPILER_BENCHMARK_CONFIG = BenchmarkConfig(
17+
name="Compiler Benchmark Regression",
18+
id="compiler_regression",
19+
source=BenchmarkApiSource(
20+
api_query_url="http://localhost:3000/api/benchmark/get_time_series",
21+
type="benchmark_time_series_api",
22+
# currently we only detect the regression for h100 with dtype bfloat16, and mode inference
23+
# we can extend this to other devices, dtypes and mode in the future
24+
api_endpoint_params_template="""
25+
{
26+
"name": "compiler_precompute",
27+
"query_params": {
28+
"commits": [],
29+
"compilers": [],
30+
"arch": "h100",
31+
"device": "cuda",
32+
"dtype": "bfloat16",
33+
"granularity": "hour",
34+
"mode": "inference",
35+
"startTime": "{{ startTime }}",
36+
"stopTime": "{{ stopTime }}",
37+
"suites": ["torchbench", "huggingface", "timm_models"],
38+
"workflowId": 0,
39+
"branches": ["main"]
40+
}
41+
}
42+
""",
43+
),
44+
# set baseline from past 7 days using avg, and compare with the last 1 day
45+
policy=Policy(
46+
frequency=Frequency(value=1, unit="days"),
47+
range=RangeConfig(
48+
baseline=DayRangeWindow(value=7),
49+
comparison=DayRangeWindow(value=2),
50+
),
51+
metrics={
52+
"passrate": RegressionPolicy(
53+
name="passrate", condition="greater_equal", threshold=0.9, baseline_aggregation="max",
54+
),
55+
"geomean": RegressionPolicy(
56+
name="geomean", condition="greater_equal", threshold=0.95,baseline_aggregation="max",
57+
),
58+
"compression_ratio": RegressionPolicy(
59+
name="compression_ratio", condition="greater_equal", threshold=0.9, baseline_aggregation="max",
60+
),
61+
},
62+
notification_config={
63+
"type": "github",
64+
"repo": "pytorch/test-infra",
65+
"issue": "7081",
66+
},
67+
),
68+
)
69+
70+
BENCHMARK_REGRESSION_CONFIG = BenchmarkRegressionConfigBook(
71+
configs={
72+
"compiler_regression": COMPILER_BENCHMARK_CONFIG,
73+
}
74+
)
75+
76+
def get_benchmark_regression_config(config_id: str) -> BenchmarkConfig:
77+
"""Get benchmark regression config by config id"""
78+
try:
79+
return BENCHMARK_REGRESSION_CONFIG[config_id]
80+
except KeyError:
81+
raise ValueError(f"Invalid config id: {config_id}")
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
from __future__ import annotations
2+
from dataclasses import dataclass, field, fields
3+
from typing import Any, ClassVar, Dict, Literal, Optional, Set, Type, Union
4+
from datetime import datetime, timedelta
5+
from jinja2 import Environment, Template, meta
6+
import requests
7+
import json
8+
9+
10+
# -------- Frequency --------
11+
@dataclass(frozen=True)
12+
class Frequency:
13+
"""
14+
The frequency of how often the report should be generated.
15+
The minimum frequency we support is 1 day.
16+
Attributes:
17+
value: Number of units (e.g., 7 for 7 days).
18+
unit: Unit of time, either "days" or "weeks".
19+
20+
Methods:
21+
to_timedelta: Convert frequency into a datetime.timedelta.
22+
get_text: return the frequency in text format
23+
"""
24+
value: int
25+
unit: Literal["days", "weeks"]
26+
def to_timedelta(self) -> timedelta:
27+
"""Convert frequency N days or M weeks into a datetime.timedelta."""
28+
if self.unit == "days":
29+
return timedelta(days=self.value)
30+
elif self.unit == "weeks":
31+
return timedelta(weeks=self.value)
32+
else:
33+
raise ValueError(f"Unsupported unit: {self.unit}")
34+
35+
def get_text(self):
36+
return f"{self.value} {self.unit}"
37+
38+
39+
# -------- Source --------
40+
_JINJA_ENV = Environment(autoescape=False)
41+
42+
@dataclass
43+
class BenchmarkApiSource:
44+
"""
45+
Defines the source of the benchmark data we want to query
46+
api_query_url: the url of the api to query
47+
api_endpoint_params_template: the jinjia2 template of the api endpoint's query params
48+
default_ctx: the default context to use when rendering the api_endpoint_params_template
49+
"""
50+
api_query_url: str
51+
api_endpoint_params_template: str
52+
type: Literal["benchmark_time_series_api", "other"] = "benchmark_time_series_api"
53+
default_ctx: Dict[str, Any] = field(default_factory=dict)
54+
55+
def required_template_vars(self) -> set[str]:
56+
ast = _JINJA_ENV.parse(self.api_endpoint_params_template)
57+
return set(meta.find_undeclared_variables(ast))
58+
59+
def render(self, ctx: Dict[str, Any], strict: bool = True) -> dict:
60+
"""Render with caller-supplied context (no special casing for start/end)."""
61+
merged = {**self.default_ctx, **ctx}
62+
63+
if strict:
64+
required = self.required_template_vars()
65+
missing = required - merged.keys()
66+
if missing:
67+
raise ValueError(f"Missing required vars: {missing}")
68+
rendered = Template(self.api_endpoint_params_template).render(**merged)
69+
return json.loads(rendered)
70+
71+
72+
# -------- Policy: range windows --------
73+
@dataclass
74+
class DayRangeWindow:
75+
value: int
76+
# raw indicates fetch from the source data
77+
source: Literal["raw"] = "raw"
78+
79+
@dataclass
80+
class RangeConfig:
81+
"""
82+
Defines the range of baseline and comparison windows for a given policy.
83+
- baseline: the baseline window that build the baseline value
84+
- comparison: the comparison window that we fetch data from to compare against the baseline value
85+
"""
86+
baseline: DayRangeWindow
87+
comparison: DayRangeWindow
88+
89+
def total_timedelta(self) -> timedelta:
90+
return timedelta(days=self.baseline.value + self.comparison.value)
91+
def comparison_timedelta(self) -> timedelta:
92+
return timedelta(days=self.comparison.value)
93+
def baseline_timedelta(self) -> timedelta:
94+
return timedelta(days=self.baseline.value)
95+
96+
# -------- Policy: metrics --------
97+
@dataclass
98+
class RegressionPolicy:
99+
"""
100+
Defines the policy for a given metric.
101+
- new value muset be {x} baseline value:
102+
- "greater_than": higher is better; new value must be strictly greater to baseline
103+
- "less_than": lower is better; new value must be strictly lower to baseline
104+
- "equal_to": new value should be ~= baseline * threshold within rel_tol
105+
- "greater_equal": higher is better; new value must be greater or equal to baseline
106+
- "less_equal": lower is better; new value must be less or equal to baseline
107+
"""
108+
name: str
109+
condition: Literal["greater_than", "less_than", "equal_to","greater_equal","less_equal"]
110+
threshold: float
111+
baseline_aggregation: Literal["avg", "max", "min", "p50", "p90", "p95","latest","earliest"] = "max"
112+
rel_tol: float = 1e-3 # used only for "equal_to"
113+
114+
def is_violation(self, value: float, baseline: float) -> bool:
115+
target = baseline * self.threshold
116+
117+
if self.condition == "greater_than":
118+
# value must be strictly greater than target
119+
return value <= target
120+
121+
if self.condition == "greater_equal":
122+
# value must be greater or equal to target
123+
return value < target
124+
125+
if self.condition == "less_than":
126+
# value must be strictly less than target
127+
return value >= target
128+
129+
if self.condition == "less_equal":
130+
# value must be less or equal to target
131+
return value > target
132+
133+
if self.condition == "equal_to":
134+
# |value - target| should be within rel_tol * max(1, |target|)
135+
denom = max(1.0, abs(target))
136+
return abs(value - target) > self.rel_tol * denom
137+
138+
raise ValueError(f"Unknown condition: {self.condition}")
139+
class BaseNotificationConfig:
140+
# every subclass must override this
141+
type_tag: ClassVar[str]
142+
143+
@classmethod
144+
def from_dict(cls: Type[T], d: Dict[str, Any]) -> T:
145+
# pick only known fields for this dataclass
146+
kwargs = {f.name: d.get(f.name) for f in fields(cls)}
147+
return cls(**kwargs) # type: ignore
148+
149+
@classmethod
150+
def matches(cls, d: Dict[str, Any]) -> bool:
151+
return d.get("type") == cls.type_tag
152+
153+
154+
@dataclass
155+
class GitHubNotificationConfig(BaseNotificationConfig):
156+
type: str = "github"
157+
repo: str = ""
158+
issue_number: str = ""
159+
type_tag: ClassVar[str] = "github"
160+
161+
def create_github_comment(self, body: str, github_token: str) -> Dict[str, Any]:
162+
"""
163+
Create a new comment on a GitHub issue.
164+
Args:
165+
notification_config: dict with keys:
166+
- type: must be "github"
167+
- repo: "owner/repo"
168+
- issue: issue number (string or int)
169+
body: text of the comment
170+
token: GitHub personal access token or GitHub Actions token
171+
172+
Returns:
173+
The GitHub API response as a dict (JSON).
174+
"""
175+
url = f"https://api.github.com/repos/{self.repo}/issues/{self.issue_number}/comments"
176+
headers = {
177+
"Authorization": f"token {github_token}",
178+
"Accept": "application/vnd.github+json",
179+
"User-Agent": "bench-reporter/1.0",
180+
}
181+
resp = requests.post(url, headers=headers, json={"body": body})
182+
resp.raise_for_status()
183+
return resp.json()
184+
185+
@dataclass
186+
class Policy:
187+
frequency: Frequency
188+
range: RangeConfig
189+
metrics: Dict[str, RegressionPolicy]
190+
notification_config: Optional[Dict[str, Any]] = None
191+
192+
def get_github_notification_config(self) -> Optional[GitHubNotificationConfig]:
193+
if not self.notification_config:
194+
return None
195+
return notification_from_dict(self.notification_config) # type: ignore
196+
197+
198+
# -------- Top-level benchmark regression config --------
199+
@dataclass
200+
class BenchmarkConfig:
201+
"""
202+
Represents a single benchmark regression configuration.
203+
204+
- BenchmarkConfig defines the benchmark regression config for a given benchmark.
205+
- source: defines the source of the benchmark data we want to query
206+
- policy: defines the policy for the benchmark regressions
207+
- name: the name of the benchmark
208+
- id: the id of the benchmark, this must be unique for each benchmark, and cannot be changed once set
209+
"""
210+
name: str
211+
id: str
212+
source: BenchmarkApiSource
213+
policy: Policy
214+
215+
216+
@dataclass
217+
class BenchmarkRegressionConfigBook:
218+
configs: Dict[str, BenchmarkConfig] = field(default_factory=dict)
219+
220+
def __getitem__(self, key: str) -> BenchmarkConfig:
221+
config = self.configs.get(key, None)
222+
if not config:
223+
raise KeyError(f"Config {key} not found")
224+
return config

0 commit comments

Comments
 (0)