Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 21 additions & 61 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func ValidateComponentsMW(dbInst *db.DB, logger *zap.Logger) gin.HandlerFunc {
}
}

func AuthenticationMW(prov *auth.Provider, logger *zap.Logger) gin.HandlerFunc {
func AuthenticationMW(prov *auth.Provider, logger *zap.Logger, secretKey string) gin.HandlerFunc {
return func(c *gin.Context) {
if prov.Disabled {
logger.Info("authentication is disabled")
Expand All @@ -70,72 +70,32 @@ func AuthenticationMW(prov *auth.Provider, logger *zap.Logger) gin.HandlerFunc {
}

rawToken := strings.TrimPrefix(authHeader, "Bearer ")
// Parse the JWT token and validate it using the Keycloak public key
token, err := jwt.Parse(rawToken, func(token *jwt.Token) (interface{}, error) { //nolint:revive
// Validate the token's signing method
// if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
// return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
// }

key, err := prov.GetPublicKey()
if err != nil {
return nil, fmt.Errorf("error while getting public key: %w", err)
}

return key, nil
})

if err != nil {
logger.Error("failed to parse and validate a token", zap.Error(err))
apiErrors.RaiseNotAuthorizedErr(c, apiErrors.ErrAuthNotAuthenticated)
return
}

if !token.Valid {
apiErrors.RaiseNotAuthorizedErr(c, apiErrors.ErrAuthNotAuthenticated)
return
}

c.Next()
}
}

// AuthenticationV1DeprecatedMW is a middleware for handling authentication only for Metric Processor.
// TODO: remove this middleware after migration to Keycloak.
func AuthenticationV1DeprecatedMW(prov *auth.Provider, logger *zap.Logger, secretKey string) gin.HandlerFunc {
return func(c *gin.Context) {
if prov.Disabled {
logger.Info("authentication is disabled")
c.Next()
return
}

logger.Info("start to process authentication request for V1")

authHeader := c.GetHeader("Authorization")
if authHeader == "" {
apiErrors.RaiseNotAuthorizedErr(c, apiErrors.ErrAuthNotAuthenticated)
return
}

rawToken := strings.TrimPrefix(authHeader, "Bearer ")
// Parse the JWT token and validate it using the Keycloak public key
token, err := jwt.Parse(rawToken, func(token *jwt.Token) (interface{}, error) {
// Validate the token's signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
switch token.Method.(type) {
// TODO: remove HMAC method after migration.
case *jwt.SigningMethodHMAC:
logger.Info("HMAC token deteced, using secret key for validation")
if secretKey == "" {
return nil, fmt.Errorf("secret key is not configured for HMAC token validation")
}
return []byte(secretKey), nil

case *jwt.SigningMethodRSA:
logger.Info("RSA token detected, using Keycloak public key for validation")
key, err := prov.GetPublicKey()
if err != nil {
return nil, fmt.Errorf("error while getting public key: %w", err)
}
return key, nil

default:
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

return []byte(secretKey), nil
})

if err != nil {
logger.Error("failed to parse and validate a token", zap.Error(err))
apiErrors.RaiseNotAuthorizedErr(c, apiErrors.ErrAuthNotAuthenticated)
return
}

if !token.Valid {
if err != nil || !token.Valid {
logger.Error("failed to parse or validate a token", zap.Error(err))
apiErrors.RaiseNotAuthorizedErr(c, apiErrors.ErrAuthNotAuthenticated)
return
}
Expand Down
12 changes: 6 additions & 6 deletions internal/api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (a *API) InitRoutes() {
{
v1API.GET("component_status", v1.GetComponentsStatusHandler(a.db, a.log))
v1API.POST("component_status",
AuthenticationV1DeprecatedMW(a.oa2Prov, a.log, a.secretKeyV1),
AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1),
v1.PostComponentStatusHandler(a.db, a.log),
)

Expand All @@ -38,27 +38,27 @@ func (a *API) InitRoutes() {
v2API := a.r.Group(v2Group)
{
v2API.GET("components", v2.GetComponentsHandler(a.db, a.log))
v2API.POST("components", AuthenticationMW(a.oa2Prov, a.log), v2.PostComponentHandler(a.db, a.log))
v2API.POST("components", AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1), v2.PostComponentHandler(a.db, a.log))
v2API.GET("components/:id", v2.GetComponentHandler(a.db, a.log))

v2API.GET("incidents", v2.GetIncidentsHandler(a.db, a.log))
v2API.POST("incidents",
AuthenticationMW(a.oa2Prov, a.log),
AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1),
ValidateComponentsMW(a.db, a.log),
v2.PostIncidentHandler(a.db, a.log),
)
v2API.GET("incidents/:incidentID", v2.GetIncidentHandler(a.db, a.log))
v2API.PATCH("incidents/:incidentID",
AuthenticationMW(a.oa2Prov, a.log),
AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1),
CheckEventExistenceMW(a.db, a.log),
v2.PatchIncidentHandler(a.db, a.log))
v2API.POST("incidents/:incidentID/extract",
AuthenticationMW(a.oa2Prov, a.log),
AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1),
CheckEventExistenceMW(a.db, a.log),
ValidateComponentsMW(a.db, a.log),
v2.PostIncidentExtractHandler(a.db, a.log))
v2API.PATCH("incidents/:incidentID/updates/:updateID",
AuthenticationMW(a.oa2Prov, a.log),
AuthenticationMW(a.oa2Prov, a.log, a.secretKeyV1),
CheckEventExistenceMW(a.db, a.log),
v2.PatchEventUpdateTextHandler(a.db, a.log))
// Paginated events.
Expand Down