Skip to content

Commit 15d9b51

Browse files
Refresh tokens and endpoint
1 parent e7afadc commit 15d9b51

File tree

6 files changed

+102
-18
lines changed

6 files changed

+102
-18
lines changed

Cargo.lock

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ log = "0.4.22"
1414
ring = "0.17.8"
1515
rustls = { version = "0.23.12", features = ["ring"], default-features = false, optional = true }
1616
serde = { version = "1.0.209", features = ["derive"] }
17+
serde_repr = "0.1.19"
1718
simplelog = "0.12.2"
1819
sqlite = "0.36.1"
1920
tokio = { version = "1.40.0", features = ["full"] }

config.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@ placeholders = true
2525

2626
[auth]
2727
route = "/auth"
28+
refresh_subroute = "/session"
2829
secret_path = "secret"
29-
valid_secs = 604_800 # one week
30+
# Session tokens are used to authenticate requests. They should be short-lived.
31+
# Refresh tokens are used to generate new session tokens. They should be long-lived and cached by clients.
32+
valid_secs_refresh = 604_800 # one week
33+
valid_secs_session = 900 # 15 minutes
3034

3135
[cookie]
3236
route = "/cookie"

src/auth.rs

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
use std::sync::{Arc, OnceLock};
22

3-
use axum::{extract::State, http::StatusCode, routing::post, Json, Router};
3+
use axum::{
4+
extract::State,
5+
http::{HeaderMap, StatusCode},
6+
routing::post,
7+
Json, Router,
8+
};
49
use jsonwebtoken::{get_current_timestamp, DecodingKey, EncodingKey, Validation};
510
use log::{info, warn};
611
use ring::rand::{SecureRandom, SystemRandom};
712
use serde::{Deserialize, Serialize};
13+
use serde_repr::{Deserialize_repr, Serialize_repr};
814

915
use crate::{util, AppState};
1016

1117
#[derive(Deserialize, Clone)]
1218
pub struct AuthConfig {
1319
route: String,
20+
refresh_subroute: String,
1421
secret_path: String,
15-
valid_secs: u64,
22+
valid_secs_refresh: u64,
23+
valid_secs_session: u64,
1624
}
1725

1826
#[derive(Deserialize)]
@@ -21,11 +29,19 @@ pub struct AuthRequest {
2129
password: String,
2230
}
2331

32+
#[repr(u8)]
33+
#[derive(Deserialize_repr, Serialize_repr, PartialEq, Eq)]
34+
pub enum TokenKind {
35+
Refresh = 0,
36+
Session = 1,
37+
}
38+
2439
#[derive(Deserialize, Serialize)]
2540
pub struct Claims {
26-
sub: String, // account id as a string
27-
crt: u64, // creation timestamp in UTC
28-
exp: u64, // expiration timestamp in UTC
41+
sub: String, // account id as a string
42+
crt: u64, // creation timestamp in UTC
43+
exp: u64, // expiration timestamp in UTC
44+
kind: TokenKind, // kind of token
2945
}
3046

3147
static SECRET_KEY: OnceLock<Vec<u8>> = OnceLock::new();
@@ -55,37 +71,48 @@ pub fn register(
5571
rng: &SystemRandom,
5672
) -> Router<Arc<AppState>> {
5773
let route = &config.route;
74+
let refresh_route = util::get_subroute(route, &config.refresh_subroute);
5875
info!("Registering auth route @ {}", route);
76+
info!("\tRefresh route @ {}", refresh_route);
5977
check_secret(&config.secret_path, rng);
60-
routes.route(route, post(do_auth))
78+
routes
79+
.route(route, post(do_auth))
80+
.route(&refresh_route, post(do_refresh))
6181
}
6282

63-
fn gen_jwt(account_id: i64, valid_secs: u64) -> Result<String, String> {
83+
fn gen_jwt(auth_config: &AuthConfig, account_id: i64, kind: TokenKind) -> Result<String, String> {
6484
let secret = SECRET_KEY.get().unwrap();
6585
let key = EncodingKey::from_secret(secret);
86+
87+
let valid_secs = match kind {
88+
TokenKind::Refresh => auth_config.valid_secs_refresh,
89+
TokenKind::Session => auth_config.valid_secs_session,
90+
};
91+
6692
let crt = get_current_timestamp();
6793
let exp = crt + valid_secs;
6894
let claims = Claims {
6995
sub: account_id.to_string(),
7096
crt,
7197
exp,
98+
kind,
7299
};
100+
73101
jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &key)
74102
.map_err(|e| format!("JWT error: {}", e))
75103
}
76104

77105
fn get_validator(account_id: Option<i64>) -> Validation {
78106
let mut validation = Validation::default();
79107
// required claims
80-
validation.required_spec_claims.insert("crt".to_string());
81-
validation.required_spec_claims.insert("exp".to_string());
82108
validation.required_spec_claims.insert("sub".to_string());
109+
validation.required_spec_claims.insert("exp".to_string());
83110
// ensure account ID matches if passed in
84111
validation.sub = account_id.map(|id| id.to_string());
85112
validation
86113
}
87114

