Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bin/router/src/jwt/context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};

use hive_router_plan_executor::execution::jwt_forward::JwtForwardingError;
use jsonwebtoken::TokenData;
Expand All @@ -10,7 +10,7 @@ pub type JwtTokenPayload = TokenData<JwtClaims>;
pub struct JwtRequestContext {
pub token_prefix: Option<String>,
pub token_raw: String,
pub token_payload: JwtTokenPayload,
pub token_payload: Arc<JwtTokenPayload>,
}

impl JwtRequestContext {
Expand Down
4 changes: 2 additions & 2 deletions bin/router/src/jwt/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use ntex::{
web,
};

#[derive(Debug, thiserror::Error)]
#[derive(Debug, thiserror::Error, Clone)]
pub enum LookupError {
#[error("failed to locate the value in the incoming request")]
LookupFailed,
Expand All @@ -21,7 +21,7 @@ pub enum LookupError {
FailedToParseHeader(InvalidHeaderValue),
}

#[derive(Debug, thiserror::Error)]
#[derive(Debug, thiserror::Error, Clone)]
pub enum JwtError {
#[error("jwt header lookup failed: {0}")]
LookupFailed(LookupError),
Expand Down
44 changes: 33 additions & 11 deletions bin/router/src/jwt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::{
errors::{JwtError, LookupError},
jwks_manager::{JwksManager, JwksSourceError},
},
shared_state::JwtClaimsCache,
};

pub struct JwtAuthRuntime {
Expand Down Expand Up @@ -265,26 +266,47 @@ impl JwtAuthRuntime {
Ok(token_data)
}

pub fn validate_request(&self, request: &mut HttpRequest) -> Result<(), JwtError> {
let valid_jwks = self.jwks.all();
pub async fn validate_request(
&self,
request: &mut HttpRequest,
cache: &JwtClaimsCache,
) -> Result<(), JwtError> {
let (maybe_prefix, token) = match self.lookup(request) {
Ok((p, t)) => (p, t),
Err(e) => {
// No token found, but this is only an error if auth is required.
if self.config.require_authentication.is_some_and(|v| v) {
return Err(JwtError::LookupFailed(e));
}
return Ok(());
}
};

let validation_result = cache
.try_get_with(token.clone(), async {
let valid_jwks = self.jwks.all();
self.authenticate(&valid_jwks, request)
.map(|(payload, _, _)| Arc::new(payload))
})
Comment on lines +286 to +290
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This implementation has a critical issue: it attempts to pass &mut HttpRequest into an async block. HttpRequest is not Send, so this will cause a compilation error because the request cannot be safely sent across threads.

Additionally, the current approach is inefficient. validate_request calls self.lookup(request) to get the token, and then the async block calls self.authenticate(...), which in turn calls self.lookup(request) again, performing the same work twice on a cache miss.

To fix this, the authentication logic should be performed directly on the token string that has already been extracted. This avoids the compilation error and the redundant work.

Suggested change
.try_get_with(token.clone(), async {
let valid_jwks = self.jwks.all();
self.authenticate(&valid_jwks, request)
.map(|(payload, _, _)| Arc::new(payload))
})
.try_get_with(token.clone(), async {
let valid_jwks = self.jwks.all();
let header = decode_header(&token).map_err(JwtError::InvalidJwtHeader)?;
let jwk = self.find_matching_jwks(&header, &valid_jwks)?;
let payload = self.decode_and_validate_token(&token, &jwk.keys)?;
Ok(Arc::new(payload))
})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine

.await;

match self.authenticate(&valid_jwks, request) {
Ok((token_payload, maybe_token_prefix, token)) => {
match validation_result {
Ok(token_payload) => {
request.extensions_mut().insert(JwtRequestContext {
token_payload,
token_raw: token,
token_prefix: maybe_token_prefix,
token_prefix: maybe_prefix,
});
Ok(())
}
Err(e) => {
warn!("jwt token error: {:?}", e);

Err(err) => {
warn!("jwt token error: {:?}", err);
if self.config.require_authentication.is_some_and(|v| v) {
return Err(e);
Err((*err).clone())
} else {
Ok(())
}
}
}

Ok(())
}
}
5 changes: 4 additions & 1 deletion bin/router/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ pub async fn graphql_request_handler(
}

if let Some(jwt) = &shared_state.jwt_auth_runtime {
match jwt.validate_request(req) {
match jwt
.validate_request(req, &shared_state.jwt_claims_cache)
.await
{
Ok(_) => (),
Err(err) => return err.make_response(),
}
Expand Down
13 changes: 13 additions & 0 deletions bin/router/src/shared_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,24 @@ use hive_router_plan_executor::headers::{
use moka::future::Cache;
use std::sync::Arc;

use crate::jwt::context::JwtTokenPayload;
use crate::jwt::JwtAuthRuntime;
use crate::pipeline::cors::{CORSConfigError, Cors};
use crate::pipeline::progressive_override::{OverrideLabelsCompileError, OverrideLabelsEvaluator};

pub type JwtClaimsCache = Cache<String, Arc<JwtTokenPayload>>;

pub struct RouterSharedState {
pub validation_plan: ValidationPlan,
pub parse_cache: Cache<u64, Arc<graphql_parser::query::Document<'static, String>>>,
pub router_config: Arc<HiveRouterConfig>,
pub headers_plan: HeaderRulesPlan,
pub override_labels_evaluator: OverrideLabelsEvaluator,
pub cors_runtime: Option<Cors>,
/// Cache for validated JWT claims to avoid re-parsing on every request.
/// The cache key is the raw JWT token string.
/// Stores the parsed claims payload for 5s.
pub jwt_claims_cache: JwtClaimsCache,
pub jwt_auth_runtime: Option<JwtAuthRuntime>,
}

Expand All @@ -30,6 +37,12 @@ impl RouterSharedState {
headers_plan: compile_headers_plan(&router_config.headers).map_err(Box::new)?,
parse_cache: moka::future::Cache::new(1000),
cors_runtime: Cors::from_config(&router_config.cors).map_err(Box::new)?,
jwt_claims_cache: Cache::builder()
// Consistent with parse_cache and prevents unbounded memory usage.
.max_capacity(1000)
// We can have it configurable in the future if needed.
.time_to_live(std::time::Duration::from_secs(5))
.build(),
router_config: router_config.clone(),
override_labels_evaluator: OverrideLabelsEvaluator::from_config(
&router_config.override_labels,
Expand Down
Loading