Skip to content

[WIP][AQUA] GPU Shape Recommendation #1221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
12 changes: 7 additions & 5 deletions ads/aqua/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,20 @@ def _validate_value(flag, value):
"If you intend to chain a function call to the result, please separate the "
"flag and the subsequent function call with separator `-`."
)

@staticmethod
def install():
"""Install ADS Aqua Extension from wheel file. Set enviroment variable `AQUA_EXTENSTION_PATH` to change the wheel file path.

Return
Return
------
int:
Installatation status.
"""
import subprocess

wheel_file_path = os.environ.get("AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl")
status = subprocess.run(f"pip install {wheel_file_path}",shell=True)
return status.check_returncode
wheel_file_path = os.environ.get(
"AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl"
)
status = subprocess.run(f"pip install {wheel_file_path}", shell=True, check=False)
return status.check_returncode
21 changes: 21 additions & 0 deletions ads/aqua/common/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ class Config:
arbitrary_types_allowed = True
protected_namespaces = ()

class ComputeRank(Serializable):
"""
Represents the cost and performance ranking for a compute shape.
"""
cost: int = Field(
None, description="The relative rank of the cost of the shape. Range is [10 (cost-effective), 100 (most-expensive)]"
)

performance: int = Field(
None, description="The relative rank of the performance of the shape. Range is [10 (lower performance), 110 (highest performance)]"
)

class GPUSpecs(Serializable):
"""
Expand All @@ -61,6 +72,12 @@ class GPUSpecs(Serializable):
gpu_type: Optional[str] = Field(
default=None, description="The type of GPU (e.g., 'V100, A100, H100')."
)
quantization: Optional[List[str]] = Field(
default_factory=list, description="The quantization format supported by shape. (ex. bitsandbytes, fp8, etc.)"
)
ranking: Optional[ComputeRank] = Field(
None, description="The relative rank of the cost and performance of the shape."
)


class GPUShapesIndex(Serializable):
Expand All @@ -84,6 +101,10 @@ class ComputeShapeSummary(Serializable):
including CPU, memory, and optional GPU characteristics.
"""

