Skip to content

Commit 7364a65

Browse files
authored
Merge pull request #356 from omertuc/auth
Role-based authorization layer
2 parents bb57e5d + 83e9e23 commit 7364a65

34 files changed

+1092
-133
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
"openai==1.99.9",
3838
"sqlalchemy>=2.0.42",
3939
"semver<4.0.0",
40+
"jsonpath-ng>=1.6.1",
4041
]
4142

4243

src/app/endpoints/authorized.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Handler for REST API call to authorized endpoint."""
22

33
import logging
4-
from typing import Any
4+
from typing import Annotated, Any
55

66
from fastapi import APIRouter, Depends
77

8+
from auth.interface import AuthTuple
89
from auth import get_auth_dependency
910
from models.responses import AuthorizedResponse, UnauthorizedResponse, ForbiddenResponse
1011

@@ -31,7 +32,7 @@
3132

3233
@router.post("/authorized", responses=authorized_responses)
3334
async def authorized_endpoint_handler(
34-
auth: Any = Depends(auth_dependency),
35+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
3536
) -> AuthorizedResponse:
3637
"""
3738
Handle request to the /authorized endpoint.

src/app/endpoints/config.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
"""Handler for REST API call to retrieve service configuration."""
22

33
import logging
4-
from typing import Any
4+
from typing import Annotated, Any
55

6-
from fastapi import APIRouter, Request
6+
from fastapi import APIRouter, Request, Depends
77

8-
from models.config import Configuration
8+
from auth.interface import AuthTuple
9+
from auth import get_auth_dependency
10+
from authorization.middleware import authorize
911
from configuration import configuration
12+
from models.config import Action, Configuration
1013
from utils.endpoints import check_configuration_loaded
1114

1215
logger = logging.getLogger(__name__)
1316
router = APIRouter(tags=["config"])
1417

18+
auth_dependency = get_auth_dependency()
19+
1520

