diff --git a/docs/openapi.json b/docs/openapi.json index 9e035a40..68021384 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -689,7 +689,17 @@ } }, "400": { - "description": "Missing or invalid credentials provided by client", + "description": "Missing or invalid credentials provided by client for the noop and noop-with-token authentication modules", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnauthorizedResponse" + } + } + } + }, + "401": { + "description": "Missing or invalid credentials provided by client for the k8s authentication module", "content": { "application/json": { "schema": { @@ -922,19 +932,27 @@ "John Doe", "Adam Smith" ] + }, + "skip_userid_check": { + "type": "boolean", + "title": "Skip User Id Check", + "description": "Whether to skip the user ID check", + "examples": [true, false] } }, "type": "object", "required": [ "user_id", - "username" + "username", + "skip_userid_check" ], "title": "AuthorizedResponse", "description": "Model representing a response to an authorization request.\n\nAttributes:\n user_id: The ID of the logged in user.\n username: The name of the logged in user.", "examples": [ { "user_id": "123e4567-e89b-12d3-a456-426614174000", - "username": "user1" + "username": "user1", + "skip_userid_check": false } ] }, diff --git a/docs/openapi.md b/docs/openapi.md index 4259266a..f12f8eb1 100644 --- a/docs/openapi.md +++ b/docs/openapi.md @@ -397,7 +397,8 @@ Returns: | Status Code | Description | Component | |-------------|-------------|-----------| | 200 | The user is logged-in and authorized to access OLS | [AuthorizedResponse](#authorizedresponse) | -| 400 | Missing or invalid credentials provided by client | [UnauthorizedResponse](#unauthorizedresponse) | +| 400 | Missing or invalid credentials provided by client for noop and noop-with-token | [UnauthorizedResponse](#unauthorizedresponse) | +| 401 | Missing or invalid credentials provided by client for k8s | [UnauthorizedResponse](#unauthorizedresponse) | | 403 | User is not authorized | [ForbiddenResponse](#forbiddenresponse) | ## GET `/metrics` @@ -515,6 +516,7 @@ Attributes: |-------|------|-------------| | user_id | string | User ID, for example UUID | | username | string | User name | +| skip_userid_check | bool | Whether to skip user_id check | ## CORSConfiguration diff --git a/docs/output.md b/docs/output.md index f7e9b18e..f9e1d71f 100644 --- a/docs/output.md +++ b/docs/output.md @@ -397,7 +397,8 @@ Returns: | Status Code | Description | Component | |-------------|-------------|-----------| | 200 | The user is logged-in and authorized to access OLS | [AuthorizedResponse](#authorizedresponse) | -| 400 | Missing or invalid credentials provided by client | [UnauthorizedResponse](#unauthorizedresponse) | +| 400 | Missing or invalid credentials provided by client for noop and noop-with-token | [UnauthorizedResponse](#unauthorizedresponse) | +| 401 | Missing or invalid credentials provided by client for k8s | [UnauthorizedResponse](#unauthorizedresponse) | | 403 | User is not authorized | [ForbiddenResponse](#forbiddenresponse) | ## GET `/metrics` @@ -509,12 +510,14 @@ Model representing a response to an authorization request. Attributes: user_id: The ID of the logged in user. username: The name of the logged in user. + skip_userid_check: Whether to skip user_id check | Field | Type | Description | |-------|------|-------------| | user_id | string | User ID, for example UUID | | username | string | User name | +| skip_userid_check | bool | skip user_id check | ## CORSConfiguration diff --git a/src/app/endpoints/authorized.py b/src/app/endpoints/authorized.py index 8fa029ee..07294f02 100644 --- a/src/app/endpoints/authorized.py +++ b/src/app/endpoints/authorized.py @@ -20,7 +20,13 @@ "model": AuthorizedResponse, }, 400: { - "description": "Missing or invalid credentials provided by client", + "description": "Missing or invalid credentials provided by client for the noop and" + "noop-with-token authentication modules", + "model": UnauthorizedResponse, + }, + 401: { + "description": "Missing or invalid credentials provided by client for the" + "k8s authentication module", "model": UnauthorizedResponse, }, 403: { @@ -44,5 +50,7 @@ async def authorized_endpoint_handler( AuthorizedResponse: Contains the user ID and username of the authenticated user. """ # Ignore the user token, we should not return it in the response - user_id, user_name, _ = auth - return AuthorizedResponse(user_id=user_id, username=user_name) + user_id, user_name, skip_userid_check, _ = auth + return AuthorizedResponse( + user_id=user_id, username=user_name, skip_userid_check=skip_userid_check + ) diff --git a/src/app/endpoints/feedback.py b/src/app/endpoints/feedback.py index 2e4ee22c..f809d4c7 100644 --- a/src/app/endpoints/feedback.py +++ b/src/app/endpoints/feedback.py @@ -110,7 +110,7 @@ async def feedback_endpoint_handler( """ logger.debug("Feedback received %s", str(feedback_request)) - user_id, _, _ = auth + user_id, _, _, _ = auth try: store_feedback(user_id, feedback_request.model_dump(exclude={"model_config"})) except Exception as e: @@ -195,7 +195,7 @@ async def update_feedback_status( Returns: FeedbackStatusUpdateResponse: Indicates whether feedback is enabled. """ - user_id, _, _ = auth + user_id, _, _, _ = auth requested_status = feedback_update_request.get_value() with feedback_status_lock: diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 1f4e6967..5705ba26 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -180,7 +180,7 @@ async def query_endpoint_handler( # log Llama Stack configuration logger.info("Llama stack config: %s", configuration.llama_stack_configuration) - user_id, _, token = auth + user_id, _, _, token = auth user_conversation: UserConversation | None = None if query_request.conversation_id: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index fe2fd7bd..d6007e96 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -556,7 +556,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals # log Llama Stack configuration logger.info("Llama stack config: %s", configuration.llama_stack_configuration) - user_id, _user_name, token = auth + user_id, _user_name, _skip_userid_check, token = auth user_conversation: UserConversation | None = None if query_request.conversation_id: diff --git a/src/auth/interface.py b/src/auth/interface.py index dcc5cca5..0e5295f6 100644 --- a/src/auth/interface.py +++ b/src/auth/interface.py @@ -9,15 +9,26 @@ from fastapi import Request -from constants import DEFAULT_USER_NAME, DEFAULT_USER_UID, NO_USER_TOKEN +from constants import ( + DEFAULT_USER_NAME, + DEFAULT_SKIP_USER_ID_CHECK, + DEFAULT_USER_UID, + NO_USER_TOKEN, +) UserID = str UserName = str +SkipUserIdCheck = bool Token = str -AuthTuple = tuple[UserID, UserName, Token] +AuthTuple = tuple[UserID, UserName, SkipUserIdCheck, Token] -NO_AUTH_TUPLE: AuthTuple = (DEFAULT_USER_UID, DEFAULT_USER_NAME, NO_USER_TOKEN) +NO_AUTH_TUPLE: AuthTuple = ( + DEFAULT_USER_UID, + DEFAULT_USER_NAME, + DEFAULT_SKIP_USER_ID_CHECK, + NO_USER_TOKEN, +) class AuthInterface(ABC): # pylint: disable=too-few-public-methods diff --git a/src/auth/jwk_token.py b/src/auth/jwk_token.py index f892b222..e18ab0ae 100644 --- a/src/auth/jwk_token.py +++ b/src/auth/jwk_token.py @@ -118,6 +118,7 @@ def __init__( """Initialize the required allowed paths for authorization checks.""" self.virtual_path: str = virtual_path self.config: JwkConfiguration = config + self.skip_userid_check = False async def __call__(self, request: Request) -> AuthTuple: """Authenticate the JWT in the headers against the keys from the JWK url.""" @@ -190,4 +191,4 @@ async def __call__(self, request: Request) -> AuthTuple: logger.info("Successfully authenticated user %s (ID: %s)", username, user_id) - return user_id, username, user_token + return user_id, username, self.skip_userid_check, user_token diff --git a/src/auth/k8s.py b/src/auth/k8s.py index 5ff43dde..3479cb40 100644 --- a/src/auth/k8s.py +++ b/src/auth/k8s.py @@ -11,7 +11,6 @@ from kubernetes.config import ConfigException from configuration import configuration -from auth.utils import extract_user_token from auth.interface import AuthInterface from constants import DEFAULT_VIRTUAL_PATH @@ -227,8 +226,9 @@ class K8SAuthDependency(AuthInterface): # pylint: disable=too-few-public-method def __init__(self, virtual_path: str = DEFAULT_VIRTUAL_PATH) -> None: """Initialize the required allowed paths for authorization checks.""" self.virtual_path = virtual_path + self.skip_userid_check = False - async def __call__(self, request: Request) -> tuple[str, str, str]: + async def __call__(self, request: Request) -> tuple[str, str, bool, str]: """Validate FastAPI Requests for authentication and authorization. Args: @@ -236,9 +236,23 @@ async def __call__(self, request: Request) -> tuple[str, str, str]: Returns: The user's UID and username if authentication and authorization succeed - user_id check is skipped with noop auth to allow consumers provide user_id + user_id check should never be skipped with K8s authentication + If user_id check should be skipped - always return False for k8s + User's token """ - token = extract_user_token(request.headers) + authorization_header = request.headers.get("Authorization") + if not authorization_header: + raise HTTPException( + status_code=401, detail="Unauthorized: No auth header found" + ) + + token = _extract_bearer_token(authorization_header) + if not token: + raise HTTPException( + status_code=401, + detail="Unauthorized: Bearer token not found or invalid", + ) + user_info = get_user_info(token) if user_info is None: raise HTTPException( @@ -267,4 +281,9 @@ async def __call__(self, request: Request) -> tuple[str, str, str]: logger.error("API exception during SubjectAccessReview: %s", e) raise HTTPException(status_code=403, detail="Internal server error") from e - return user_info.user.uid, user_info.user.username, token + return ( + user_info.user.uid, + user_info.user.username, + self.skip_userid_check, + token, + ) diff --git a/src/auth/noop.py b/src/auth/noop.py index 6e55fe72..07da498d 100644 --- a/src/auth/noop.py +++ b/src/auth/noop.py @@ -21,8 +21,9 @@ class NoopAuthDependency(AuthInterface): # pylint: disable=too-few-public-metho def __init__(self, virtual_path: str = DEFAULT_VIRTUAL_PATH) -> None: """Initialize the required allowed paths for authorization checks.""" self.virtual_path = virtual_path + self.skip_userid_check = True - async def __call__(self, request: Request) -> tuple[str, str, str]: + async def __call__(self, request: Request) -> tuple[str, str, bool, str]: """Validate FastAPI Requests for authentication and authorization. Args: @@ -39,4 +40,4 @@ async def __call__(self, request: Request) -> tuple[str, str, str]: # try to extract user ID from request user_id = request.query_params.get("user_id", DEFAULT_USER_UID) logger.debug("Retrieved user ID: %s", user_id) - return user_id, DEFAULT_USER_NAME, NO_USER_TOKEN + return user_id, DEFAULT_USER_NAME, self.skip_userid_check, NO_USER_TOKEN diff --git a/src/auth/noop_with_token.py b/src/auth/noop_with_token.py index 27937e4e..5bbd3c77 100644 --- a/src/auth/noop_with_token.py +++ b/src/auth/noop_with_token.py @@ -32,8 +32,9 @@ class NoopWithTokenAuthDependency( def __init__(self, virtual_path: str = DEFAULT_VIRTUAL_PATH) -> None: """Initialize the required allowed paths for authorization checks.""" self.virtual_path = virtual_path + self.skip_userid_check = True - async def __call__(self, request: Request) -> tuple[str, str, str]: + async def __call__(self, request: Request) -> tuple[str, str, bool, str]: """Validate FastAPI Requests for authentication and authorization. Args: @@ -52,4 +53,4 @@ async def __call__(self, request: Request) -> tuple[str, str, str]: # try to extract user ID from request user_id = request.query_params.get("user_id", DEFAULT_USER_UID) logger.debug("Retrieved user ID: %s", user_id) - return user_id, DEFAULT_USER_NAME, user_token + return user_id, DEFAULT_USER_NAME, self.skip_userid_check, user_token diff --git a/src/authorization/resolvers.py b/src/authorization/resolvers.py index 8e330275..760f55d7 100644 --- a/src/authorization/resolvers.py +++ b/src/authorization/resolvers.py @@ -81,7 +81,7 @@ def evaluate_role_rules(rule: JwtRoleRule, jwt_claims: dict[str, Any]) -> UserRo @staticmethod def _get_claims(auth: AuthTuple) -> dict[str, Any]: """Get the JWT claims from the auth tuple.""" - _, _, token = auth + _, _, _, token = auth if token == constants.NO_USER_TOKEN: # No claims for guests return {} diff --git a/src/constants.py b/src/constants.py index 2260d510..f5982b44 100644 --- a/src/constants.py +++ b/src/constants.py @@ -31,6 +31,7 @@ # Authentication constants DEFAULT_VIRTUAL_PATH = "/ls-access" DEFAULT_USER_NAME = "lightspeed-user" +DEFAULT_SKIP_USER_ID_CHECK = True DEFAULT_USER_UID = "00000000-0000-0000-0000-000" # default value for token when no token is provided NO_USER_TOKEN = "" diff --git a/src/models/responses.py b/src/models/responses.py index 9d1a0ac0..d2e85740 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -307,6 +307,7 @@ class AuthorizedResponse(BaseModel): Attributes: user_id: The ID of the logged in user. username: The name of the logged in user. + skip_userid_check: Whether to skip the user ID check. """ user_id: str = Field( @@ -319,6 +320,11 @@ class AuthorizedResponse(BaseModel): description="User name", examples=["John Doe", "Adam Smith"], ) + skip_userid_check: bool = Field( + ..., + description="Whether to skip the user ID check", + examples=[True, False], + ) # provides examples for /docs endpoint model_config = { @@ -327,6 +333,7 @@ class AuthorizedResponse(BaseModel): { "user_id": "123e4567-e89b-12d3-a456-426614174000", "username": "user1", + "skip_userid_check": False, } ] } diff --git a/tests/e2e/features/authorized_noop.feature b/tests/e2e/features/authorized_noop.feature index 9ce63782..9f04f7d4 100644 --- a/tests/e2e/features/authorized_noop.feature +++ b/tests/e2e/features/authorized_noop.feature @@ -15,7 +15,7 @@ Feature: Authorized endpoint API tests for the noop authentication module Then The status code of the response is 200 And The body of the response is the following """ - {"user_id": "00000000-0000-0000-0000-000","username": "lightspeed-user"} + {"user_id": "00000000-0000-0000-0000-000","username": "lightspeed-user","skip_userid_check": true} """ Scenario: Check if the authorized endpoint works when auth token is not provided @@ -24,7 +24,7 @@ Feature: Authorized endpoint API tests for the noop authentication module Then The status code of the response is 200 And The body of the response is the following """ - {"user_id": "test_user","username": "lightspeed-user"} + {"user_id": "test_user","username": "lightspeed-user","skip_userid_check": true} """ Scenario: Check if the authorized endpoint works when user_id is not provided @@ -34,7 +34,7 @@ Feature: Authorized endpoint API tests for the noop authentication module Then The status code of the response is 200 And The body of the response is the following """ - {"user_id": "00000000-0000-0000-0000-000","username": "lightspeed-user"} + {"user_id": "00000000-0000-0000-0000-000","username": "lightspeed-user","skip_userid_check": true} """ Scenario: Check if the authorized endpoint works when providing empty user_id @@ -44,7 +44,7 @@ Feature: Authorized endpoint API tests for the noop authentication module Then The status code of the response is 200 And The body of the response is the following """ - {"user_id": "","username": "lightspeed-user"} + {"user_id": "","username": "lightspeed-user","skip_userid_check": true} """ Scenario: Check if the authorized endpoint works when providing proper user_id @@ -54,5 +54,5 @@ Feature: Authorized endpoint API tests for the noop authentication module Then The status code of the response is 200 And The body of the response is the following """ - {"user_id": "test_user","username": "lightspeed-user"} + {"user_id": "test_user","username": "lightspeed-user","skip_userid_check": true} """ \ No newline at end of file diff --git a/tests/e2e/features/authorized_noop_token.feature b/tests/e2e/features/authorized_noop_token.feature index a57d2211..d324f977 100644 --- a/tests/e2e/features/authorized_noop_token.feature +++ b/tests/e2e/features/authorized_noop_token.feature @@ -1,5 +1,5 @@ @Authorized -Feature: Authorized endpoint API tests for the noop-with-token +Feature: Authorized endpoint API tests for the noop-with-token authentication module Background: Given The service is started locally @@ -26,7 +26,7 @@ Feature: Authorized endpoint API tests for the noop-with-token Then The status code of the response is 200 And The body of the response is the following """ - {"user_id": "00000000-0000-0000-0000-000","username": "lightspeed-user"} + {"user_id": "00000000-0000-0000-0000-000","username": "lightspeed-user","skip_userid_check": true} """ Scenario: Check if the authorized endpoint works when providing empty user_id @@ -36,7 +36,7 @@ Feature: Authorized endpoint API tests for the noop-with-token Then The status code of the response is 200 And The body of the response is the following """ - {"user_id": "","username": "lightspeed-user"} + {"user_id": "","username": "lightspeed-user","skip_userid_check": true} """ Scenario: Check if the authorized endpoint works when providing proper user_id @@ -46,7 +46,7 @@ Feature: Authorized endpoint API tests for the noop-with-token Then The status code of the response is 200 And The body of the response is the following """ - {"user_id": "test_user","username": "lightspeed-user"} + {"user_id": "test_user","username": "lightspeed-user","skip_userid_check": true} """ Scenario: Check if the authorized endpoint works with proper user_id but bearer token is not present diff --git a/tests/integration/test_openapi_json.py b/tests/integration/test_openapi_json.py index 6051e00d..15e7d87c 100644 --- a/tests/integration/test_openapi_json.py +++ b/tests/integration/test_openapi_json.py @@ -71,7 +71,7 @@ def test_servers_section_present(spec: dict): ("/v1/conversations/{conversation_id}", "delete", {"200", "404", "503", "422"}), ("/readiness", "get", {"200", "503"}), ("/liveness", "get", {"200"}), - ("/authorized", "post", {"200", "400", "403"}), + ("/authorized", "post", {"200", "400", "401", "403"}), ("/metrics", "get", {"200"}), ], ) diff --git a/tests/unit/app/endpoints/test_authorized.py b/tests/unit/app/endpoints/test_authorized.py index a1d4144d..160bbe1e 100644 --- a/tests/unit/app/endpoints/test_authorized.py +++ b/tests/unit/app/endpoints/test_authorized.py @@ -7,7 +7,7 @@ from app.endpoints.authorized import authorized_endpoint_handler from auth.utils import extract_user_token -MOCK_AUTH = ("test-id", "test-user", "token") +MOCK_AUTH = ("test-id", "test-user", True, "token") @pytest.mark.asyncio @@ -18,6 +18,7 @@ async def test_authorized_endpoint(): assert response.model_dump() == { "user_id": "test-id", "username": "test-user", + "skip_userid_check": True, } diff --git a/tests/unit/app/endpoints/test_feedback.py b/tests/unit/app/endpoints/test_feedback.py index e462999d..1f13b031 100644 --- a/tests/unit/app/endpoints/test_feedback.py +++ b/tests/unit/app/endpoints/test_feedback.py @@ -82,7 +82,7 @@ async def test_feedback_endpoint_handler(mocker, feedback_request_data): result = await feedback_endpoint_handler( feedback_request=feedback_request, _ensure_feedback_enabled=assert_feedback_enabled, - auth=("test_user_id", "test_username", "test_token"), + auth=("test_user_id", "test_username", False, "test_token"), ) # Assert that the expected response is returned @@ -109,7 +109,7 @@ async def test_feedback_endpoint_handler_error(mocker): await feedback_endpoint_handler( feedback_request=feedback_request, _ensure_feedback_enabled=assert_feedback_enabled, - auth=("test_user_id", "test_username", "test_token"), + auth=("test_user_id", "test_username", False, "test_token"), ) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -204,7 +204,7 @@ async def test_update_feedback_status_different(mocker): req = FeedbackStatusUpdateRequest(status=False) resp = await update_feedback_status( req, - auth=("test_user_id", "test_username", "test_token"), + auth=("test_user_id", "test_username", False, "test_token"), ) assert resp.status == { "previous_status": True, @@ -221,7 +221,7 @@ async def test_update_feedback_status_no_change(mocker): req = FeedbackStatusUpdateRequest(status=True) resp = await update_feedback_status( req, - auth=("test_user_id", "test_username", "test_token"), + auth=("test_user_id", "test_username", False, "test_token"), ) assert resp.status == { "previous_status": True, diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index a1aee5b1..3b3d64f3 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -31,7 +31,7 @@ from utils.types import ToolCallSummary, TurnSummary from authorization.resolvers import NoopRolesResolver -MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") +MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") @pytest.fixture @@ -108,7 +108,7 @@ async def test_query_endpoint_handler_configuration_not_loaded(mocker, dummy_req await query_endpoint_handler( query_request=query_request, request=dummy_request, - auth=["test-user", "", "token"], + auth=("test-user", "", False, "token"), ) assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.value.detail["response"] == "Configuration is not loaded" @@ -1214,7 +1214,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker, dummy_requ _ = await query_endpoint_handler( request=dummy_request, query_request=QueryRequest(query="test query"), - auth=("user123", "username", "auth_token_123"), + auth=("user123", "username", False, "auth_token_123"), mcp_headers=None, ) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 38888c54..38983666 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -47,7 +47,7 @@ from authorization.resolvers import NoopRolesResolver from utils.types import ToolCallSummary, TurnSummary -MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token") +MOCK_AUTH = ("mock_user_id", "mock_username", False, "mock_token") def mock_database_operations(mocker): @@ -1318,7 +1318,7 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): await streaming_query_endpoint_handler( request, QueryRequest(query="test query"), - auth=("user123", "username", "auth_token_123"), + auth=("user123", "username", False, "auth_token_123"), mcp_headers=None, ) diff --git a/tests/unit/auth/test_jwk_token.py b/tests/unit/auth/test_jwk_token.py index cb76e7d3..296e17f5 100644 --- a/tests/unit/auth/test_jwk_token.py +++ b/tests/unit/auth/test_jwk_token.py @@ -174,9 +174,10 @@ def set_auth_header(request: Request, token: str): def ensure_test_user_id_and_name(auth_tuple, expected_token): """Utility to ensure that the values in the auth tuple match the test values.""" - user_id, username, token = auth_tuple + user_id, username, skip_userid_check, token = auth_tuple assert user_id == TEST_USER_ID assert username == TEST_USER_NAME + assert skip_userid_check is False assert token == expected_token @@ -259,10 +260,13 @@ async def test_no_auth_header( dependency = JwkTokenAuthDependency(default_jwk_configuration) - user_id, username, token_claims = await dependency(no_token_request) + user_id, username, skip_userid_check, token_claims = await dependency( + no_token_request + ) assert user_id == DEFAULT_USER_UID assert username == DEFAULT_USER_NAME + assert skip_userid_check is True assert token_claims == NO_USER_TOKEN diff --git a/tests/unit/auth/test_k8s.py b/tests/unit/auth/test_k8s.py index e7e6b477..85608fe9 100644 --- a/tests/unit/auth/test_k8s.py +++ b/tests/unit/auth/test_k8s.py @@ -94,11 +94,12 @@ async def test_auth_dependency_valid_token(mocker): } ) - user_uid, username, token = await dependency(request) + user_uid, username, skip_userid_check, token = await dependency(request) # Check if the correct user info has been returned assert user_uid == "valid-uid" assert username == "valid-user" + assert skip_userid_check is False assert token == "valid-token" @@ -165,11 +166,12 @@ async def test_cluster_id_is_used_for_kube_admin(mocker): return_value="some-cluster-id", ) - user_uid, username, token = await dependency(request) + user_uid, username, skip_userid_check, token = await dependency(request) # check if the correct user info has been returned assert user_uid == "some-cluster-id" assert username == "kube:admin" + assert skip_userid_check is False assert token == "valid-token" diff --git a/tests/unit/auth/test_noop.py b/tests/unit/auth/test_noop.py index 77a30f01..0ead030e 100644 --- a/tests/unit/auth/test_noop.py +++ b/tests/unit/auth/test_noop.py @@ -13,11 +13,12 @@ async def test_noop_auth_dependency(): request = Request(scope={"type": "http", "query_string": b""}) # Call the dependency - user_id, username, user_token = await dependency(request) + user_id, username, skip_userid_check, user_token = await dependency(request) # Assert the expected values assert user_id == DEFAULT_USER_UID assert username == DEFAULT_USER_NAME + assert skip_userid_check is True assert user_token == NO_USER_TOKEN @@ -29,9 +30,10 @@ async def test_noop_auth_dependency_custom_user_id(): request = Request(scope={"type": "http", "query_string": b"user_id=test-user"}) # Call the dependency - user_id, username, user_token = await dependency(request) + user_id, username, skip_userid_check, user_token = await dependency(request) # Assert the expected values assert user_id == "test-user" assert username == DEFAULT_USER_NAME + assert skip_userid_check is True assert user_token == NO_USER_TOKEN diff --git a/tests/unit/auth/test_noop_with_token.py b/tests/unit/auth/test_noop_with_token.py index 4c3bc77e..c5003f20 100644 --- a/tests/unit/auth/test_noop_with_token.py +++ b/tests/unit/auth/test_noop_with_token.py @@ -22,11 +22,12 @@ async def test_noop_with_token_auth_dependency(): ) # Call the dependency - user_id, username, user_token = await dependency(request) + user_id, username, skip_userid_check, user_token = await dependency(request) # Assert the expected values assert user_id == DEFAULT_USER_UID assert username == DEFAULT_USER_NAME + assert skip_userid_check is True assert user_token == "spongebob-token" @@ -46,11 +47,12 @@ async def test_noop_with_token_auth_dependency_custom_user_id(): ) # Call the dependency - user_id, username, user_token = await dependency(request) + user_id, username, skip_userid_check, user_token = await dependency(request) # Assert the expected values assert user_id == "test-user" assert username == DEFAULT_USER_NAME + assert skip_userid_check is True assert user_token == "spongebob-token" diff --git a/tests/unit/authorization/test_resolvers.py b/tests/unit/authorization/test_resolvers.py index f8b9720a..d0dbf36e 100644 --- a/tests/unit/authorization/test_resolvers.py +++ b/tests/unit/authorization/test_resolvers.py @@ -50,7 +50,7 @@ async def test_resolve_roles_redhat_employee(self): } # Mock auth tuple with JWT claims as third element - auth = ("user", "token", claims_to_token(jwt_claims)) + auth = ("user", "token", False, claims_to_token(jwt_claims)) roles = await jwt_resolver.resolve_roles(auth) assert "employee" in roles @@ -74,7 +74,7 @@ async def test_resolve_roles_no_match(self): } # Mock auth tuple with JWT claims as third element - auth = ("user", "token", claims_to_token(jwt_claims)) + auth = ("user", "token", False, claims_to_token(jwt_claims)) roles = await jwt_resolver.resolve_roles(auth) assert len(roles) == 0 @@ -97,7 +97,7 @@ async def test_resolve_roles_match_operator_email_domain(self): "email": "employee@redhat.com", } - auth = ("user", "token", claims_to_token(jwt_claims)) + auth = ("user", "token", False, claims_to_token(jwt_claims)) roles = await jwt_resolver.resolve_roles(auth) assert "redhat_employee" in roles @@ -120,7 +120,7 @@ async def test_resolve_roles_match_operator_no_match(self): "email": "user@example.com", } - auth = ("user", "token", claims_to_token(jwt_claims)) + auth = ("user", "token", False, claims_to_token(jwt_claims)) roles = await jwt_resolver.resolve_roles(auth) assert len(roles) == 0 @@ -166,7 +166,7 @@ async def test_resolve_roles_match_operator_non_string_value(self): "user_id": 12345, # Non-string value } - auth = ("user", "token", claims_to_token(jwt_claims)) + auth = ("user", "token", False, claims_to_token(jwt_claims)) roles = await jwt_resolver.resolve_roles(auth) assert len(roles) == 0 # Non-string values don't match regex diff --git a/tests/unit/models/responses/test_authorized_response.py b/tests/unit/models/responses/test_authorized_response.py index 08721317..f3412462 100644 --- a/tests/unit/models/responses/test_authorized_response.py +++ b/tests/unit/models/responses/test_authorized_response.py @@ -15,6 +15,7 @@ def test_constructor(self) -> None: ar = AuthorizedResponse( user_id="123e4567-e89b-12d3-a456-426614174000", username="testuser", + skip_userid_check=False, ) assert ar.user_id == "123e4567-e89b-12d3-a456-426614174000" assert ar.username == "testuser" diff --git a/tests/unit/models/test_responses.py b/tests/unit/models/test_responses.py new file mode 100644 index 00000000..42ca4bea --- /dev/null +++ b/tests/unit/models/test_responses.py @@ -0,0 +1,88 @@ +"""Tests for QueryResponse. StatusResponse, AuthorizedResponse, and UnauthorizedResponse models.""" + +import pytest + +from pydantic import ValidationError + +from models.responses import ( + QueryResponse, + StatusResponse, + AuthorizedResponse, + UnauthorizedResponse, +) + + +class TestQueryResponse: + """Test cases for the QueryResponse model.""" + + def test_constructor(self) -> None: + """Test the QueryResponse constructor.""" + qr = QueryResponse( + conversation_id="123e4567-e89b-12d3-a456-426614174000", + response="LLM answer", + ) + assert qr.conversation_id == "123e4567-e89b-12d3-a456-426614174000" + assert qr.response == "LLM answer" + + def test_optional_conversation_id(self) -> None: + """Test the QueryResponse with default conversation ID.""" + qr = QueryResponse(response="LLM answer") + assert qr.conversation_id is None + assert qr.response == "LLM answer" + + +class TestStatusResponse: + """Test cases for the StatusResponse model.""" + + def test_constructor_feedback_enabled(self) -> None: + """Test the StatusResponse constructor.""" + sr = StatusResponse(functionality="feedback", status={"enabled": True}) + assert sr.functionality == "feedback" + assert sr.status == {"enabled": True} + + def test_constructor_feedback_disabled(self) -> None: + """Test the StatusResponse constructor.""" + sr = StatusResponse(functionality="feedback", status={"enabled": False}) + assert sr.functionality == "feedback" + assert sr.status == {"enabled": False} + + +class TestAuthorizedResponse: + """Test cases for the AuthorizedResponse model.""" + + def test_constructor(self) -> None: + """Test the AuthorizedResponse constructor.""" + ar = AuthorizedResponse( + user_id="123e4567-e89b-12d3-a456-426614174000", + username="testuser", + skip_userid_check=True, + ) + assert ar.user_id == "123e4567-e89b-12d3-a456-426614174000" + assert ar.username == "testuser" + assert ar.skip_userid_check is True + + def test_constructor_fields_required(self) -> None: + """Test the AuthorizedResponse constructor.""" + with pytest.raises(ValidationError): + AuthorizedResponse(username="testuser") # pyright: ignore + + with pytest.raises(ValidationError): + AuthorizedResponse( + user_id="123e4567-e89b-12d3-a456-426614174000" + ) # pyright: ignore + + +class TestUnauthorizedResponse: + """Test cases for the UnauthorizedResponse model.""" + + def test_constructor(self) -> None: + """Test the UnauthorizedResponse constructor.""" + ur = UnauthorizedResponse( + detail="Missing or invalid credentials provided by client" + ) + assert ur.detail == "Missing or invalid credentials provided by client" + + def test_constructor_fields_required(self) -> None: + """Test the UnauthorizedResponse constructor.""" + with pytest.raises(Exception): + UnauthorizedResponse() # pyright: ignore