Skip to content

Commit 670e918

Browse files
committed
Auth
1 parent 6e61258 commit 670e918

File tree

17 files changed

+815
-2
lines changed

17 files changed

+815
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dependencies = [
3636
"openai==1.99.1",
3737
"sqlalchemy>=2.0.42",
3838
"email-validator>=2.2.0",
39+
"jsonpath-ng>=1.6.1",
3940
]
4041

4142
[tool.pyright]

src/app/endpoints/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from models.config import Configuration
99
from configuration import configuration
10+
from authorization.middleware import authorize
11+
from authorization.models import Action
1012
from utils.endpoints import check_configuration_loaded
1113

1214
logger = logging.getLogger(__name__)
@@ -56,6 +58,7 @@
5658

5759

5860
@router.get("/config", responses=get_config_responses)
61+
@authorize(Action.GET_CONFIG)
5962
def config_endpoint_handler(_request: Request) -> Configuration:
6063
"""
6164
Handle requests to the /config endpoint.

src/app/endpoints/conversations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from auth import get_auth_dependency
2020
from app.database import get_session
2121
from utils.endpoints import check_configuration_loaded, validate_conversation_ownership
22+
from authorization.middleware import authorize
23+
from authorization.models import Action
2224
from utils.suid import check_suid
2325

2426
logger = logging.getLogger("app.endpoints.handlers")
@@ -200,6 +202,7 @@ def get_conversations_list_endpoint_handler(
200202

201203

202204
@router.get("/conversations/{conversation_id}", responses=conversation_responses)
205+
@authorize(Action.GET_CONVERSATION)
203206
async def get_conversation_endpoint_handler(
204207
conversation_id: str,
205208
auth: Any = Depends(auth_dependency),
@@ -309,6 +312,7 @@ async def get_conversation_endpoint_handler(
309312
@router.delete(
310313
"/conversations/{conversation_id}", responses=conversation_delete_responses
311314
)
315+
@authorize(Action.DELETE_CONVERSATION)
312316
async def delete_conversation_endpoint_handler(
313317
conversation_id: str,
314318
auth: Any = Depends(auth_dependency),

src/app/endpoints/feedback.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from auth import get_auth_dependency
1111
from auth.interface import AuthTuple
12+
from authorization.middleware import authorize
13+
from authorization.models import Action
1214
from configuration import configuration
1315
from models.responses import (
1416
ErrorResponse,
@@ -79,6 +81,7 @@ async def assert_feedback_enabled(_request: Request) -> None:
7981

8082

8183
@router.post("", responses=feedback_response)
84+
@authorize(Action.FEEDBACK)
8285
def feedback_endpoint_handler(
8386
feedback_request: FeedbackRequest,
8487
auth: Annotated[AuthTuple, Depends(auth_dependency)],

src/app/endpoints/metrics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
CONTENT_TYPE_LATEST,
88
)
99

10+
from authorization.middleware import authorize
11+
from authorization.models import Action
1012
from metrics.utils import setup_model_metrics
1113

1214
router = APIRouter(tags=["metrics"])
1315

1416

1517
@router.get("/metrics", response_class=PlainTextResponse)
18+
@authorize(Action.GET_METRICS)
1619
async def metrics_endpoint_handler(_request: Request) -> PlainTextResponse:
1720
"""
1821
Handle request to the /metrics endpoint.

src/app/endpoints/models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33
import logging
44
from typing import Any
55

6+
from fastapi.params import Depends
67
from llama_stack_client import APIConnectionError
78
from fastapi import APIRouter, HTTPException, Request, status
89

910
from client import AsyncLlamaStackClientHolder
1011
from configuration import configuration
12+
from authorization.middleware import authorize
13+
from authorization.models import Action
1114
from models.responses import ModelsResponse
1215
from utils.endpoints import check_configuration_loaded
16+
from auth import get_auth_dependency
1317

1418
logger = logging.getLogger(__name__)
1519
router = APIRouter(tags=["models"])
1620

1721

22+
auth_dependency = get_auth_dependency()
23+
24+
1825
models_responses: dict[int | str, dict[str, Any]] = {
1926
200: {
2027
"models": [
@@ -43,7 +50,8 @@
4350

4451

4552
@router.get("/models", responses=models_responses)
46-
async def models_endpoint_handler(_request: Request) -> ModelsResponse:
53+
@authorize(Action.GET_MODELS)
54+
async def models_endpoint_handler(_request: Request, auth: Any = Depends(get_auth_dependency())) -> ModelsResponse:
4755
"""
4856
Handle requests to the /models endpoint.
4957
@@ -57,6 +65,10 @@ async def models_endpoint_handler(_request: Request) -> ModelsResponse:
5765
Returns:
5866
ModelsResponse: An object containing the list of available models.
5967
"""
68+
69+
# Used only by the middleware
70+
_ = auth
71+
6072
check_configuration_loaded(configuration)
6173

6274
llama_stack_configuration = configuration.llama_stack_configuration

src/app/endpoints/query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
get_system_prompt,
3535
validate_conversation_ownership,
3636
)
37+
from authorization.middleware import authorize
38+
from authorization.models import Action
3739
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3840
from utils.suid import get_suid
3941

@@ -147,6 +149,7 @@ def evaluate_model_hints(
147149

148150

149151
@router.post("/query", responses=query_response)
152+
@authorize(Action.QUERY)
150153
async def query_endpoint_handler(
151154
query_request: QueryRequest,
152155
auth: Annotated[AuthTuple, Depends(auth_dependency)],

src/app/endpoints/streaming_query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from auth import get_auth_dependency
2121
from auth.interface import AuthTuple
22+
from authorization.middleware import authorize
23+
from authorization.models import Action
2224
from client import AsyncLlamaStackClientHolder
2325
from configuration import configuration
2426
import metrics
@@ -384,6 +386,7 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]:
384386

385387

386388
@router.post("/streaming_query")
389+
@authorize(Action.STREAMING_QUERY)
387390
async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
388391
_request: Request,
389392
query_request: QueryRequest,

src/auth/jwk_token.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from asyncio import Lock
55
from typing import Any, Callable
6+
import json
67

78
from fastapi import Request, HTTPException, status
89
from authlib.jose import JsonWebKey, KeySet, jwt, Key
@@ -188,4 +189,4 @@ async def __call__(self, request: Request) -> tuple[str, str, str]:
188189

189190
logger.info("Successfully authenticated user %s (ID: %s)", username, user_id)
190191

191-
return user_id, username, user_token
192+
return user_id, username, json.dumps(claims)

src/authorization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Authorization module for role-based access control."""

0 commit comments

Comments
 (0)