1621
get_config_responses: dict[int | str, dict[str, Any]] = {
1722
200: {
@@ -56,7 +61,11 @@
5661

5762

5863
@router.get("/config", responses=get_config_responses)
59-
def config_endpoint_handler(_request: Request) -> Configuration:
64+
@authorize(Action.GET_CONFIG)
65+
async def config_endpoint_handler(
66+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
67+
request: Request,
68+
) -> Configuration:
6069
"""
6170
Handle requests to the /config endpoint.
6271
@@ -66,6 +75,12 @@ def config_endpoint_handler(_request: Request) -> Configuration:
6675
Returns:
6776
Configuration: The loaded service configuration object.
6877
"""
78+
# Used only for authorization
79+
_ = auth
80+
81+
# Nothing interesting in the request
82+
_ = request
83+
6984
# ensure that configuration is loaded
7085
check_configuration_loaded(configuration)
7186

src/app/endpoints/conversations.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,21 @@
55

66
from llama_stack_client import APIConnectionError, NotFoundError
77

8-
from fastapi import APIRouter, HTTPException, status, Depends
8+
from fastapi import APIRouter, HTTPException, Request, status, Depends
99

1010
from client import AsyncLlamaStackClientHolder
1111
from configuration import configuration
12+
from app.database import get_session
13+
from auth import get_auth_dependency
14+
from authorization.middleware import authorize
15+
from models.config import Action
16+
from models.database.conversations import UserConversation
1217
from models.responses import (
1318
ConversationResponse,
1419
ConversationDeleteResponse,
1520
ConversationsListResponse,
1621
ConversationDetails,
1722
)
18-
from models.database.conversations import UserConversation
19-
from auth import get_auth_dependency
20-
from app.database import get_session
2123
from utils.endpoints import check_configuration_loaded, validate_conversation_ownership
2224
from utils.suid import check_suid
2325

@@ -146,7 +148,9 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
146148

147149

148150
@router.get("/conversations", responses=conversations_list_responses)
149-
def get_conversations_list_endpoint_handler(
151+
@authorize(Action.LIST_CONVERSATIONS)
152+
async def get_conversations_list_endpoint_handler(
153+
request: Request,
150154
auth: Any = Depends(auth_dependency),
151155
) -> ConversationsListResponse:
152156
"""Handle request to retrieve all conversations for the authenticated user."""
@@ -158,11 +162,16 @@ def get_conversations_list_endpoint_handler(
158162

159163
with get_session() as session:
160164
try:
161-
# Get all conversations for this user
162-
user_conversations = (
163-
session.query(UserConversation).filter_by(user_id=user_id).all()
165+
query = session.query(UserConversation)
166+
167+
filtered_query = (
168+
query
169+
if Action.LIST_OTHERS_CONVERSATIONS in request.state.authorized_actions
170+
else query.filter_by(user_id=user_id)
164171
)
165172

173+
user_conversations = filtered_query.all()
174+
166175
# Return conversation summaries with metadata
167176
conversations = [
168177
ConversationDetails(
@@ -200,7 +209,9 @@ def get_conversations_list_endpoint_handler(
200209

201210

202211
@router.get("/conversations/{conversation_id}", responses=conversation_responses)
212+
@authorize(Action.GET_CONVERSATION)
203213
async def get_conversation_endpoint_handler(
214+
request: Request,
204215
conversation_id: str,
205216
auth: Any = Depends(auth_dependency),
206217
) -> ConversationResponse:
@@ -239,6 +250,9 @@ async def get_conversation_endpoint_handler(
239250
validate_conversation_ownership(
240251
user_id=user_id,
241252
conversation_id=conversation_id,
253+
others_allowed=(
254+
Action.READ_OTHERS_CONVERSATIONS in request.state.authorized_actions
255+
),
242256
)
243257

244258
agent_id = conversation_id
@@ -309,7 +323,9 @@ async def get_conversation_endpoint_handler(
309323
@router.delete(
310324
"/conversations/{conversation_id}", responses=conversation_delete_responses
311325
)
326+
@authorize(Action.DELETE_CONVERSATION)
312327
async def delete_conversation_endpoint_handler(
328+
request: Request,
313329
conversation_id: str,
314330
auth: Any = Depends(auth_dependency),
315331
) -> ConversationDeleteResponse:
@@ -342,6 +358,9 @@ async def delete_conversation_endpoint_handler(
342358
validate_conversation_ownership(
343359
user_id=user_id,
344360
conversation_id=conversation_id,
361+
others_allowed=(
362+
Action.DELETE_OTHERS_CONVERSATIONS in request.state.authorized_actions
363+
),
345364
)
346365

347366
agent_id = conversation_id

src/app/endpoints/feedback.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,21 @@
55
from pathlib import Path
66
import json
77
from datetime import datetime, UTC
8-
from fastapi import APIRouter, Request, HTTPException, Depends, status
8+
from fastapi import APIRouter, HTTPException, Depends, Request, status
99

1010
from auth import get_auth_dependency
1111
from auth.interface import AuthTuple
12+
from authorization.middleware import authorize
1213
from configuration import configuration
14+
from models.config import Action
15+
from models.requests import FeedbackRequest
1316
from models.responses import (
1417
ErrorResponse,
1518
FeedbackResponse,
1619
StatusResponse,
1720
UnauthorizedResponse,
1821
ForbiddenResponse,
1922
)
20-
from models.requests import FeedbackRequest
2123
from utils.suid import get_suid
2224

2325
logger = logging.getLogger(__name__)
@@ -79,7 +81,8 @@ async def assert_feedback_enabled(_request: Request) -> None:
7981

8082

8183
@router.post("", responses=feedback_response)
82-
def feedback_endpoint_handler(
84+
@authorize(Action.FEEDBACK)
85+
async def feedback_endpoint_handler(
8386
feedback_request: FeedbackRequest,
8487
auth: Annotated[AuthTuple, Depends(auth_dependency)],
8588
_ensure_feedback_enabled: Any = Depends(assert_feedback_enabled),

src/app/endpoints/health.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
"""
77

88
import logging
9-
from typing import Any
9+
from typing import Annotated, Any
1010

1111
from llama_stack.providers.datatypes import HealthStatus
1212

13-
from fastapi import APIRouter, status, Response
13+
from fastapi import APIRouter, status, Response, Depends
1414
from client import AsyncLlamaStackClientHolder
15+
from auth.interface import AuthTuple
16+
from auth import get_auth_dependency
17+
from authorization.middleware import authorize
18+
from models.config import Action
1519
from models.responses import (
1620
LivenessResponse,
1721
ReadinessResponse,
@@ -21,6 +25,8 @@
2125
logger = logging.getLogger(__name__)
2226
router = APIRouter(tags=["health"])
2327

28+
auth_dependency = get_auth_dependency()
29+
2430

2531
async def get_providers_health_statuses() -> list[ProviderHealthStatus]:
2632
"""
@@ -72,14 +78,21 @@ async def get_providers_health_statuses() -> list[ProviderHealthStatus]:
7278

7379

7480
@router.get("/readiness", responses=get_readiness_responses)
75-
async def readiness_probe_get_method(response: Response) -> ReadinessResponse:
81+
@authorize(Action.INFO)
82+
async def readiness_probe_get_method(
83+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
84+
response: Response,
85+
) -> ReadinessResponse:
7686
"""
7787
Handle the readiness probe endpoint, returning service readiness.
7888
7989
If any provider reports an error status, responds with HTTP 503
8090
and details of unhealthy providers; otherwise, indicates the
8191
service is ready.
8292
"""
93+
# Used only for authorization
94+
_ = auth
95+
8396
provider_statuses = await get_providers_health_statuses()
8497

8598
# Check if any provider is unhealthy (not counting not_implemented as unhealthy)
@@ -112,11 +125,17 @@ async def readiness_probe_get_method(response: Response) -> ReadinessResponse:
112125

113126

114127
@router.get("/liveness", responses=get_liveness_responses)
115-
def liveness_probe_get_method() -> LivenessResponse:
128+
@authorize(Action.INFO)
129+
async def liveness_probe_get_method(
130+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
131+
) -> LivenessResponse:
116132
"""
117133
Return the liveness status of the service.
118134
119135
Returns:
120136
LivenessResponse: Indicates that the service is alive.
121137
"""
138+
# Used only for authorization
139+
_ = auth
140+
122141
return LivenessResponse(alive=True)

src/app/endpoints/info.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
"""Handler for REST API call to provide info."""
22

33
import logging
4-
from typing import Any
4+
from typing import Annotated, Any
55

66
from fastapi import APIRouter, Request
7+
from fastapi import Depends
78

9+
from auth.interface import AuthTuple
10+
from auth import get_auth_dependency
11+
from authorization.middleware import authorize
812
from configuration import configuration
9-
from version import __version__
13+
from models.config import Action
1014
from models.responses import InfoResponse
15+
from version import __version__
1116

1217
logger = logging.getLogger(__name__)
1318
router = APIRouter(tags=["info"])
1419

20+
auth_dependency = get_auth_dependency()
21+
1522

1623
get_info_responses: dict[int | str, dict[str, Any]] = {
1724
200: {
@@ -22,7 +29,11 @@
2229

2330

2431
@router.get("/info", responses=get_info_responses)
25-
def info_endpoint_handler(_request: Request) -> InfoResponse:
32+
@authorize(Action.INFO)
33+
async def info_endpoint_handler(
34+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
35+
request: Request,
36+
) -> InfoResponse:
2637
"""
2738
Handle request to the /info endpoint.
2839
@@ -32,4 +43,10 @@ def info_endpoint_handler(_request: Request) -> InfoResponse:
3243
Returns:
3344
InfoResponse: An object containing the service's name and version.
3445
"""
46+
# Used only for authorization
47+
_ = auth
48+
49+
# Nothing interesting in the request
50+
_ = request
51+
3552
return InfoResponse(name=configuration.configuration.name, version=__version__)

src/app/endpoints/metrics.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
11
"""Handler for REST API call to provide metrics."""
22

3+
from typing import Annotated
34
from fastapi.responses import PlainTextResponse
4-
from fastapi import APIRouter, Request
5+
from fastapi import APIRouter, Request, Depends
56
from prometheus_client import (
67
generate_latest,
78
CONTENT_TYPE_LATEST,
89
)
910

11+
from auth.interface import AuthTuple
12+
from auth import get_auth_dependency
13+
from authorization.middleware import authorize
14+
from models.config import Action
1015
from metrics.utils import setup_model_metrics
1116

1217
router = APIRouter(tags=["metrics"])
1318

19+
auth_dependency = get_auth_dependency()
20+
1421

1522
@router.get("/metrics", response_class=PlainTextResponse)
16-
async def metrics_endpoint_handler(_request: Request) -> PlainTextResponse:
23+
@authorize(Action.GET_METRICS)
24+
async def metrics_endpoint_handler(
25+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
26+
request: Request,
27+
) -> PlainTextResponse:
1728
"""
1829
Handle request to the /metrics endpoint.
1930
@@ -24,6 +35,12 @@ async def metrics_endpoint_handler(_request: Request) -> PlainTextResponse:
2435
set up, then responds with the current metrics snapshot in
2536
Prometheus format.
2637
"""
38+
# Used only for authorization
39+
_ = auth
40+
41+
# Nothing interesting in the request
42+
_ = request
43+
2744
# Setup the model metrics if not already done. This is a one-time setup
2845
# and will not be run again on subsequent calls to this endpoint
2946
await setup_model_metrics()

0 commit comments

Comments
 (0)