available: Optional[bool] = Field(
default = False,
description="True if shape is available on user tenancy, "
)
core_count: Optional[int] = Field(
default=None,
description="Total number of CPU cores available for the compute shape.",
Expand Down
5 changes: 5 additions & 0 deletions ads/aqua/common/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class AquaValueError(AquaError, ValueError):
def __init__(self, reason, status=403, service_payload=None):
super().__init__(reason, status, service_payload)

class AquaRecommendationError(AquaError):
"""Exception raised for models incompatible with shape recommendation tool."""

def __init__(self, reason, status=400, service_payload=None):
super().__init__(reason, status, service_payload)

class AquaFileNotFoundError(AquaError, FileNotFoundError):
"""Exception raised for missing target file."""
Expand Down
56 changes: 49 additions & 7 deletions ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,10 +1229,10 @@ def load_gpu_shapes_index(
auth: Optional[Dict[str, Any]] = None,
) -> GPUShapesIndex:
"""
Load the GPU shapes index, preferring the OS bucket copy over the local one.
Load the GPU shapes index, merging based on freshness.

Attempts to read `gpu_shapes_index.json` from OCI Object Storage first;
if that succeeds, those entries will override the local defaults.
Compares last-modified timestamps of local and remote files,
merging the shapes from the fresher file on top of the older one.

Parameters
----------
Expand All @@ -1253,7 +1253,9 @@ def load_gpu_shapes_index(
file_name = "gpu_shapes_index.json"

# Try remote load
remote_data: Dict[str, Any] = {}
local_data, remote_data = {}, {}
local_mtime, remote_mtime = None, None

if CONDA_BUCKET_NS:
try:
auth = auth or authutil.default_signer()
Expand All @@ -1263,8 +1265,24 @@ def load_gpu_shapes_index(
logger.debug(
"Loading GPU shapes index from Object Storage: %s", storage_path
)
with fsspec.open(storage_path, mode="r", **auth) as f:

fs = fsspec.filesystem("oci", **auth)
with fs.open(storage_path, mode="r") as f:
remote_data = json.load(f)

remote_info = fs.info(storage_path)
remote_mtime_str = remote_info.get("timeModified", None)
if remote_mtime_str:
# Convert OCI timestamp (e.g., 'Mon, 04 Aug 2025 06:37:13 GMT') to epoch time
remote_mtime = datetime.strptime(
remote_mtime_str, "%a, %d %b %Y %H:%M:%S %Z"
).timestamp()

logger.debug(
"Remote GPU shapes last-modified time: %s",
datetime.fromtimestamp(remote_mtime).strftime("%Y-%m-%d %H:%M:%S"),
)

logger.debug(
"Loaded %d shapes from Object Storage",
len(remote_data.get("shapes", {})),
Expand All @@ -1273,12 +1291,19 @@ def load_gpu_shapes_index(
logger.debug("Remote load failed (%s); falling back to local", ex)

# Load local copy
local_data: Dict[str, Any] = {}
local_path = os.path.join(os.path.dirname(__file__), "../resources", file_name)
try:
logger.debug("Loading GPU shapes index from local file: %s", local_path)
with open(local_path) as f:
local_data = json.load(f)

local_mtime = os.path.getmtime(local_path)

logger.debug(
"Local GPU shapes last-modified time: %s",
datetime.fromtimestamp(local_mtime).strftime("%Y-%m-%d %H:%M:%S"),
)

logger.debug(
"Loaded %d shapes from local file", len(local_data.get("shapes", {}))
)
Expand All @@ -1288,7 +1313,24 @@ def load_gpu_shapes_index(
# Merge: remote shapes override local
local_shapes = local_data.get("shapes", {})
remote_shapes = remote_data.get("shapes", {})
merged_shapes = {**local_shapes, **remote_shapes}
merged_shapes = {}

if local_mtime and remote_mtime:
if remote_mtime >= local_mtime:
logger.debug("Remote data is fresher or equal; merging remote over local.")
merged_shapes = {**local_shapes, **remote_shapes}
else:
logger.debug("Local data is fresher; merging local over remote.")
merged_shapes = {**remote_shapes, **local_shapes}
elif remote_shapes:
logger.debug("Only remote shapes available.")
merged_shapes = remote_shapes
elif local_shapes:
logger.debug("Only local shapes available.")
merged_shapes = local_shapes
else:
logger.error("No GPU shapes data found in either source.")
merged_shapes = {}

return GPUShapesIndex(shapes=merged_shapes)

Expand Down
36 changes: 36 additions & 0 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ def get(self, id: Union[str, List[str]] = None):
return self.get_deployment_config(
model_id=id.split(",") if "," in id else id
)
elif paths.startswith("aqua/deployments/recommend_shapes"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: /recommended_shapes would be better.

if not id or not isinstance(id, str):
raise HTTPError(
400,
f"Invalid request format for {self.request.path}. "
"Expected a single model OCID specified as --model_id",
)
id = id.replace(" ", "")
return self.get_recommend_shape(model_id=id)
elif paths.startswith("aqua/deployments/shapes"):
return self.list_shapes()
elif paths.startswith("aqua/deployments"):
Expand Down Expand Up @@ -161,6 +170,32 @@ def get_deployment_config(self, model_id: Union[str, List[str]]):

return self.finish(deployment_config)

def get_recommend_shape(self, model_id: str):
"""
Retrieves the valid shape and deployment parameter configuration for one Aqua Model.

Parameters
----------
model_id : str
A single model ID (str).

Returns
-------
None
The function sends the ShapeRecommendReport (generate_table = False) or Rich Diff Table (generate_table = True)
"""
app = AquaDeploymentApp()

compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)

recommend_report = app.recommend_shape(
model_id=model_id,
compartment_id=compartment_id,
generate_table=False,
)

return self.finish(recommend_report)

def list_shapes(self):
"""
Lists the valid model deployment shapes.
Expand Down Expand Up @@ -408,6 +443,7 @@ def get(self, model_deployment_id):
("deployments/?([^/]*)/params", AquaDeploymentParamsHandler),
("deployments/config/?([^/]*)", AquaDeploymentHandler),
("deployments/shapes/?([^/]*)", AquaDeploymentHandler),
("deployments/recommend_shapes/?([^/]*)", AquaDeploymentHandler),
("deployments/?([^/]*)", AquaDeploymentHandler),
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
Expand Down
1 change: 1 addition & 0 deletions ads/aqua/modeldeployment/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@

DEFAULT_WAIT_TIME = 12000
DEFAULT_POLL_INTERVAL = 10

57 changes: 55 additions & 2 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import shlex
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from cachetools import TTLCache, cached
from oci.data_science.models import ModelDeploymentShapeSummary
from pydantic import ValidationError
from rich.table import Table

from ads.aqua.app import AquaApp, logger
from ads.aqua.common.entities import (
Expand Down Expand Up @@ -63,14 +64,22 @@
ModelDeploymentConfigSummary,
MultiModelDeploymentConfigLoader,
)
from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME
from ads.aqua.modeldeployment.constants import (
DEFAULT_POLL_INTERVAL,
DEFAULT_WAIT_TIME,
)
from ads.aqua.modeldeployment.entities import (
AquaDeployment,
AquaDeploymentDetail,
ConfigValidationError,
CreateModelDeploymentDetails,
)
from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig
from ads.aqua.shaperecommend.recommend import AquaShapeRecommend
from ads.aqua.shaperecommend.shape_report import (
RequestRecommend,
ShapeRecommendationReport,
)
from ads.common.object_storage_details import ObjectStorageDetails
from ads.common.utils import UNKNOWN, get_log_links
from ads.common.work_request import DataScienceWorkRequest
Expand Down Expand Up @@ -1243,6 +1252,50 @@ def validate_deployment_params(
)
return {"valid": True}

def recommend_shape(self, **kwargs) -> Union[Table, ShapeRecommendationReport]:
"""
For the CLI (set generate_table = True), generates the table (in rich diff) with valid
GPU deployment shapes for the provided model and configuration.

For the API (set generate_table = False), generates the JSON with valid
GPU deployment shapes for the provided model and configuration.

Validates if recommendations are generated, calls method to construct the rich diff
table with the recommendation data.

Parameters
----------
model_ocid : str
OCID of the model to recommend feasible compute shapes.

Returns
-------
Table (generate_table = True)
A table format for the recommendation report with compatible deployment shapes
or troubleshooting info citing the largest shapes if no shape is suitable.

ShapeRecommendationReport (generate_table = False)
A recommendation report with compatible deployment shapes, or troubleshooting info
citing the largest shapes if no shape is suitable.

Raises
------
AquaValueError
If model type is unsupported by tool (no recommendation report generated)
"""
try:
request = RequestRecommend(**kwargs)
except ValidationError as e:
custom_error = build_pydantic_error_message(e)
raise AquaValueError( # noqa: B904
f"Failed to request shape recommendation due to invalid input parameters: {custom_error}"
)

shape_recommend = AquaShapeRecommend()
shape_recommend_report = shape_recommend.which_shapes(request)

return shape_recommend_report

@telemetry(entry_point="plugin=deployment&action=list_shapes", name="aqua")
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:
Expand Down
Loading
Loading