Skip to content

Commit aee5d72

Browse files
committed
auth5
1 parent b3989af commit aee5d72

File tree

12 files changed

+252
-34
lines changed

12 files changed

+252
-34
lines changed

src/app/endpoints/config.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Handler for REST API call to retrieve service configuration."""
22

33
import logging
4-
from typing import Any
4+
from typing import Annotated, Any
55

6-
from fastapi import APIRouter, Request
6+
from fastapi import APIRouter, Request, Depends
77

8+
from auth.interface import AuthTuple
9+
from auth import get_auth_dependency
810
from authorization.middleware import authorize
911
from configuration import configuration
1012
from models.config import Action, Configuration
@@ -13,6 +15,8 @@
1315
logger = logging.getLogger(__name__)
1416
router = APIRouter(tags=["config"])
1517

18+
auth_dependency = get_auth_dependency()
19+
1620

1721
get_config_responses: dict[int | str, dict[str, Any]] = {
1822
200: {
@@ -58,7 +62,10 @@
5862

5963
@router.get("/config", responses=get_config_responses)
6064
@authorize(Action.GET_CONFIG)
61-
def config_endpoint_handler(_request: Request) -> Configuration:
65+
def config_endpoint_handler(
66+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
67+
_request: Request,
68+
) -> Configuration:
6269
"""
6370
Handle requests to the /config endpoint.
6471
@@ -68,6 +75,9 @@ def config_endpoint_handler(_request: Request) -> Configuration:
6875
Returns:
6976
Configuration: The loaded service configuration object.
7077
"""
78+
# Used only for authorization
79+
_ = auth
80+
7181
# ensure that configuration is loaded
7282
check_configuration_loaded(configuration)
7383

src/app/endpoints/metrics.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
11
"""Handler for REST API call to provide metrics."""
22

3+
from typing import Annotated
34
from fastapi.responses import PlainTextResponse
4-
from fastapi import APIRouter, Request
5+
from fastapi import APIRouter, Request, Depends
56
from prometheus_client import (
67
generate_latest,
78
CONTENT_TYPE_LATEST,
89
)
910

11+
from auth.interface import AuthTuple
12+
from auth import get_auth_dependency
1013
from authorization.middleware import authorize
1114
from models.config import Action
1215
from metrics.utils import setup_model_metrics
1316

1417
router = APIRouter(tags=["metrics"])
1518

19+
auth_dependency = get_auth_dependency()
20+
1621

1722
@router.get("/metrics", response_class=PlainTextResponse)
1823
@authorize(Action.GET_METRICS)
19-
async def metrics_endpoint_handler(_request: Request) -> PlainTextResponse:
24+
async def metrics_endpoint_handler(
25+
auth: Annotated[AuthTuple, Depends(auth_dependency)],
26+
_request: Request,
27+
) -> PlainTextResponse:
2028
"""
2129
Handle request to the /metrics endpoint.
2230
@@ -27,6 +35,9 @@ async def metrics_endpoint_handler(_request: Request) -> PlainTextResponse:
2735
set up, then responds with the current metrics snapshot in
2836
Prometheus format.
2937
"""
38+
# Used only for authorization
39+
_ = auth
40+
3041
# Setup the model metrics if not already done. This is a one-time setup
3142
# and will not be run again on subsequent calls to this endpoint
3243
await setup_model_metrics()

src/auth/jwk_token.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
from asyncio import Lock
55
from typing import Any, Callable
6-
import json
76

87
from fastapi import Request, HTTPException, status
98
from authlib.jose import JsonWebKey, KeySet, jwt, Key
@@ -189,4 +188,4 @@ async def __call__(self, request: Request) -> tuple[str, str, str]:
189188

190189
logger.info("Successfully authenticated user %s (ID: %s)", username, user_id)
191190

192-
return user_id, username, json.dumps(claims)
191+
return user_id, username, user_token

src/authorization/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(self, access_rules: list[AccessRule]):
145145

146146
async def check_access(self, action: Action, user_roles: UserRoles) -> bool:
147147
"""Check if the user has access to the specified action based on their roles."""
148-
if action != Action.ADMIN and self.check_access(action.ADMIN, user_roles):
148+
if action != Action.ADMIN and await self.check_access(Action.ADMIN, user_roles):
149149
# Recurse to check if the roles allow the user to perform the admin action,
150150
# if they do, then we allow any action
151151
return True

src/authorization/middleware.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def _perform_authorization_check(action: Action, kwargs: dict[str, Any]) -
7373
detail="Internal server error",
7474
) from exc
7575

76-
if not access_resolver.check_access(
76+
if not await access_resolver.check_access(
7777
action, await role_resolver.resolve_roles(auth)
7878
):
7979
raise HTTPException(
@@ -84,12 +84,18 @@ async def _perform_authorization_check(action: Action, kwargs: dict[str, Any]) -
8484

8585
def authorize(action: Action) -> Callable:
8686
"""Check authorization for an endpoint (async version)."""
87+
import asyncio
8788

8889
def decorator(func: Callable) -> Callable:
8990
@wraps(func)
9091
async def wrapper(*args: Any, **kwargs: Any) -> Any:
9192
await _perform_authorization_check(action, kwargs)
92-
return await func(*args, **kwargs)
93+
94+
# Handle both sync and async functions
95+
result = func(*args, **kwargs)
96+
if asyncio.iscoroutine(result):
97+
return await result
98+
return result
9399

94100
return wrapper
95101

tests/unit/app/endpoints/test_config.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,19 @@
77
from configuration import AppConfig
88

99

10-
def test_config_endpoint_handler_configuration_not_loaded(mocker):
10+
@pytest.mark.asyncio
11+
async def test_config_endpoint_handler_configuration_not_loaded(mocker):
1112
"""Test the config endpoint handler."""
13+
# Mock authorization resolvers
14+
mock_resolvers = mocker.patch(
15+
"authorization.middleware.get_authorization_resolvers"
16+
)
17+
mock_role_resolver = mocker.AsyncMock()
18+
mock_access_resolver = mocker.AsyncMock()
19+
mock_role_resolver.resolve_roles.return_value = []
20+
mock_access_resolver.check_access.return_value = True
21+
mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver)
22+
1223
mocker.patch(
1324
"app.endpoints.config.configuration._configuration",
1425
new=None,
@@ -20,14 +31,27 @@ def test_config_endpoint_handler_configuration_not_loaded(mocker):
2031
"type": "http",
2132
}
2233
)
23-
with pytest.raises(HTTPException) as e:
24-
config_endpoint_handler(request)
25-
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
26-
assert e.detail["response"] == "Configuration is not loaded"
34+
auth = ("test_user", "token", {})
35+
with pytest.raises(HTTPException) as exc_info:
36+
await config_endpoint_handler(auth=auth, _request=request)
2737

38+
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
39+
assert exc_info.value.detail["response"] == "Configuration is not loaded"
2840

29-
def test_config_endpoint_handler_configuration_loaded():
41+
42+
@pytest.mark.asyncio
43+
async def test_config_endpoint_handler_configuration_loaded(mocker):
3044
"""Test the config endpoint handler."""
45+
# Mock authorization resolvers
46+
mock_resolvers = mocker.patch(
47+
"authorization.middleware.get_authorization_resolvers"
48+
)
49+
mock_role_resolver = mocker.AsyncMock()
50+
mock_access_resolver = mocker.AsyncMock()
51+
mock_role_resolver.resolve_roles.return_value = []
52+
mock_access_resolver.check_access.return_value = True
53+
mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver)
54+
3155
config_dict = {
3256
"name": "foo",
3357
"service": {
@@ -49,15 +73,21 @@ def test_config_endpoint_handler_configuration_loaded():
4973
"authentication": {
5074
"module": "noop",
5175
},
76+
"authorization": {"access_rules": []},
5277
"customization": None,
5378
}
5479
cfg = AppConfig()
5580
cfg.init_from_dict(config_dict)
81+
82+
# Mock configuration
83+
mocker.patch("app.endpoints.config.configuration", cfg)
84+
5685
request = Request(
5786
scope={
5887
"type": "http",
5988
}
6089
)
61-
response = config_endpoint_handler(request)
90+
auth = ("test_user", "token", {})
91+
response = await config_endpoint_handler(auth=auth, _request=request)
6292
assert response is not None
6393
assert response == cfg.configuration

tests/unit/app/endpoints/test_feedback.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,20 @@ async def test_assert_feedback_enabled(mocker):
6262
],
6363
ids=["no_categories", "with_negative_categories"],
6464
)
65-
def test_feedback_endpoint_handler(mocker, feedback_request_data):
65+
@pytest.mark.asyncio
66+
async def test_feedback_endpoint_handler(mocker, feedback_request_data):
6667
"""Test that feedback_endpoint_handler processes feedback for different payloads."""
6768

69+
# Mock authorization resolvers
70+
mock_resolvers = mocker.patch(
71+
"authorization.middleware.get_authorization_resolvers"
72+
)
73+
mock_role_resolver = mocker.AsyncMock()
74+
mock_access_resolver = mocker.AsyncMock()
75+
mock_role_resolver.resolve_roles.return_value = []
76+
mock_access_resolver.check_access.return_value = True
77+
mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver)
78+
6879
# Mock the dependencies
6980
mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None)
7081
mocker.patch("app.endpoints.feedback.store_feedback", return_value=None)
@@ -74,7 +85,7 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data):
7485
feedback_request.model_dump.return_value = feedback_request_data
7586

7687
# Call the endpoint handler
77-
result = feedback_endpoint_handler(
88+
result = await feedback_endpoint_handler(
7889
feedback_request=feedback_request,
7990
_ensure_feedback_enabled=assert_feedback_enabled,
8091
auth=("test_user_id", "test_username", "test_token"),
@@ -84,8 +95,19 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data):
8495
assert result.response == "feedback received"
8596

8697

87-
def test_feedback_endpoint_handler_error(mocker):
98+
@pytest.mark.asyncio
99+
async def test_feedback_endpoint_handler_error(mocker):
88100
"""Test that feedback_endpoint_handler raises an HTTPException on error."""
101+
# Mock authorization resolvers
102+
mock_resolvers = mocker.patch(
103+
"authorization.middleware.get_authorization_resolvers"
104+
)
105+
mock_role_resolver = mocker.AsyncMock()
106+
mock_access_resolver = mocker.AsyncMock()
107+
mock_role_resolver.resolve_roles.return_value = []
108+
mock_access_resolver.check_access.return_value = True
109+
mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver)
110+
89111
# Mock the dependencies
90112
mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None)
91113
mocker.patch(
@@ -98,7 +120,7 @@ def test_feedback_endpoint_handler_error(mocker):
98120

99121
# Call the endpoint handler and assert it raises an exception
100122
with pytest.raises(HTTPException) as exc_info:
101-
feedback_endpoint_handler(
123+
await feedback_endpoint_handler(
102124
feedback_request=feedback_request,
103125
_ensure_feedback_enabled=assert_feedback_enabled,
104126
auth=("test_user_id", "test_username", "test_token"),

tests/unit/app/endpoints/test_health.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Unit tests for the /health REST API endpoint."""
22