88-
pub fn validate_jwt(jwt: &str) -> Result<i64, String> {
115+
pub fn validate_jwt(jwt: &str, kind: TokenKind) -> Result<i64, String> {
89116
let Some(secret) = SECRET_KEY.get() else {
90117
return Err("Auth module not initialized".to_string());
91118
};
@@ -102,6 +129,10 @@ pub fn validate_jwt(jwt: &str) -> Result<i64, String> {
102129
return Err("Expired JWT".to_string());
103130
}
104131

132+
if token.claims.kind != kind {
133+
return Err("Bad token kind".to_string());
134+
}
135+
105136
match token.claims.sub.parse() {
106137
Ok(id) => Ok(id),
107138
Err(e) => Err(format!("Bad account ID: {}", e)),
@@ -118,8 +149,11 @@ async fn do_auth(
118149
warn!("Auth error: {}", e);
119150
(StatusCode::UNAUTHORIZED, "Invalid credentials".to_string())
120151
})?;
121-
let valid_secs = app.config.auth.as_ref().unwrap().valid_secs;
122-
match gen_jwt(account_id, valid_secs) {
152+
match gen_jwt(
153+
app.config.auth.as_ref().unwrap(),
154+
account_id,
155+
TokenKind::Refresh,
156+
) {
123157
Ok(jwt) => Ok(jwt),
124158
Err(e) => {
125159
warn!("Auth error: {}", e);
@@ -130,3 +164,30 @@ async fn do_auth(
130164
}
131165
}
132166
}
167+
168+
async fn do_refresh(
169+
State(app): State<Arc<AppState>>,
170+
headers: HeaderMap,
171+
) -> Result<String, (StatusCode, String)> {
172+
assert!(app.is_tls);
173+
let db = app.db.lock().await;
174+
// TODO validate the refresh token against the last password reset timestamp
175+
let account_id = match util::validate_authed_request(&headers, TokenKind::Refresh) {
176+
Ok(id) => id,
177+
Err(e) => return Err((StatusCode::UNAUTHORIZED, e)),
178+
};
179+
match gen_jwt(
180+
app.config.auth.as_ref().unwrap(),
181+
account_id,
182+
TokenKind::Session,
183+
) {
184+
Ok(jwt) => Ok(jwt),
185+
Err(e) => {
186+
warn!("Refresh error: {}", e);
187+
Err((
188+
StatusCode::INTERNAL_SERVER_ERROR,
189+
"Server error".to_string(),
190+
))
191+
}
192+
}
193+
}

src/cookie.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use ring::rand::{SecureRandom, SystemRandom};
1313
use serde::{Deserialize, Serialize};
1414
use sqlite::Connection;
1515

16-
use crate::{util, AppState};
16+
use crate::{auth::TokenKind, util, AppState};
1717

1818
#[derive(Deserialize, Clone)]
1919
pub struct CookieConfig {
@@ -69,7 +69,7 @@ async fn get_cookie(
6969
assert!(app.is_tls);
7070

7171
let db = app.db.lock().await;
72-
let account_id = match util::validate_authed_request(&headers) {
72+
let account_id = match util::validate_authed_request(&headers, TokenKind::Session) {
7373
Ok(id) => id,
7474
Err(e) => return Err((StatusCode::UNAUTHORIZED, e)),
7575
};

src/util.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use axum::http::HeaderMap;
22
use log::info;
33
use sqlite::{Connection, State};
44

5-
use crate::auth;
5+
use crate::auth::{self, TokenKind};
66

77
const MIN_DATABASE_VERSION: i64 = 6;
88

@@ -23,6 +23,12 @@ pub fn version_to_string(version: usize) -> String {
2323
format!("{}.{}.{}", major, minor, patch)
2424
}
2525

26+
pub fn get_subroute(route: &str, subroute: &str) -> String {
27+
let route_noslash = route.trim_end_matches('/');
28+
let subroute_noslash = subroute.trim_start_matches('/');
29+
format!("{}/{}", route_noslash, subroute_noslash)
30+
}
31+
2632
pub fn connect_to_db(path: &str) -> Connection {
2733
const QUERY: &str = "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='Meta';";
2834
const VERSION_QUERY: &str = "SELECT Value FROM Meta WHERE Key = 'DatabaseVersion';";
@@ -84,7 +90,7 @@ pub fn parse_csv(data: &str) -> Vec<Vec<String>> {
8490
.collect()
8591
}
8692

87-
pub fn validate_authed_request(headers: &HeaderMap) -> Result<i64, String> {
93+
pub fn validate_authed_request(headers: &HeaderMap, kind: TokenKind) -> Result<i64, String> {
8894
let auth_header = headers.get("authorization").ok_or("No auth header")?;
8995
// auth header uses the Bearer scheme
9096
let parts: Vec<&str> = auth_header
@@ -96,7 +102,7 @@ pub fn validate_authed_request(headers: &HeaderMap) -> Result<i64, String> {
96102
return Err("Invalid auth header".to_string());
97103
}
98104
let token = parts[1];
99-
auth::validate_jwt(token)
105+
auth::validate_jwt(token, kind)
100106
}
101107

102108
pub fn find_account(db: &Connection, username: &str) -> Option<Account> {

0 commit comments

Comments
 (0)