Skip to content

Commit 99e95f9

Browse files
committed
fix bug2
ghstack-source-id: 5875715 Pull-Request: #7092
1 parent 83cb0ce commit 99e95f9

File tree

4 files changed

+794
-0
lines changed

4 files changed

+794
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any, Dict, List
3+
4+
import requests
5+
6+
7+
# The data class to provide api response model from get_time_series api
8+
9+
10+
@dataclass
11+
class TimeRange:
12+
start: str
13+
end: str
14+
15+
16+
@dataclass
17+
class BenchmarkTimeSeriesItem:
18+
group_info: Dict[str, Any]
19+
num_of_dp: int
20+
data: List[Dict[str, Any]] = field(default_factory=list)
21+
22+
23+
@dataclass
24+
class BenchmarkTimeSeriesApiData:
25+
time_series: List[BenchmarkTimeSeriesItem]
26+
time_range: TimeRange
27+
28+
29+
@dataclass
30+
class BenchmarkTimeSeriesApiResponse:
31+
data: BenchmarkTimeSeriesApiData
32+
33+
@classmethod
34+
def from_request(
35+
cls, url: str, query: dict, timeout: int = 180
36+
) -> "BenchmarkTimeSeriesApiResponse":
37+
"""
38+
Send a POST request and parse into BenchmarkTimeSeriesApiResponse.
39+
40+
Args:
41+
url: API endpoint
42+
query: JSON payload must
43+
timeout: max seconds to wait for connect + response (default: 30)
44+
Returns:
45+
ApiResponse
46+
Raises:
47+
requests.exceptions.RequestException if network/timeout/HTTP error
48+
RuntimeError if the API returns an "error" field or malformed data
49+
"""
50+
resp = requests.post(url, json=query, timeout=timeout)
51+
resp.raise_for_status()
52+
payload = resp.json()
53+
54+
if "error" in payload:
55+
raise RuntimeError(f"API error: {payload['error']}")
56+
try:
57+
tr = TimeRange(**payload["data"]["time_range"])
58+
ts = [
59+
BenchmarkTimeSeriesItem(**item)
60+
for item in payload["data"]["time_series"]
61+
]
62+
except Exception as e:
63+
raise RuntimeError(f"Malformed API payload: {e}")
64+
return cls(data=BenchmarkTimeSeriesApiData(time_series=ts, time_range=tr))
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from common.config_model import (
2+
BenchmarkApiSource,
3+
BenchmarkConfig,
4+
BenchmarkRegressionConfigBook,
5+
DayRangeWindow,
6+
Frequency,
7+
Policy,
8+
RangeConfig,
9+
RegressionPolicy,
10+
)
11+
12+
13+
# Compiler benchmark regression config
14+
# todo(elainewy): eventually each team should configure
15+
# their own benchmark regression config, currenlty place
16+
# here for lambda
17+
18+
19+
COMPILER_BENCHMARK_CONFIG = BenchmarkConfig(
20+
name="Compiler Benchmark Regression",
21+
id="compiler_regression",
22+
source=BenchmarkApiSource(
23+
api_query_url="https://hud.pytorch.org/api/benchmark/get_time_series",
24+
type="benchmark_time_series_api",
25+
# currently we only detect the regression for h100 with dtype bfloat16, and mode inference
26+
# we can extend this to other devices, dtypes and mode in the future
27+
api_endpoint_params_template="""
28+
{
29+
"name": "compiler_precompute",
30+
"query_params": {
31+
"commits": [],
32+
"compilers": [],
33+
"arch": "h100",
34+
"device": "cuda",
35+
"dtype": "bfloat16",
36+
"granularity": "hour",
37+
"mode": "inference",
38+
"startTime": "{{ startTime }}",
39+
"stopTime": "{{ stopTime }}",
40+
"suites": ["torchbench", "huggingface", "timm_models"],
41+
"workflowId": 0,
42+
"branches": ["main"]
43+
}
44+
}
45+
""",
46+
),
47+
# set baseline from past 7 days using avg, and compare with the last 1 day
48+
policy=Policy(
49+
frequency=Frequency(value=1, unit="days"),
50+
range=RangeConfig(
51+
baseline=DayRangeWindow(value=7),
52+
comparison=DayRangeWindow(value=2),
53+
),
54+
metrics={
55+
"passrate": RegressionPolicy(
56+
name="passrate",
57+
condition="greater_equal",
58+
threshold=0.9,
59+
baseline_aggregation="max",
60+
),
61+
"geomean": RegressionPolicy(
62+
name="geomean",
63+
condition="greater_equal",
64+
threshold=0.95,
65+
baseline_aggregation="max",
66+
),
67+
"compression_ratio": RegressionPolicy(
68+
name="compression_ratio",
69+
condition="greater_equal",
70+
threshold=0.9,
71+
baseline_aggregation="max",
72+
),
73+
},
74+
notification_config={
75+
"type": "github",
76+
"repo": "pytorch/test-infra",
77+
"issue": "7081",
78+
},
79+
),
80+
)
81+
82+
BENCHMARK_REGRESSION_CONFIG = BenchmarkRegressionConfigBook(
83+
configs={
84+
"compiler_regression": COMPILER_BENCHMARK_CONFIG,
85+
}
86+
)
87+
88+
89+
def get_benchmark_regression_config(config_id: str) -> BenchmarkConfig:
90+
"""Get benchmark regression config by config id"""
91+
try:
92+
return BENCHMARK_REGRESSION_CONFIG[config_id]
93+
except KeyError:
94+
raise ValueError(f"Invalid config id: {config_id}")
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from dataclasses import dataclass, field
5+
from datetime import timedelta
6+
from typing import Any, Dict, Literal, Optional
7+
8+
from jinja2 import Environment, meta, Template
9+
10+
11+
# -------- Frequency --------
12+
@dataclass(frozen=True)
13+
class Frequency:
14+
"""
15+
The frequency of how often the report should be generated.
16+
The minimum frequency we support is 1 day.
17+
Attributes:
18+
value: Number of units (e.g., 7 for 7 days).
19+
unit: Unit of time, either "days" or "weeks".
20+
21+
Methods:
22+
to_timedelta: Convert frequency into a datetime.timedelta.
23+
get_text: return the frequency in text format
24+
"""
25+
26+
value: int
27+
unit: Literal["days", "weeks"]
28+
29+
def to_timedelta(self) -> timedelta:
30+
"""Convert frequency N days or M weeks into a datetime.timedelta."""
31+
if self.unit == "days":
32+
return timedelta(days=self.value)
33+
elif self.unit == "weeks":
34+
return timedelta(weeks=self.value)
35+
else:
36+
raise ValueError(f"Unsupported unit: {self.unit}")
37+
38+
def get_text(self):
39+
return f"{self.value} {self.unit}"
40+
41+
42+
# -------- Source --------
43+
_JINJA_ENV = Environment(autoescape=False)
44+
45+
46+
@dataclass
47+
class BenchmarkApiSource:
48+
"""
49+
Defines the source of the benchmark data we want to query
50+
api_query_url: the url of the api to query
51+
api_endpoint_params_template: the jinjia2 template of the api endpoint's query params
52+
default_ctx: the default context to use when rendering the api_endpoint_params_template
53+
"""
54+
55+
api_query_url: str
56+
api_endpoint_params_template: str
57+
type: Literal["benchmark_time_series_api", "other"] = "benchmark_time_series_api"
58+
default_ctx: Dict[str, Any] = field(default_factory=dict)
59+
60+
def required_template_vars(self) -> set[str]:
61+
ast = _JINJA_ENV.parse(self.api_endpoint_params_template)
62+
return set(meta.find_undeclared_variables(ast))
63+
64+
def render(self, ctx: Dict[str, Any], strict: bool = True) -> dict:
65+
"""Render with caller-supplied context (no special casing for start/end)."""
66+
merged = {**self.default_ctx, **ctx}
67+
68+
if strict:
69+
required = self.required_template_vars()
70+
missing = required - merged.keys()
71+
if missing:
72+
raise ValueError(f"Missing required vars: {missing}")
73+
rendered = Template(self.api_endpoint_params_template).render(**merged)
74+
return json.loads(rendered)
75+
76+
77+
# -------- Policy: range windows --------
78+
@dataclass
79+
class DayRangeWindow:
80+
value: int
81+
# raw indicates fetch from the source data
82+
source: Literal["raw"] = "raw"
83+
84+
85+
@dataclass
86+
class RangeConfig:
87+
"""
88+
Defines the range of baseline and comparison windows for a given policy.
89+
- baseline: the baseline window that build the baseline value
90+
- comparison: the comparison window that we fetch data from to compare against the baseline value
91+
"""
92+
93+
baseline: DayRangeWindow
94+
comparison: DayRangeWindow
95+
96+
def total_timedelta(self) -> timedelta:
97+
return timedelta(days=self.baseline.value + self.comparison.value)
98+
99+
def comparison_timedelta(self) -> timedelta:
100+
return timedelta(days=self.comparison.value)
101+
102+
def baseline_timedelta(self) -> timedelta:
103+
return timedelta(days=self.baseline.value)
104+
105+
106+
# -------- Policy: metrics --------
107+
@dataclass
108+
class RegressionPolicy:
109+
"""
110+
Defines the policy for a given metric.
111+
- new value muset be {x} baseline value:
112+
- "greater_than": higher is better; new value must be strictly greater to baseline
113+
- "less_than": lower is better; new value must be strictly lower to baseline
114+
- "equal_to": new value should be ~= baseline * threshold within rel_tol
115+
- "greater_equal": higher is better; new value must be greater or equal to baseline
116+
- "less_equal": lower is better; new value must be less or equal to baseline
117+
"""
118+
119+
name: str
120+
condition: Literal[
121+
"greater_than", "less_than", "equal_to", "greater_equal", "less_equal"
122+
]
123+
threshold: float
124+
baseline_aggregation: Literal[
125+
"avg", "max", "min", "p50", "p90", "p95", "latest", "earliest"
126+
] = "max"
127+
rel_tol: float = 1e-3 # used only for "equal_to"
128+
129+
def is_violation(self, value: float, baseline: float) -> bool:
130+
target = baseline * self.threshold
131+
132+
if self.condition == "greater_than":
133+
# value must be strictly greater than target
134+
return value <= target
135+
136+
if self.condition == "greater_equal":
137+
# value must be greater or equal to target
138+
return value < target
139+
140+
if self.condition == "less_than":
141+
# value must be strictly less than target
142+
return value >= target
143+
144+
if self.condition == "less_equal":
145+
# value must be less or equal to target
146+
return value > target
147+
148+
if self.condition == "equal_to":
149+
# |value - target| should be within rel_tol * max(1, |target|)
150+
denom = max(1.0, abs(target))
151+
return abs(value - target) > self.rel_tol * denom
152+
153+
raise ValueError(f"Unknown condition: {self.condition}")
154+
155+
156+
@dataclass
157+
class Policy:
158+
frequency: Frequency
159+
range: RangeConfig
160+
metrics: Dict[str, RegressionPolicy]
161+
162+
# TODO(elainewy): add notification config
163+
notification_config: Optional[Dict[str, Any]] = None
164+
165+
166+
# -------- Top-level benchmark regression config --------
167+
@dataclass
168+
class BenchmarkConfig:
169+
"""
170+
Represents a single benchmark regression configuration.
171+
- BenchmarkConfig defines the benchmark regression config for a given benchmark.
172+
- source: defines the source of the benchmark data we want to query
173+
- policy: defines the policy for the benchmark regressions, including frequency to
174+
generate the report, range of the baseline and new values, and regression thresholds
175+
for metrics
176+
- name: the name of the benchmark
177+
- id: the id of the benchmark, this must be unique for each benchmark, and cannot be changed once set
178+
"""
179+
180+
name: str
181+
id: str
182+
source: BenchmarkApiSource
183+
policy: Policy
184+
185+
186+
@dataclass
187+
class BenchmarkRegressionConfigBook:
188+
configs: Dict[str, BenchmarkConfig] = field(default_factory=dict)
189+
190+
def __getitem__(self, key: str) -> BenchmarkConfig:
191+
config = self.configs.get(key)
192+
if not config:
193+
raise KeyError(f"Config {key} not found")
194+
return config

0 commit comments

Comments
 (0)