3+
import pytest
34
from unittest.mock import Mock
45

56
from llama_stack.providers.datatypes import HealthStatus
@@ -12,8 +13,19 @@
1213
from models.responses import ProviderHealthStatus, ReadinessResponse
1314

1415

16+
@pytest.mark.asyncio
1517
async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker):
1618
"""Test the readiness endpoint handler fails when providers are unhealthy."""
19+
# Mock authorization resolvers
20+
mock_resolvers = mocker.patch(
21+
"authorization.middleware.get_authorization_resolvers"
22+
)
23+
mock_role_resolver = mocker.AsyncMock()
24+
mock_access_resolver = mocker.AsyncMock()
25+
mock_role_resolver.resolve_roles.return_value = []
26+
mock_access_resolver.check_access.return_value = True
27+
mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver)
28+
1729
# Mock get_providers_health_statuses to return an unhealthy provider
1830
mock_get_providers_health_statuses = mocker.patch(
1931
"app.endpoints.health.get_providers_health_statuses"
@@ -38,8 +50,19 @@ async def test_readiness_probe_fails_due_to_unhealthy_providers(mocker):
3850
assert mock_response.status_code == 503
3951

4052

53+
@pytest.mark.asyncio
4154
async def test_readiness_probe_success_when_all_providers_healthy(mocker):
4255
"""Test the readiness endpoint handler succeeds when all providers are healthy."""
56+
# Mock authorization resolvers
57+
mock_resolvers = mocker.patch(
58+
"authorization.middleware.get_authorization_resolvers"
59+
)
60+
mock_role_resolver = mocker.AsyncMock()
61+
mock_access_resolver = mocker.AsyncMock()
62+
mock_role_resolver.resolve_roles.return_value = []
63+
mock_access_resolver.check_access.return_value = True
64+
mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver)
65+
4366
# Mock get_providers_health_statuses to return healthy providers
4467
mock_get_providers_health_statuses = mocker.patch(
4568
"app.endpoints.health.get_providers_health_statuses"
@@ -70,10 +93,21 @@ async def test_readiness_probe_success_when_all_providers_healthy(mocker):
7093
assert len(response.providers) == 0
7194

7295

73-
def test_liveness_probe():
96+
@pytest.mark.asyncio
97+
async def test_liveness_probe(mocker):
7498
"""Test the liveness endpoint handler."""
99+
# Mock authorization resolvers
100+
mock_resolvers = mocker.patch(
101+
"authorization.middleware.get_authorization_resolvers"
102+
)
103+
mock_role_resolver = mocker.AsyncMock()
104+
mock_access_resolver = mocker.AsyncMock()
105+
mock_role_resolver.resolve_roles.return_value = []
106+
mock_access_resolver.check_access.return_value = True
107+
mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver)
108+
75109
auth = ("test_user", "token", {})
76-
response = liveness_probe_get_method(auth=auth)
110+
response = await liveness_probe_get_method(auth=auth)
77111
assert response is not None
78112
assert response.alive is True
79113

