1
1
use std:: sync:: { Arc , OnceLock } ;
2
2
3
- use axum:: { extract:: State , http:: StatusCode , routing:: post, Json , Router } ;
3
+ use axum:: {
4
+ extract:: State ,
5
+ http:: { HeaderMap , StatusCode } ,
6
+ routing:: post,
7
+ Json , Router ,
8
+ } ;
4
9
use jsonwebtoken:: { get_current_timestamp, DecodingKey , EncodingKey , Validation } ;
5
10
use log:: { info, warn} ;
6
11
use ring:: rand:: { SecureRandom , SystemRandom } ;
7
12
use serde:: { Deserialize , Serialize } ;
13
+ use serde_repr:: { Deserialize_repr , Serialize_repr } ;
8
14
9
15
use crate :: { util, AppState } ;
10
16
11
17
#[ derive( Deserialize , Clone ) ]
12
18
pub struct AuthConfig {
13
19
route : String ,
20
+ refresh_subroute : String ,
14
21
secret_path : String ,
15
- valid_secs : u64 ,
22
+ valid_secs_refresh : u64 ,
23
+ valid_secs_session : u64 ,
16
24
}
17
25
18
26
#[ derive( Deserialize ) ]
@@ -21,11 +29,19 @@ pub struct AuthRequest {
21
29
password : String ,
22
30
}
23
31
32
+ #[ repr( u8 ) ]
33
+ #[ derive( Deserialize_repr , Serialize_repr , PartialEq , Eq ) ]
34
+ pub enum TokenKind {
35
+ Refresh = 0 ,
36
+ Session = 1 ,
37
+ }
38
+
24
39
#[ derive( Deserialize , Serialize ) ]
25
40
pub struct Claims {
26
- sub : String , // account id as a string
27
- crt : u64 , // creation timestamp in UTC
28
- exp : u64 , // expiration timestamp in UTC
41
+ sub : String , // account id as a string
42
+ crt : u64 , // creation timestamp in UTC
43
+ exp : u64 , // expiration timestamp in UTC
44
+ kind : TokenKind , // kind of token
29
45
}
30
46
31
47
static SECRET_KEY : OnceLock < Vec < u8 > > = OnceLock :: new ( ) ;
@@ -55,37 +71,48 @@ pub fn register(
55
71
rng : & SystemRandom ,
56
72
) -> Router < Arc < AppState > > {
57
73
let route = & config. route ;
74
+ let refresh_route = util:: get_subroute ( route, & config. refresh_subroute ) ;
58
75
info ! ( "Registering auth route @ {}" , route) ;
76
+ info ! ( "\t Refresh route @ {}" , refresh_route) ;
59
77
check_secret ( & config. secret_path , rng) ;
60
- routes. route ( route, post ( do_auth) )
78
+ routes
79
+ . route ( route, post ( do_auth) )
80
+ . route ( & refresh_route, post ( do_refresh) )
61
81
}
62
82
63
- fn gen_jwt ( account_id : i64 , valid_secs : u64 ) -> Result < String , String > {
83
+ fn gen_jwt ( auth_config : & AuthConfig , account_id : i64 , kind : TokenKind ) -> Result < String , String > {
64
84
let secret = SECRET_KEY . get ( ) . unwrap ( ) ;
65
85
let key = EncodingKey :: from_secret ( secret) ;
86
+
87
+ let valid_secs = match kind {
88
+ TokenKind :: Refresh => auth_config. valid_secs_refresh ,
89
+ TokenKind :: Session => auth_config. valid_secs_session ,
90
+ } ;
91
+
66
92
let crt = get_current_timestamp ( ) ;
67
93
let exp = crt + valid_secs;
68
94
let claims = Claims {
69
95
sub : account_id. to_string ( ) ,
70
96
crt,
71
97
exp,
98
+ kind,
72
99
} ;
100
+
73
101
jsonwebtoken:: encode ( & jsonwebtoken:: Header :: default ( ) , & claims, & key)
74
102
. map_err ( |e| format ! ( "JWT error: {}" , e) )
75
103
}
76
104
77
105
fn get_validator ( account_id : Option < i64 > ) -> Validation {
78
106
let mut validation = Validation :: default ( ) ;
79
107
// required claims
80
- validation. required_spec_claims . insert ( "crt" . to_string ( ) ) ;
81
- validation. required_spec_claims . insert ( "exp" . to_string ( ) ) ;
82
108
validation. required_spec_claims . insert ( "sub" . to_string ( ) ) ;
109
+ validation. required_spec_claims . insert ( "exp" . to_string ( ) ) ;
83
110
// ensure account ID matches if passed in
84
111
validation. sub = account_id. map ( |id| id. to_string ( ) ) ;
85
112
validation
86
113
}
87
114
88
- pub fn validate_jwt ( jwt : & str ) -> Result < i64 , String > {
115
+ pub fn validate_jwt ( jwt : & str , kind : TokenKind ) -> Result < i64 , String > {
89
116
let Some ( secret) = SECRET_KEY . get ( ) else {
90
117
return Err ( "Auth module not initialized" . to_string ( ) ) ;
91
118
} ;
@@ -102,6 +129,10 @@ pub fn validate_jwt(jwt: &str) -> Result<i64, String> {
102
129
return Err ( "Expired JWT" . to_string ( ) ) ;
103
130
}
104
131
132
+ if token. claims . kind != kind {
133
+ return Err ( "Bad token kind" . to_string ( ) ) ;
134
+ }
135
+
105
136
match token. claims . sub . parse ( ) {
106
137
Ok ( id) => Ok ( id) ,
107
138
Err ( e) => Err ( format ! ( "Bad account ID: {}" , e) ) ,
@@ -118,8 +149,11 @@ async fn do_auth(
118
149
warn ! ( "Auth error: {}" , e) ;
119
150
( StatusCode :: UNAUTHORIZED , "Invalid credentials" . to_string ( ) )
120
151
} ) ?;
121
- let valid_secs = app. config . auth . as_ref ( ) . unwrap ( ) . valid_secs ;
122
- match gen_jwt ( account_id, valid_secs) {
152
+ match gen_jwt (
153
+ app. config . auth . as_ref ( ) . unwrap ( ) ,
154
+ account_id,
155
+ TokenKind :: Refresh ,
156
+ ) {
123
157
Ok ( jwt) => Ok ( jwt) ,
124
158
Err ( e) => {
125
159
warn ! ( "Auth error: {}" , e) ;
@@ -130,3 +164,30 @@ async fn do_auth(
130
164
}
131
165
}
132
166
}
167
+
168
+ async fn do_refresh (
169
+ State ( app) : State < Arc < AppState > > ,
170
+ headers : HeaderMap ,
171
+ ) -> Result < String , ( StatusCode , String ) > {
172
+ assert ! ( app. is_tls) ;
173
+ let db = app. db . lock ( ) . await ;
174
+ // TODO validate the refresh token against the last password reset timestamp
175
+ let account_id = match util:: validate_authed_request ( & headers, TokenKind :: Refresh ) {
176
+ Ok ( id) => id,
177
+ Err ( e) => return Err ( ( StatusCode :: UNAUTHORIZED , e) ) ,
178
+ } ;
179
+ match gen_jwt (
180
+ app. config . auth . as_ref ( ) . unwrap ( ) ,
181
+ account_id,
182
+ TokenKind :: Session ,
183
+ ) {
184
+ Ok ( jwt) => Ok ( jwt) ,
185
+ Err ( e) => {
186
+ warn ! ( "Refresh error: {}" , e) ;
187
+ Err ( (
188
+ StatusCode :: INTERNAL_SERVER_ERROR ,
189
+ "Server error" . to_string ( ) ,
190
+ ) )
191
+ }
192
+ }
193
+ }
0 commit comments