Skip to content

Commit 24680eb

Browse files
committed
refactor: improved jwt authenticator middleware to handle expired cookies
1 parent db022d2 commit 24680eb

File tree

2 files changed

+109
-83
lines changed

2 files changed

+109
-83
lines changed

http/middleware/auth/authenticator.go

+3-7
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,17 @@ type Authenticator interface {
1515
err error,
1616
errorCode *string,
1717
),
18-
refreshTokenFn func(
19-
w http.ResponseWriter,
20-
r *http.Request,
21-
) error,
22-
authenticateFn func(next http.Handler) http.Handler,
2318
) func(next http.Handler) http.Handler
2419
AuthenticateFromHeader(
2520
token gojwttoken.Token,
2621
) func(next http.Handler) http.Handler
2722
AuthenticateFromCookie(
2823
token gojwttoken.Token,
29-
cookieName string,
24+
cookieRefreshTokenName,
25+
cookieAccessTokenName string,
3026
refreshTokenFn func(
3127
w http.ResponseWriter,
3228
r *http.Request,
33-
) error,
29+
) (*map[gojwttoken.Token]string, error),
3430
) func(next http.Handler) http.Handler
3531
}

http/middleware/auth/middleware.go

+106-76
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package auth
22

33
import (
44
"errors"
5-
"github.com/golang-jwt/jwt/v5"
65
gojwt "github.com/ralvarezdev/go-jwt"
76
gojwtnethttp "github.com/ralvarezdev/go-jwt/net/http"
87
gojwtnethttpctx "github.com/ralvarezdev/go-jwt/net/http/context"
@@ -49,11 +48,6 @@ func (m *Middleware) Authenticate(
4948
err error,
5049
errorCode *string,
5150
),
52-
refreshTokenFn func(
53-
w http.ResponseWriter,
54-
r *http.Request,
55-
) error,
56-
authenticateFn func(next http.Handler) http.Handler,
5751
) func(next http.Handler) http.Handler {
5852
return func(next http.Handler) http.Handler {
5953
return http.HandlerFunc(
@@ -64,33 +58,10 @@ func (m *Middleware) Authenticate(
6458
token,
6559
)
6660
if err != nil {
67-
// Check if the error is a token expired error
68-
if token == gojwttoken.RefreshToken || !errors.Is(
69-
err,
70-
jwt.ErrTokenExpired,
71-
) || refreshTokenFn == nil {
72-
failHandler(
73-
w,
74-
err,
75-
ErrCodeInvalidTokenClaims,
76-
)
77-
return
78-
}
79-
80-
// Refresh the token
81-
if err = refreshTokenFn(w, r); err != nil {
82-
failHandler(
83-
w,
84-
err,
85-
ErrCodeFailedToRefreshToken,
86-
)
87-
return
88-
}
89-
90-
// Authenticate again
91-
authenticateFn(next).ServeHTTP(
61+
failHandler(
9262
w,
93-
r,
63+
err,
64+
ErrCodeInvalidTokenClaims,
9465
)
9566
return
9667
}
@@ -153,8 +124,6 @@ func (m *Middleware) AuthenticateFromHeader(
153124
token,
154125
rawToken,
155126
failHandler,
156-
nil,
157-
nil,
158127
)(next).ServeHTTP(
159128
w,
160129
r,
@@ -167,64 +136,125 @@ func (m *Middleware) AuthenticateFromHeader(
167136
// AuthenticateFromCookie return the middleware function that authenticates the request from the cookie
168137
func (m *Middleware) AuthenticateFromCookie(
169138
token gojwttoken.Token,
170-
cookieName string,
139+
cookieRefreshTokenName,
140+
cookieAccessTokenName string,
171141
refreshTokenFn func(
172142
w http.ResponseWriter,
173143
r *http.Request,
174-
) error,
144+
) (*map[gojwttoken.Token]string, error),
175145
) func(next http.Handler) http.Handler {
146+
var cookieName string
147+
if token == gojwttoken.AccessToken {
148+
cookieName = cookieAccessTokenName
149+
} else if token == gojwttoken.RefreshToken {
150+
cookieName = cookieRefreshTokenName
151+
}
152+
176153
// Create the fail handler function
177154
failHandler := func(
155+
cookieName string,
156+
) func(
178157
w http.ResponseWriter,
179158
err error,
180159
errorCode *string,
181160
) {
182-
m.handler.HandleError(
183-
w,
184-
gonethttpresponse.NewFailResponseError(
185-
cookieName,
186-
err.Error(),
187-
errorCode,
188-
http.StatusUnauthorized,
189-
),
190-
)
161+
return func(
162+
w http.ResponseWriter,
163+
err error,
164+
errorCode *string,
165+
) {
166+
m.handler.HandleError(
167+
w,
168+
gonethttpresponse.NewFailResponseError(
169+
cookieName,
170+
err.Error(),
171+
errorCode,
172+
http.StatusUnauthorized,
173+
),
174+
)
175+
}
191176
}
192177

193178
// Create the authenticate function
194-
var authenticateFn func(next http.Handler) http.Handler
195-
authenticateFn = func(next http.Handler) http.Handler {
196-
return http.HandlerFunc(
197-
func(w http.ResponseWriter, r *http.Request) {
198-
// Get the cookie
199-
cookie, err := r.Cookie(cookieName)
179+
var authenticateFn func(*map[gojwttoken.Token]string) func(next http.Handler) http.Handler
180+
authenticateFn = func(rawTokens *map[gojwttoken.Token]string) func(next http.Handler) http.Handler {
181+
return func(next http.Handler) http.Handler {
182+
return http.HandlerFunc(
183+
func(w http.ResponseWriter, r *http.Request) {
184+
var rawToken string
185+
var cookie *http.Cookie
186+
var err error
187+
var ok bool
200188

201-
// Return an error if the cookie is missing
202-
if err != nil {
203-
failHandler(
204-
w,
205-
gonethttp.ErrCookieNotFound,
206-
gonethttp.ErrCodeCookieNotFound,
207-
)
208-
return
209-
}
189+
// Get the cookie
190+
if rawTokens != nil {
191+
// Get the raw token from the map
192+
rawToken, ok = (*rawTokens)[token]
210193

211-
// Get the raw token from the cookie
212-
rawToken := cookie.Value
194+
// Return an error if the token is missing
195+
if !ok {
196+
failHandler(cookieName)(
197+
w,
198+
gonethttp.ErrCookieNotFound,
199+
gonethttp.ErrCodeCookieNotFound,
200+
)
201+
return
202+
}
203+
} else {
204+
// Get the cookie from the request
205+
cookie, err = r.Cookie(cookieAccessTokenName)
213206

214-
// Call the Authenticate function
215-
m.Authenticate(
216-
token,
217-
rawToken,
218-
failHandler,
219-
refreshTokenFn,
220-
authenticateFn,
221-
)(next).ServeHTTP(
222-
w,
223-
r,
224-
)
225-
},
226-
)
207+
// Check if there was an error getting the cookie
208+
if err == nil {
209+
// Get the raw token from the cookie
210+
rawToken = cookie.Value
211+
} else if errors.Is(err, http.ErrNoCookie) {
212+
// Check if the token can be refreshed
213+
if token == gojwttoken.AccessToken && refreshTokenFn != nil {
214+
// Refresh the token
215+
rawTokens, err = refreshTokenFn(w, r)
216+
if err != nil {
217+
failHandler(cookieRefreshTokenName)(
218+
w,
219+
err,
220+
ErrCodeFailedToRefreshToken,
221+
)
222+
return
223+
}
224+
225+
// Authenticate again
226+
authenticateFn(rawTokens)(next).ServeHTTP(
227+
w,
228+
r,
229+
)
230+
return
231+
}
232+
}
233+
}
234+
235+
// Check if the raw token is empty
236+
if rawToken == "" {
237+
failHandler(cookieAccessTokenName)(
238+
w,
239+
gonethttp.ErrCookieNotFound,
240+
gonethttp.ErrCodeCookieNotFound,
241+
)
242+
return
243+
}
244+
245+
// Call the Authenticate function
246+
m.Authenticate(
247+
token,
248+
rawToken,
249+
failHandler,
250+
)(next).ServeHTTP(
251+
w,
252+
r,
253+
)
254+
},
255+
)
256+
}
227257
}
228258

229-
return authenticateFn
259+
return authenticateFn(nil)
230260
}

0 commit comments

Comments
 (0)