tests/unit/app/endpoints/test_info.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""Unit tests for the /info REST API endpoint."""
22

3+
import pytest
34
from fastapi import Request
45

56
from app.endpoints.info import info_endpoint_handler
67
from configuration import AppConfig
78

89

9-
def test_info_endpoint():
10+
@pytest.mark.asyncio
11+
async def test_info_endpoint(mocker):
1012
"""Test the info endpoint handler."""
1113
config_dict = {
1214
"name": "foo",
@@ -27,16 +29,32 @@ def test_info_endpoint():
2729
"feedback_enabled": False,
2830
},
2931
"customization": None,
32+
"authorization": {"access_rules": []},
33+
"authentication": {"module": "noop"},
3034
}
3135
cfg = AppConfig()
3236
cfg.init_from_dict(config_dict)
37+
38+
# Mock configuration
39+
mocker.patch("configuration.configuration", cfg)
40+
41+
# Mock authorization resolvers
42+
mock_resolvers = mocker.patch(
43+
"authorization.middleware.get_authorization_resolvers"
44+
)
45+
mock_role_resolver = mocker.AsyncMock()
46+
mock_access_resolver = mocker.AsyncMock()
47+
mock_role_resolver.resolve_roles.return_value = []
48+
mock_access_resolver.check_access.return_value = True
49+
mock_resolvers.return_value = (mock_role_resolver, mock_access_resolver)
50+
3351
request = Request(
3452
scope={
3553
"type": "http",
3654
}
3755
)
3856
auth = ("test_user", "token", {})
39-
response = info_endpoint_handler(auth, request)
57+
response = await info_endpoint_handler(auth=auth, _request=request)
4058
assert response is not None
4159
assert response.name is not None
4260
assert response.version is not None

0 commit comments

Comments
 (0)