diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 9be5492..bdc286e 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -53,7 +53,7 @@ func ValidateComponentsMW(dbInst *db.DB, logger *zap.Logger) gin.HandlerFunc { } } -func AuthenticationMW(prov *auth.Provider, logger *zap.Logger) gin.HandlerFunc { +func AuthenticationMW(prov *auth.Provider, logger *zap.Logger, secretKey string) gin.HandlerFunc { return func(c *gin.Context) { if prov.Disabled { logger.Info("authentication is disabled") @@ -70,72 +70,32 @@ func AuthenticationMW(prov *auth.Provider, logger *zap.Logger) gin.HandlerFunc { } rawToken := strings.TrimPrefix(authHeader, "Bearer ") - // Parse the JWT token and validate it using the Keycloak public key - token, err := jwt.Parse(rawToken, func(token *jwt.Token) (interface{}, error) { //nolint:revive - // Validate the token's signing method - // if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - // return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - // } - - key, err := prov.GetPublicKey() - if err != nil { - return nil, fmt.Errorf("error while getting public key: %w", err) - } - - return key, nil - }) - - if err != nil { - logger.Error("failed to parse and validate a token", zap.Error(err)) - apiErrors.RaiseNotAuthorizedErr(c, apiErrors.ErrAuthNotAuthenticated) - return - } - if !token.Valid { - apiErrors.RaiseNotAuthorizedErr(c, apiErrors.ErrAuthNotAuthenticated) - return - } - - c.Next() - } -} - -// AuthenticationV1DeprecatedMW is a middleware for handling authentication only for Metric Processor. -// TODO: remove this middleware after migration to Keycloak. -func AuthenticationV1DeprecatedMW(prov *auth.Provider, logger *zap.Logger, secretKey string) gin.HandlerFunc { - return func(c *gin.Context) { - if prov.Disabled { - logger.Info("authentication is disabled") - c.Next() - return - } - - logger.Info("start to process authentication request for V1") - - authHeader := c.GetHeader("Authorization") - if authHeader == "" { - apiErrors.RaiseNotAuthorizedErr(c, apiErrors.ErrAuthNotAuthenticated) - return - } - - rawToken := strings.TrimPrefix(authHeader, "Bearer ") - // Parse the JWT token and validate it using the Keycloak public key token, err := jwt.Parse(rawToken, func(token *jwt.Token) (interface{}, error) { - // Validate the token's signing method - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + switch token.Method.(type) { + // TODO: remove HMAC method after migration. + case *jwt.SigningMethodHMAC: + logger.Info("HMAC token deteced, using secret key for validation") + if secretKey == "" { + return nil, fmt.Errorf("secret key is not configured for HMAC token validation") + } + return []byte(secretKey), nil + + case *jwt.SigningMethodRSA: + logger.Info("RSA token detected, using Keycloak public key for validation") + key, err := prov.GetPublicKey() + if err != nil { + return nil, fmt.Errorf("error while getting public key: %w", err) + } + return key, nil + + default: return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } - - return []byte(secretKey), nil }) - if err != nil { - logger.Error("failed to parse and validate a token", zap.Error(err)) - apiErrors.RaiseNotAuthorizedErr(c, apiErrors.ErrAuthNotAuthenticated) - return - } - - if !token.Valid { + if err != nil || !token.Valid { + logger.Error("failed to parse or validate a token", zap.Error(err)) apiErrors.RaiseNotAuthorizedErr(c, apiErrors.ErrAuthNotAuthenticated) return } diff --git a/internal/api/routes.go b/internal/api/routes.go index 102e118..60c7877 100644 --- a/internal/api/routes.go +++ b/internal/api/routes.go @@ -28,7 +28,7 @@ func (a *API) InitRoutes() { { v1API.GET("component_status", v1.GetComponentsStatusHandler(a.db, a.log)) v1API.POST("component_status", - AuthenticationV1DeprecatedMW(a.oa2Prov, a.log, a.secretKeyV1), + AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1), v1.PostComponentStatusHandler(a.db, a.log), ) @@ -38,27 +38,27 @@ func (a *API) InitRoutes() { v2API := a.r.Group(v2Group) { v2API.GET("components", v2.GetComponentsHandler(a.db, a.log)) - v2API.POST("components", AuthenticationMW(a.oa2Prov, a.log), v2.PostComponentHandler(a.db, a.log)) + v2API.POST("components", AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1), v2.PostComponentHandler(a.db, a.log)) v2API.GET("components/:id", v2.GetComponentHandler(a.db, a.log)) v2API.GET("incidents", v2.GetIncidentsHandler(a.db, a.log)) v2API.POST("incidents", - AuthenticationMW(a.oa2Prov, a.log), + AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1), ValidateComponentsMW(a.db, a.log), v2.PostIncidentHandler(a.db, a.log), ) v2API.GET("incidents/:incidentID", v2.GetIncidentHandler(a.db, a.log)) v2API.PATCH("incidents/:incidentID", - AuthenticationMW(a.oa2Prov, a.log), + AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1), CheckEventExistenceMW(a.db, a.log), v2.PatchIncidentHandler(a.db, a.log)) v2API.POST("incidents/:incidentID/extract", - AuthenticationMW(a.oa2Prov, a.log), + AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1), CheckEventExistenceMW(a.db, a.log), ValidateComponentsMW(a.db, a.log), v2.PostIncidentExtractHandler(a.db, a.log)) v2API.PATCH("incidents/:incidentID/updates/:updateID", - AuthenticationMW(a.oa2Prov, a.log), + AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1), CheckEventExistenceMW(a.db, a.log), v2.PatchEventUpdateTextHandler(a.db, a.log)) // Paginated events.