diff --git a/docs/auth.md b/docs/auth.md index fd795f02..ec55b860 100644 --- a/docs/auth.md +++ b/docs/auth.md @@ -249,6 +249,8 @@ authorization: - `info` - Access the `/` endpoint, `/info` endpoint, `/readiness` endpoint, and `/liveness` endpoint - `get_config` - Access the `/config` endpoint - `get_models` - Access the `/models` endpoint +- `list_providers` - Access the `/providers` endpoint +- `get_provider` - Access the `/providers/{provider_id}` endpoint - `get_metrics` - Access the `/metrics` endpoint - `list_conversations` - Access the `/conversations` endpoint - `list_other_conversations` - Access conversations not owned by the user diff --git a/docs/openapi.json b/docs/openapi.json index 34492061..adb3be69 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -203,6 +203,119 @@ } } }, + "/v1/providers": { + "get": { + "tags": [ + "providers" + ], + "summary": "Providers Endpoint Handler", + "description": "Handle requests to the /providers endpoint.\n\nProcess GET requests to the /providers endpoint, returning a list of available\nproviders from the Llama Stack service.\n\nRaises:\n HTTPException: If unable to connect to the Llama Stack server or if\n providers retrieval fails for any reason.\n\nReturns:\n ProvidersListResponse: An object containing the list of available providers.", + "operationId": "providers_endpoint_handler_v1_providers_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProvidersListResponse" + } + } + }, + + "providers": { + "agents": [ + { + "provider_id": "meta-reference", + "provider_type": "inline::meta-reference" + } + ], + "datasetio": [ + { + "provider_id": "huggingface", + "provider_type": "remote::huggingface" + }, + { + "provider_id": "localfs", + "provider_type": "inline::localfs" + } + ], + "inference": [ + { + "provider_id": "sentence-transformers", + "provider_type": "inline::sentence-transformers" + }, + { + "provider_id": "openai", + "provider_type": "remote::openai" + } + ] + } + }, + "500": { + "description": "Connection to Llama Stack is broken" + } + } + } + }, + "/v1/providers/{provider_id}": { + "get": { + "summary": "Retrieve a single provider by ID", + "description": "Fetches detailed information about a specific provider, including its API, configuration, health status, provider ID, and type. Returns a 404 error if the provider with the specified ID does not exist, or a 500 error if there is a problem connecting to the Llama Stack service.", + "parameters": [ + { + "name": "provider_id", + "in": "path", + "required": true, + "description": "Unique identifier of the provider to retrieve", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Provider found successfully", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ProviderResponse" }, + "example": { + "api": "inference", + "config": {"api_key": "********"}, + "health": { + "status": "Not Implemented", + "message": "Provider does not implement health check" + }, + "provider_id": "openai", + "provider_type": "remote::openai" + } + } + } + }, + "404": { + "description": "Provider with the specified ID was not found", + "content": { + "application/json": { + "example": { + "response": "Provider with given id not found" + } + } + } + }, + "500": { + "description": "Unable to retrieve provider due to server error or connection issues", + "content": { + "application/json": { + "example": { + "response": "Unable to retrieve list of providers", + "cause": "Connection to Llama Stack is broken" + } + } + } + } + }, + "tags": ["providers"] + } + }, "/v1/query": { "post": { "tags": [ @@ -1246,6 +1359,8 @@ "get_models", "get_tools", "get_shields", + "list_providers", + "get_provider", "get_metrics", "get_config", "info", @@ -2762,6 +2877,105 @@ "title": "ProviderHealthStatus", "description": "Model representing the health status of a provider.\n\nAttributes:\n provider_id: The ID of the provider.\n status: The health status ('ok', 'unhealthy', 'not_implemented').\n message: Optional message about the health status." }, + "ProviderResponse": { + "type": "object", + "title": "ProviderResponse", + "description": "Model representing a provider and its configuration, health, and identification details.", + "properties": { + "api": { + "type": "string", + "description": "The API name this provider implements" + }, + "config": { + "type": "object", + "description": "Configuration parameters for the provider", + "additionalProperties": true + }, + "health": { + "type": "object", + "description": "Current health status of the provider", + "additionalProperties": true + }, + "provider_id": { + "type": "string", + "description": "Unique identifier for the provider" + }, + "provider_type": { + "type": "string", + "description": "The type of provider implementation" + } + }, + "required": [ + "api", + "config", + "health", + "provider_id", + "provider_type" + ], + "example": { + "api": "inference", + "config": {"api_key": "********"}, + "health": { + "status": "Not Implemented", + "message": "Provider does not implement health check" + }, + "provider_id": "openai", + "provider_type": "remote::openai" + } + }, + "ProvidersListResponse": { + "type": "object", + "properties": { + "providers": { + "type": "object", + "description": "Mapping of API type to list of its available providers", + "additionalProperties": { + "type": "array", + "items": { + "type": "object", + "properties": { + "provider_id": { + "type": "string", + "description": "Unique local identifier of provider" + }, + "provider_type": { + "type": "string", + "description": "Llama stack identifier of provider (following schema ::)" + } + }, + "required": ["provider_id", "provider_type"] + } + }, + "examples": [ + { + "inference": [ + { + "provider_id": "sentence-transformers", + "provider_type": "inline::sentence-transformers" + }, + { + "provider_id": "openai", + "provider_type": "remote::openai" + } + ], + "datasetio": [ + { + "provider_id": "huggingface", + "provider_type": "remote::huggingface" + }, + { + "provider_id": "localfs", + "provider_type": "inline::localfs" + } + ] + } + ] + } + }, + "required": ["providers"], + "title": "ProvidersListResponse", + "description": "Model representing a response to providers request." + }, "QueryRequest": { "properties": { "query": { diff --git a/src/app/endpoints/providers.py b/src/app/endpoints/providers.py new file mode 100644 index 00000000..426804bd --- /dev/null +++ b/src/app/endpoints/providers.py @@ -0,0 +1,215 @@ +"""Handler for REST API calls to list and retrieve available providers.""" + +import logging +from typing import Annotated, Any + +from fastapi import APIRouter, HTTPException, Request, status +from fastapi.params import Depends +from llama_stack_client import APIConnectionError + +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from client import AsyncLlamaStackClientHolder +from configuration import configuration +from models.config import Action +from models.responses import ProvidersListResponse, ProviderResponse +from utils.endpoints import check_configuration_loaded + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["providers"]) + + +providers_responses: dict[int | str, dict[str, Any]] = { + 200: { + "providers": { + "agents": [ + { + "provider_id": "meta-reference", + "provider_type": "inline::meta-reference", + } + ], + "datasetio": [ + {"provider_id": "huggingface", "provider_type": "remote::huggingface"}, + {"provider_id": "localfs", "provider_type": "inline::localfs"}, + ], + "inference": [ + { + "provider_id": "sentence-transformers", + "provider_type": "inline::sentence-transformers", + }, + {"provider_id": "openai", "provider_type": "remote::openai"}, + ], + } + }, + 500: {"description": "Connection to Llama Stack is broken"}, +} + +provider_responses: dict[int | str, dict[str, Any]] = { + 200: { + "api": "inference", + "config": {"api_key": "********"}, + "health": { + "status": "Not Implemented", + "message": "Provider does not implement health check", + }, + "provider_id": "openai", + "provider_type": "remote::openai", + }, + 404: {"response": "Provider with given id not found"}, + 500: { + "response": "Unable to retrieve list of providers", + "cause": "Connection to Llama Stack is broken", + }, +} + + +@router.get("/providers", responses=providers_responses) +@authorize(Action.LIST_PROVIDERS) +async def providers_endpoint_handler( + request: Request, + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], +) -> ProvidersListResponse: + """ + Handle GET requests to list all available providers. + + Retrieves providers from the Llama Stack service, groups them by API type. + + Raises: + HTTPException: + - 500 if configuration is not loaded, + - 500 if unable to connect to Llama Stack, + - 500 for any unexpected retrieval errors. + + Returns: + ProvidersListResponse: Object mapping API types to lists of providers. + """ + # Used only by the middleware + _ = auth + + # Nothing interesting in the request + _ = request + + check_configuration_loaded(configuration) + + llama_stack_configuration = configuration.llama_stack_configuration + logger.info("Llama stack config: %s", llama_stack_configuration) + + try: + # try to get Llama Stack client + client = AsyncLlamaStackClientHolder().get_client() + # retrieve providers + providers = await client.providers.list() + providers = [dict(p) for p in providers] + return ProvidersListResponse(providers=group_providers(providers)) + + # connection to Llama Stack server + except APIConnectionError as e: + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + # any other exception that can occur during model listing + except Exception as e: + logger.error("Unable to retrieve list of providers: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to retrieve list of providers", + "cause": str(e), + }, + ) from e + + +def group_providers(providers: list[dict[str, Any]]) -> dict[str, list[dict[str, Any]]]: + """Group a list of providers by their API type. + + Args: + providers: List of provider dictionaries. Each must contain + 'api', 'provider_id', and 'provider_type' keys. + + Returns: + Mapping from API type to list of providers containing + only 'provider_id' and 'provider_type'. + """ + result: dict[str, list[dict[str, Any]]] = {} + for provider in providers: + result.setdefault(provider["api"], []).append( + { + "provider_id": provider["provider_id"], + "provider_type": provider["provider_type"], + } + ) + return result + + +@router.get("/providers/{provider_id}", responses=provider_responses) +@authorize(Action.GET_PROVIDER) +async def get_provider_endpoint_handler( + request: Request, + provider_id: str, + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], +) -> ProviderResponse: + """Retrieve a single provider by its unique ID. + + Raises: + HTTPException: + - 404 if provider with the given ID is not found, + - 500 if unable to connect to Llama Stack, + - 500 for any unexpected retrieval errors. + + Returns: + ProviderResponse: A single provider's details including API, config, health, + provider_id, and provider_type. + """ + # Used only by the middleware + _ = auth + + # Nothing interesting in the request + _ = request + + check_configuration_loaded(configuration) + + llama_stack_configuration = configuration.llama_stack_configuration + logger.info("Llama stack config: %s", llama_stack_configuration) + + try: + # try to get Llama Stack client + client = AsyncLlamaStackClientHolder().get_client() + # retrieve providers + providers = await client.providers.list() + p = [dict(p) for p in providers] + match = next((item for item in p if item["provider_id"] == provider_id), None) + if not match: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"response": f"Provider with id '{provider_id}' not found"}, + ) + return ProviderResponse(**match) + + # connection to Llama Stack server + except HTTPException: + raise + except APIConnectionError as e: + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + # any other exception that can occur during model listing + except Exception as e: + logger.error("Unable to retrieve list of providers: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to retrieve list of providers", + "cause": str(e), + }, + ) from e diff --git a/src/app/routers.py b/src/app/routers.py index 7cd98203..66c70766 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -6,6 +6,7 @@ info, models, shields, + providers, root, query, health, @@ -31,6 +32,7 @@ def include_routers(app: FastAPI) -> None: app.include_router(models.router, prefix="/v1") app.include_router(tools.router, prefix="/v1") app.include_router(shields.router, prefix="/v1") + app.include_router(providers.router, prefix="/v1") app.include_router(query.router, prefix="/v1") app.include_router(streaming_query.router, prefix="/v1") app.include_router(config.router, prefix="/v1") diff --git a/src/models/config.py b/src/models/config.py index 544cefa8..82fa2ece 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -364,6 +364,8 @@ class Action(str, Enum): GET_MODELS = "get_models" GET_TOOLS = "get_tools" GET_SHIELDS = "get_shields" + LIST_PROVIDERS = "list_providers" + GET_PROVIDER = "get_provider" GET_METRICS = "get_metrics" GET_CONFIG = "get_config" diff --git a/src/models/responses.py b/src/models/responses.py index e89f0b36..fc2916dc 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -1,9 +1,10 @@ """Models for REST API responses.""" -from typing import Any, Optional +from typing import Any, Optional, Union from pydantic import AnyUrl, BaseModel, Field +from llama_stack_client.types import ProviderInfo from models.cache_entry import ConversationData @@ -85,6 +86,65 @@ class ShieldsResponse(BaseModel): ) +class ProvidersListResponse(BaseModel): + """Model representing a response to providers request.""" + + providers: dict[str, list[dict[str, Any]]] = Field( + ..., + description="List of available API types and their corresponding providers", + examples=[ + { + "inference": [ + { + "provider_id": "sentence-transformers", + "provider_type": "inline::sentence-transformers", + }, + {"provider_id": "openai", "provider_type": "remote::openai"}, + ], + "agents": [ + { + "provider_id": "meta-reference", + "provider_type": "inline::meta-reference", + }, + ], + "datasetio": [ + { + "provider_id": "huggingface", + "provider_type": "remote::huggingface", + }, + {"provider_id": "localfs", "provider_type": "inline::localfs"}, + ], + }, + ], + ) + + +class ProviderResponse(ProviderInfo): + """Model representing a response to get specific provider request.""" + + api: str = Field( + ..., + description="The API this provider implements", + example="inference", + ) # type: ignore + config: dict[str, Union[bool, float, str, list[Any], object, None]] = Field( + ..., + description="Provider configuration parameters", + example={"api_key": "********"}, + ) # type: ignore + health: dict[str, Union[bool, float, str, list[Any], object, None]] = Field( + ..., + description="Current health status of the provider", + example={"status": "OK", "message": "Healthy"}, + ) # type: ignore + provider_id: str = Field( + ..., description="Unique provider identifier", example="openai" + ) # type: ignore + provider_type: str = Field( + ..., description="Provider implementation type", example="remote::openai" + ) # type: ignore + + class RAGChunk(BaseModel): """Model representing a RAG chunk used in the response.""" diff --git a/tests/unit/app/endpoints/test_providers.py b/tests/unit/app/endpoints/test_providers.py new file mode 100644 index 00000000..167400cc --- /dev/null +++ b/tests/unit/app/endpoints/test_providers.py @@ -0,0 +1,163 @@ +"""Unit tests for the /providers REST API endpoints.""" + +from unittest.mock import AsyncMock + +import pytest +from fastapi import HTTPException, Request, status +from llama_stack_client import APIConnectionError + +from app.endpoints.providers import ( + get_provider_endpoint_handler, + providers_endpoint_handler, +) + + +@pytest.mark.asyncio +async def test_providers_endpoint_configuration_not_loaded(mocker): + """Test that /providers endpoint raises HTTP 500 if configuration is not loaded.""" + mocker.patch("app.endpoints.providers.configuration", None) + request = Request(scope={"type": "http"}) + auth = ("user", "token", {}) + + with pytest.raises(HTTPException) as e: + await providers_endpoint_handler(request=request, auth=auth) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_providers_endpoint_connection_error(mocker): + """Test that /providers endpoint raises HTTP 500 if Llama Stack connection fails.""" + mock_client = AsyncMock() + mock_client.providers.list.side_effect = APIConnectionError(request=None) + mocker.patch( + "app.endpoints.providers.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + request = Request(scope={"type": "http"}) + auth = ("user", "token", {}) + + with pytest.raises(HTTPException) as e: + await providers_endpoint_handler(request=request, auth=auth) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Unable to connect to Llama Stack" in e.value.detail["response"] + + +@pytest.mark.asyncio +async def test_providers_endpoint_success(mocker): + """Test that /providers endpoint returns a grouped list of providers on success.""" + provider_list = [ + { + "api": "inference", + "provider_id": "openai", + "provider_type": "remote::openai", + }, + { + "api": "inference", + "provider_id": "st", + "provider_type": "inline::sentence-transformers", + }, + { + "api": "datasetio", + "provider_id": "huggingface", + "provider_type": "remote::huggingface", + }, + ] + mock_client = AsyncMock() + mock_client.providers.list.return_value = provider_list + mocker.patch( + "app.endpoints.providers.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + request = Request(scope={"type": "http"}) + auth = ("user", "token", {}) + + response = await providers_endpoint_handler(request=request, auth=auth) + assert "inference" in response.providers + assert len(response.providers["inference"]) == 2 + assert "datasetio" in response.providers + + +@pytest.mark.asyncio +async def test_get_provider_not_found(mocker): + """Test that /providers/{provider_id} endpoint raises HTTP 404 if the provider is not found.""" + mock_client = AsyncMock() + mock_client.providers.list.return_value = [] + mocker.patch( + "app.endpoints.providers.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + request = Request(scope={"type": "http"}) + auth = ("user", "token", {}) + + with pytest.raises(HTTPException) as e: + await get_provider_endpoint_handler( + request=request, provider_id="openai", auth=auth + ) + assert e.value.status_code == status.HTTP_404_NOT_FOUND + assert "not found" in e.value.detail["response"] + + +@pytest.mark.asyncio +async def test_get_provider_success(mocker): + """Test that /providers/{provider_id} endpoint returns provider details on success.""" + provider = { + "api": "inference", + "provider_id": "openai", + "provider_type": "remote::openai", + "config": {"api_key": "*****"}, + "health": {"status": "OK", "message": "Healthy"}, + } + mock_client = AsyncMock() + mock_client.providers.list.return_value = [provider] + mocker.patch( + "app.endpoints.providers.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + request = Request(scope={"type": "http"}) + auth = ("user", "token", {}) + + response = await get_provider_endpoint_handler( + request=request, provider_id="openai", auth=auth + ) + assert response.provider_id == "openai" + assert response.api == "inference" + + +@pytest.mark.asyncio +async def test_get_provider_connection_error(mocker): + """Test that /providers/{provider_id} raises HTTP 500 if Llama Stack connection fails.""" + mock_client = AsyncMock() + mock_client.providers.list.side_effect = APIConnectionError(request=None) + mocker.patch( + "app.endpoints.providers.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + request = Request(scope={"type": "http"}) + auth = ("user", "token", {}) + + with pytest.raises(HTTPException) as e: + await get_provider_endpoint_handler( + request=request, provider_id="openai", auth=auth + ) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Unable to connect to Llama Stack" in e.value.detail["response"] + + +@pytest.mark.asyncio +async def test_get_provider_unexpected_exception(mocker): + """Test that /providers/{provider_id} endpoint raises HTTP 500 for unexpected exceptions.""" + mock_client = AsyncMock() + mock_client.providers.list.side_effect = Exception("boom") + mocker.patch( + "app.endpoints.providers.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + request = Request(scope={"type": "http"}) + auth = ("user", "token", {}) + + with pytest.raises(HTTPException) as e: + await get_provider_endpoint_handler( + request=request, provider_id="openai", auth=auth + ) + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Unable to retrieve list of providers" in e.value.detail["response"] diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index 9b0b5520..e466fca4 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -13,6 +13,7 @@ info, models, shields, + providers, query, health, config, @@ -63,12 +64,13 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 14 + assert len(app.routers) == 15 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() assert tools.router in app.get_routers() assert shields.router in app.get_routers() + assert providers.router in app.get_routers() assert query.router in app.get_routers() assert streaming_query.router in app.get_routers() assert config.router in app.get_routers() @@ -86,12 +88,13 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 14 + assert len(app.routers) == 15 assert app.get_router_prefix(root.router) == "" assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" assert app.get_router_prefix(tools.router) == "/v1" assert app.get_router_prefix(shields.router) == "/v1" + assert app.get_router_prefix(providers.router) == "/v1" assert app.get_router_prefix(query.router) == "/v1" assert app.get_router_prefix(streaming_query.router) == "/v1" assert app.get_router_prefix(config.router) == "/v1"