diff --git a/.air-cdk.toml b/.air-cdk.toml new file mode 100644 index 0000000..cad544b --- /dev/null +++ b/.air-cdk.toml @@ -0,0 +1,11 @@ +root = "." +tmp_dir = "tmp" + +[build] +bin = "" +cmd = './build.sh' +delay = 100 +exclude_dir = ["tmp", "cdk"] +full_bin = "sam local start-api --port 8160 --template cdk/cdk.out/twosv-api-dev.template.json --env-vars cdk/env.json" +include_ext = ["go"] +kill_delay = "0s" diff --git a/.dockerignore b/.dockerignore index 39729fb..50a2515 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,9 +2,8 @@ * # Whitelist required files -!.env.encrypted -!scripts/* !lambda/* +!router/* !server/* !u2fsimulator/* !u2fserver/* diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index bab1138..4a5b5a8 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,4 +1,4 @@ -* @silinternational/developers -*.tf @silinternational/tf-devs -*.go @silinternational/go-devs -go.* @silinternational/go-devs +* @sil-org/developers +*.tf @sil-org/tf-devs +*.go @sil-org/go-devs +go.* @sil-org/go-devs diff --git a/Dockerfile b/Dockerfile index 3d34820..92c46b0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,32 +1,16 @@ -FROM node:22 +FROM golang:1.24 -ENV GO_VERSION=1.24.4 - -ADD https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip . -ADD https://go.dev/dl/go${GO_VERSION}.linux-amd64.tar.gz . - -RUN < +- LETS_ENCRYPT_EMAIL= +- LETS_ENCRYPT_CA=production +- TLD= +- SANS=mfa-ui.,mfa-app. +- BACKEND1_URL=http://ui:80 +- FRONTEND1_DOMAIN=mfa-ui. +- BACKEND2_URL=http://app:8080 +- FRONTEND2_DOMAIN=mfa-app. + +Create DNS A records (without Cloudflare proxy enabled) for the values defined in `FRONTEND1_DOMAIN` and +`FRONTEND2_DOMAIN` pointing to 127.0.0.1 and wait for DNS propagation. Once all of the above configuration is in place, +run `make demo`. The first time will take several minutes for all the initialization. You can watch Docker logs on the +proxy container to keep tabs on the progress. + +### Lambda + +To exercise the API as it would be used in AWS Lambda, run this command: `air -c .air-cdk.toml`. This will run a +file watcher that will rebuild the app code and the CDK stack, then run `sam local start-api` using the generated +Cloudformation template. This will listen on port 8160. Any code changes will trigger a rebuild and SAM will restart +using the new code. + +Implementation notes: + +- SAM uses Docker internally, which would make it complicated to run with Docker Compose. +- You will need to install CDK and SAM on your computer for this to work. +- It can use the DynamoDB container in Docker Compose, which can be started using `make dbinit`. +- The `make dbinit` command creates an APIKey (key: `EC7C2E16-5028-432F-8AF2-A79A64CF3BC1` +secret: `1ED18444-7238-410B-A536-D6C15A3C`) +- Some unit tests will delete the APIKey created by `make dbinit`. diff --git a/api.go b/api.go index e724740..1f51623 100644 --- a/api.go +++ b/api.go @@ -2,21 +2,32 @@ package mfa import ( "encoding/json" + "errors" "log" "net/http" + "strings" + + "github.com/google/uuid" ) -const IDParam = "id" +const ( + IDParam = "id" + UUIDParam = "uuid" +) // simpleError is a custom error type that can be JSON-encoded for API responses type simpleError struct { - Error string `json:"error"` + Err string `json:"error"` } // newSimpleError creates a new simpleError from the given error -func newSimpleError(err error) simpleError { - return simpleError{Error: err.Error()} -} +func newSimpleError(err error) simpleError { return simpleError{Err: err.Error()} } + +// Error satisfies the error interface. +func (s simpleError) Error() string { return s.Err } + +// Is returns true if the error strings are equal. +func (s simpleError) Is(err error) bool { return s.Err == err.Error() || errors.Is(err, simpleError{}) } // jsonResponse encodes a body as JSON and writes it to the response. It sets the response Content-Type header to // "application/json". @@ -25,6 +36,8 @@ func jsonResponse(w http.ResponseWriter, body interface{}, status int) { switch b := body.(type) { case error: data = newSimpleError(b) + case string: + data = newSimpleError(errors.New(b)) default: data = body } @@ -34,7 +47,12 @@ func jsonResponse(w http.ResponseWriter, body interface{}, status int) { if data != nil { jBody, err = json.Marshal(data) if err != nil { - log.Printf("failed to marshal response body to json: %s", err) + + // SonarQube flagged this as vulnerable to injection attacks. Rather than exhaustively search for places + // where user input is inserted into the error message, I'll just sanitize it as recommended. + sanitizedError := strings.ReplaceAll(strings.ReplaceAll(err.Error(), "\n", "_"), "\r", "_") + + log.Printf("failed to marshal response body to json: %s", sanitizedError) w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte("failed to marshal response body to json")) return @@ -45,6 +63,15 @@ func jsonResponse(w http.ResponseWriter, body interface{}, status int) { w.WriteHeader(status) _, err = w.Write(jBody) if err != nil { - log.Printf("failed to write response in jsonResponse: %s\n", err) + log.Printf("failed to write response in jsonResponse: %s", err) + } +} + +// NewUUID returns a new V4 UUID value as a text string +func NewUUID() string { + u, err := uuid.NewRandom() + if err != nil { + panic("failed to generate uuid: " + err.Error()) } + return u.String() } diff --git a/apikey.go b/apikey.go index 82b0a4f..60fae39 100644 --- a/apikey.go +++ b/apikey.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "log" "net/http" "strings" "time" @@ -28,6 +29,14 @@ const ( paramOldKeySecret = "oldKeySecret" ) +const ( + apiKeyIsRequired = "apiKeyValue is required" + apiKeyNotFound = "API Key not found" + emailIsRequired = "email is required" +) + +var ErrKeyAlreadyActivated = errors.New("key already activated") + // ApiKey holds API key data from DynamoDB type ApiKey struct { Key string `dynamodbav:"value" json:"value"` @@ -76,7 +85,8 @@ func (k *ApiKey) IsCorrect(given string) error { err := bcrypt.CompareHashAndPassword([]byte(k.HashedSecret), []byte(given)) if err != nil { - return err + return fmt.Errorf("hash does not match plaintext (hash: %s) (plaintext: %v...): %w", + k.HashedSecret, given[0:min(len(given), 4)], err) } return nil @@ -169,14 +179,12 @@ func (k *ApiKey) DecryptLegacy(ciphertext string) (string, error) { iv, err := base64.StdEncoding.DecodeString(parts[0]) if err != nil { - fmt.Printf("failed to decode iv: %s\n", err) - return "", err + return "", fmt.Errorf("failed to decode iv: %w", err) } decodedCipher, err := base64.StdEncoding.DecodeString(parts[1]) if err != nil { - fmt.Printf("failed to decode ciphertext: %s\n", err) - return "", err + return "", fmt.Errorf("failed to decode ciphertext: %w", err) } // plaintext will hold decrypted content, it must be at least as long @@ -194,7 +202,7 @@ func (k *ApiKey) DecryptLegacy(ciphertext string) (string, error) { // ActivatedAt fields. func (k *ApiKey) Activate() error { if k.ActivatedAt != 0 { - return errors.New("key already activated") + return ErrKeyAlreadyActivated } random := make([]byte, 32) @@ -349,36 +357,48 @@ func (a *App) ActivateApiKey(w http.ResponseWriter, r *http.Request) { err := json.NewDecoder(r.Body).Decode(&requestBody) if err != nil { - jsonResponse(w, fmt.Errorf("invalid request: %w", err), http.StatusBadRequest) + log.Printf("invalid request in ActivateApiKey: %s", err) + jsonResponse(w, invalidRequest, http.StatusBadRequest) return } if requestBody.ApiKeyValue == "" { - jsonResponse(w, fmt.Errorf("apiKeyValue is required"), http.StatusBadRequest) + jsonResponse(w, apiKeyIsRequired, http.StatusBadRequest) return } if requestBody.Email == "" { - jsonResponse(w, fmt.Errorf("email is required"), http.StatusBadRequest) + jsonResponse(w, emailIsRequired, http.StatusBadRequest) return } newKey := ApiKey{Key: requestBody.ApiKeyValue, Store: a.db} err = newKey.Load() if err != nil { - jsonResponse(w, fmt.Errorf("key not found: %w", err), http.StatusNotFound) + if strings.Contains(err.Error(), "does not exist") { + jsonResponse(w, apiKeyNotFound, http.StatusNotFound) + } else { + log.Printf("error loading API Key: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) + } return } err = newKey.Activate() if err != nil { - jsonResponse(w, fmt.Errorf("failed to activate key: %w", err), http.StatusBadRequest) + if errors.Is(err, ErrKeyAlreadyActivated) { + jsonResponse(w, err, http.StatusBadRequest) + } else { + log.Printf("failed to activate key: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) + } return } err = newKey.Save() if err != nil { - jsonResponse(w, fmt.Errorf("failed to save key: %w", err), http.StatusInternalServerError) + log.Printf("failed to save key: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } @@ -393,25 +413,28 @@ func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) { err := json.NewDecoder(r.Body).Decode(&requestBody) if err != nil { - jsonResponse(w, fmt.Errorf("invalid request: %w", err), http.StatusBadRequest) + log.Printf("invalid request in CreateApiKey: %s", err) + jsonResponse(w, invalidRequest, http.StatusBadRequest) return } if requestBody.Email == "" { - jsonResponse(w, fmt.Errorf("email is required"), http.StatusBadRequest) + jsonResponse(w, emailIsRequired, http.StatusBadRequest) return } key, err := NewApiKey(requestBody.Email) if err != nil { - jsonResponse(w, fmt.Errorf("failed to create a random key: %w", err), http.StatusInternalServerError) + log.Printf("failed to create a random key: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } key.Store = a.db err = key.Save() if err != nil { - jsonResponse(w, fmt.Errorf("failed to save key: %w", err), http.StatusInternalServerError) + log.Printf("failed to save key: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } @@ -425,33 +448,42 @@ func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) { func (a *App) RotateApiKey(w http.ResponseWriter, r *http.Request) { requestBody, err := parseRotateKeyRequestBody(r.Body) if err != nil { - jsonResponse(w, fmt.Errorf("invalid request: %w", err), http.StatusBadRequest) + if strings.HasSuffix(err.Error(), "is required") { + jsonResponse(w, err, http.StatusBadRequest) + } else { + log.Printf("invalid request in RotateApiKey: %s", err) + jsonResponse(w, invalidRequest, http.StatusBadRequest) + } return } oldKey := ApiKey{Key: requestBody[paramOldKeyId], Store: a.GetDB()} err = oldKey.loadAndCheck(requestBody[paramOldKeySecret]) if err != nil { - jsonResponse(w, fmt.Errorf("old key is not valid: %w", err), http.StatusNotFound) + log.Printf("old key is not valid: %s", err) + jsonResponse(w, apiKeyNotFound, http.StatusNotFound) return } newKey := ApiKey{Key: requestBody[paramNewKeyId], Store: a.GetDB()} err = newKey.loadAndCheck(requestBody[paramNewKeySecret]) if err != nil { - jsonResponse(w, fmt.Errorf("new key is not valid: %w", err), http.StatusNotFound) + log.Printf("new key is not valid: %s", err) + jsonResponse(w, apiKeyNotFound, http.StatusNotFound) return } totpComplete, totpIncomplete, err := newKey.ReEncryptTOTPs(a.GetDB(), oldKey) if err != nil { - jsonResponse(w, fmt.Errorf("failed to re-encrypt TOTP data: %w", err), http.StatusInternalServerError) + log.Printf("failed to re-encrypt TOTP data: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } webauthnComplete, webauthnIncomplete, err := newKey.ReEncryptWebAuthnUsers(a.GetDB(), oldKey) if err != nil { - jsonResponse(w, fmt.Errorf("failed to re-encrypt WebAuthn data: %w", err), http.StatusInternalServerError) + log.Printf("failed to re-encrypt WebAuthn data: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } diff --git a/apikey_test.go b/apikey_test.go index 1a41e76..1994321 100644 --- a/apikey_test.go +++ b/apikey_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -13,7 +14,6 @@ import ( "testing" "time" - uuid "github.com/satori/go.uuid" "golang.org/x/crypto/bcrypt" ) @@ -190,9 +190,11 @@ func (ms *MfaSuite) TestApiKeyActivate() { } ms.NoError(err) - ms.Regexp(regexp.MustCompile("[A-Za-z0-9+/]{43}="), key.Secret) - ms.NoError(bcrypt.CompareHashAndPassword([]byte(key.HashedSecret), []byte(key.Secret))) - ms.WithinDuration(time.Now(), time.Unix(int64(key.ActivatedAt/1000), 0), time.Minute) + ms.Regexp(regexp.MustCompile("^[A-Za-z0-9+/]{43}=$"), key.Secret, "Secret isn't correct") + ms.NoError(bcrypt.CompareHashAndPassword([]byte(key.HashedSecret), []byte(key.Secret)), + "HashedSecret isn't correct") + ms.WithinDuration(time.Now(), time.Unix(int64(key.ActivatedAt/1000), 0), time.Minute, + "ActivatedAt isn't set to the current time") // ensure no other fields were changed ms.Equal(tt.key.Key, key.Key) @@ -219,7 +221,7 @@ func (ms *MfaSuite) TestActivateApiKey() { name string body any wantStatus int - wantError string + wantError error }{ { name: "not previously activated", @@ -236,7 +238,7 @@ func (ms *MfaSuite) TestActivateApiKey() { "apiKeyValue": key2.Key, }, wantStatus: http.StatusBadRequest, - wantError: "failed to activate key: key already activated", + wantError: ErrKeyAlreadyActivated, }, { name: "missing email", @@ -244,7 +246,7 @@ func (ms *MfaSuite) TestActivateApiKey() { "apiKeyValue": key3.Key, }, wantStatus: http.StatusBadRequest, - wantError: "email is required", + wantError: errors.New("email is required"), }, { name: "missing apiKey", @@ -252,7 +254,7 @@ func (ms *MfaSuite) TestActivateApiKey() { "email": exampleEmail, }, wantStatus: http.StatusBadRequest, - wantError: "apiKeyValue is required", + wantError: errors.New("apiKeyValue is required"), }, { name: "key not found", @@ -261,7 +263,7 @@ func (ms *MfaSuite) TestActivateApiKey() { "apiKeyValue": "not a key", }, wantStatus: http.StatusNotFound, - wantError: "key not found: item does not exist: not a key", + wantError: errors.New("API Key not found"), }, } for _, tt := range tests { @@ -274,7 +276,7 @@ func (ms *MfaSuite) TestActivateApiKey() { ms.Equal(tt.wantStatus, res.Status, fmt.Sprintf("ActivateApiKey response: %s", res.Body)) var se simpleError ms.decodeBody(res.Body, &se) - ms.Equal(tt.wantError, se.Error) + ms.ErrorIs(se, tt.wantError) return } @@ -299,7 +301,7 @@ func (ms *MfaSuite) TestCreateApiKey() { name string body any wantStatus int - wantError string + wantError error }{ { name: "success", @@ -312,7 +314,7 @@ func (ms *MfaSuite) TestCreateApiKey() { name: "missing email", body: map[string]interface{}{}, wantStatus: http.StatusBadRequest, - wantError: "email is required", + wantError: errors.New("email is required"), }, } for _, tt := range tests { @@ -321,11 +323,11 @@ func (ms *MfaSuite) TestCreateApiKey() { req := requestWithUser(tt.body, ApiKey{Store: localStorage}) ms.app.CreateApiKey(res, req) - if tt.wantError != "" { + if tt.wantError != nil { ms.Equal(tt.wantStatus, res.Status, fmt.Sprintf("CreateApiKey response: %s", res.Body)) var se simpleError ms.decodeBody(res.Body, &se) - ms.Equal(tt.wantError, se.Error) + ms.ErrorIs(se, tt.wantError) return } @@ -344,23 +346,16 @@ func (ms *MfaSuite) TestAppRotateApiKey() { key := user.ApiKey must(db.Store(config.ApiKeyTable, key)) - totp := TOTP{ - UUID: uuid.NewV4().String(), - ApiKey: key.Key, - EncryptedTotpKey: mustEncryptLegacy(key, "plain text TOTP key"), - } - must(db.Store(ms.app.GetConfig().TotpTable, totp)) + totp := ms.newPasscode(key) - newKey, err := NewApiKey("email@example.com") - must(err) - must(newKey.Activate()) + newKey := newTestKey() must(db.Store(config.ApiKeyTable, newKey)) tests := []struct { name string body any wantStatus int - wantError string + wantError error }{ { name: "missing oldKeyId", @@ -370,7 +365,7 @@ func (ms *MfaSuite) TestAppRotateApiKey() { paramOldKeySecret: key.Secret, }, wantStatus: http.StatusBadRequest, - wantError: "invalid request: oldKeyId is required", + wantError: errors.New("oldKeyId is required"), }, { name: "missing oldKeySecret", @@ -380,7 +375,7 @@ func (ms *MfaSuite) TestAppRotateApiKey() { paramOldKeyId: key.Key, }, wantStatus: http.StatusBadRequest, - wantError: "invalid request: oldKeySecret is required", + wantError: errors.New("oldKeySecret is required"), }, { name: "missing newKeyId", @@ -390,7 +385,7 @@ func (ms *MfaSuite) TestAppRotateApiKey() { paramOldKeySecret: key.Secret, }, wantStatus: http.StatusBadRequest, - wantError: "invalid request: newKeyId is required", + wantError: errors.New("newKeyId is required"), }, { name: "missing newKeySecret", @@ -400,7 +395,7 @@ func (ms *MfaSuite) TestAppRotateApiKey() { paramOldKeySecret: key.Secret, }, wantStatus: http.StatusBadRequest, - wantError: "invalid request: newKeySecret is required", + wantError: errors.New("newKeySecret is required"), }, { name: "good", @@ -419,11 +414,11 @@ func (ms *MfaSuite) TestAppRotateApiKey() { req := requestWithUser(tt.body, key) ms.app.RotateApiKey(res, req) - if tt.wantError != "" { + if tt.wantError != nil { ms.Equal(tt.wantStatus, res.Status, fmt.Sprintf("CreateApiKey response: %s", res.Body)) var se simpleError ms.decodeBody(res.Body, &se) - ms.Equal(tt.wantError, se.Error) + ms.ErrorIs(se, tt.wantError) return } @@ -448,7 +443,10 @@ func (ms *MfaSuite) TestAppRotateApiKey() { func (ms *MfaSuite) TestNewApiKey() { got, err := NewApiKey(exampleEmail) ms.NoError(err) - ms.Regexp(regexp.MustCompile("[a-f0-9]{40}"), got) + ms.Equal(exampleEmail, got.Email, "Email isn't correct") + ms.Regexp(regexp.MustCompile("^[a-f0-9]{40}$"), got.Key, "Key isn't correct") + ms.WithinDuration(time.Now(), time.Unix(int64(got.CreatedAt)/1000, 0), time.Minute, + "CreatedAt isn't set to the current time") } func (ms *MfaSuite) TestNewCipherBlock() { @@ -510,11 +508,7 @@ func (ms *MfaSuite) TestApiKey_ReEncryptTOTPs() { must(newKey.Activate()) must(ms.app.GetDB().Store(ms.app.GetConfig().ApiKeyTable, newKey)) - must(storage.Store(ms.app.GetConfig().TotpTable, TOTP{ - UUID: uuid.NewV4().String(), - ApiKey: oldKey.Key, - EncryptedTotpKey: mustEncryptLegacy(oldKey, "plain text TOTP key"), - })) + _ = ms.newPasscode(oldKey) complete, incomplete, err := newKey.ReEncryptTOTPs(storage, oldKey) ms.NoError(err) @@ -531,9 +525,7 @@ func (ms *MfaSuite) TestReEncryptWebAuthnUsers() { baseConfigs := getDBConfig(ms) users := getTestWebauthnUsers(ms, baseConfigs) - newKey, err := NewApiKey("email@example.com") - must(err) - must(newKey.Activate()) + newKey := newTestKey() must(ms.app.GetDB().Store(ms.app.GetConfig().ApiKeyTable, newKey)) complete, incomplete, err := newKey.ReEncryptWebAuthnUsers(storage, users[0].ApiKey) @@ -577,9 +569,7 @@ func (ms *MfaSuite) TestReEncryptWebAuthnUser() { } for _, tt := range tests { ms.Run(tt.name, func() { - newKey, err := NewApiKey("email@example.com") - must(err) - must(newKey.Activate()) + newKey := newTestKey() must(ms.app.GetDB().Store(ms.app.GetConfig().ApiKeyTable, newKey)) ms.NotEqual(newKey.Secret, tt.user.ApiKey.Secret) @@ -673,3 +663,10 @@ func (ms *MfaSuite) TestApiKeyReEncryptLegacy() { ms.NoError(err) ms.Equal(plaintext, after) } + +func newTestKey() ApiKey { + apiKey, err := NewApiKey("user@example.com") + must(err) + must(apiKey.Activate()) + return apiKey +} diff --git a/auth.go b/auth.go index 4436c7c..b880bca 100644 --- a/auth.go +++ b/auth.go @@ -50,7 +50,7 @@ func AuthenticateRequest(r *http.Request) (User, error) { return authWebauthnUser(r, localStorage, apiKey) case "totp": - return nil, fmt.Errorf("TOTP is not yet supported") + return authTOTP(apiKey) case "api-key": return nil, nil // no authentication required for api-key diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..2820e42 --- /dev/null +++ b/build.sh @@ -0,0 +1,13 @@ +#!/bin/env bash + +set -e + +set -x + +go build -tags lambda.norpc -ldflags="-s -w" -o bootstrap ./lambda + +( + cd cdk || exit + rm -rf cdk.out/asset.* + cdk synth -q +) diff --git a/cdk/cdk.context.json b/cdk/cdk.context.json new file mode 100644 index 0000000..c3f254d --- /dev/null +++ b/cdk/cdk.context.json @@ -0,0 +1,6 @@ +{ + "cli-telemetry": false, + "acknowledged-issue-numbers": [ + 34892 + ] +} diff --git a/cdk/cdk.go b/cdk/cdk.go index 7572ab3..c5d510c 100644 --- a/cdk/cdk.go +++ b/cdk/cdk.go @@ -29,9 +29,6 @@ func NewCdkStack(scope constructs.Construct, id string, props *CdkStackProps) aw totpTable := getEnv("TOTP_TABLE", "totp") webauthnTable := getEnv("WEBAUTHN_TABLE", "webauthn") lambdaRoleArn := getEnv("LAMBDA_ROLE", "") - if lambdaRoleArn == "" { - panic("LAMBDA_ROLE environment variable must be set") - } functionName := id @@ -41,9 +38,7 @@ func NewCdkStack(scope constructs.Construct, id string, props *CdkStackProps) aw RemovalPolicy: awscdk.RemovalPolicy_RETAIN, // Retain logs when stack is deleted }) - role := awsiam.Role_FromRoleArn(stack, jsii.String("Role"), jsii.String(lambdaRoleArn), nil) - - function := awslambda.NewFunction(stack, jsii.String("Function"), &awslambda.FunctionProps{ + functionProps := &awslambda.FunctionProps{ Code: awslambda.Code_FromAsset(jsii.String("../"), &awss3assets.AssetOptions{ // include only the bootstrap file Exclude: jsii.Strings("**", "!bootstrap"), @@ -52,16 +47,25 @@ func NewCdkStack(scope constructs.Construct, id string, props *CdkStackProps) aw "API_KEY_TABLE": jsii.String(apiKeyTable), "TOTP_TABLE": jsii.String(totpTable), "WEBAUTHN_TABLE": jsii.String(webauthnTable), + "AWS_ENDPOINT": jsii.String(""), }, FunctionName: &functionName, Handler: jsii.String("bootstrap"), LoggingFormat: awslambda.LoggingFormat_JSON, LogGroup: logGroup, MemorySize: jsii.Number(1024.0), - Role: role, Runtime: awslambda.Runtime_PROVIDED_AL2023(), Timeout: awscdk.Duration_Seconds(jsii.Number(5)), - }) + } + + if lambdaRoleArn != "" { + functionProps.Role = awsiam.Role_FromRoleArn(stack, jsii.String("Role"), jsii.String(lambdaRoleArn), nil) + } else { + functionProps.Role = awsiam.Role_FromRoleName(stack, jsii.String("Role"), + jsii.String("service-role/AWSLambdaBasicExecutionRole"), nil) + } + + function := awslambda.NewFunction(stack, jsii.String("Function"), functionProps) api := awsapigateway.NewRestApi(stack, jsii.String("API"), &awsapigateway.RestApiProps{ RestApiName: jsii.String(functionName), @@ -93,19 +97,21 @@ func main() { env = "dev" } - NewCdkStack(app, "twosv-api-"+env, &CdkStackProps{ - awscdk.StackProps{ - Env: &awscdk.Environment{ - Region: jsii.String(os.Getenv("AWS_REGION")), - }, - Tags: &map[string]*string{ - "managed_by": jsii.String("cdk"), - "itse_app_name": jsii.String("twosv-api"), - "itse_app_customer": jsii.String("shared"), - "itse_app_env": jsii.String(env), - }, + props := awscdk.StackProps{ + Tags: &map[string]*string{ + "managed_by": jsii.String("cdk"), + "itse_app_name": jsii.String("twosv-api"), + "itse_app_customer": jsii.String("shared"), + "itse_app_env": jsii.String(env), }, - }) + } + + region := os.Getenv("AWS_REGION") + if region != "" { + props.Env = &awscdk.Environment{Region: ®ion} + } + + NewCdkStack(app, "twosv-api-"+env, &CdkStackProps{props}) app.Synth(nil) } diff --git a/cdk/env.json b/cdk/env.json new file mode 100644 index 0000000..9a75735 --- /dev/null +++ b/cdk/env.json @@ -0,0 +1,8 @@ +{ + "Parameters": { + "AWS_ENDPOINT": "http://172.17.0.1:8000", + "API_KEY_TABLE": "ApiKey", + "TOTP_TABLE": "Totp", + "WEBAUTHN_TABLE": "WebAuthn" + } +} diff --git a/docker-compose.yml b/compose.yaml similarity index 76% rename from docker-compose.yml rename to compose.yaml index 3ce07ab..abf93f3 100644 --- a/docker-compose.yml +++ b/compose.yaml @@ -25,7 +25,7 @@ services: volumes: - ./:/src ports: - - 8080 + - "8161:8080" environment: AWS_REGION: localhost AWS_ENDPOINT: http://dynamo:8000 @@ -44,22 +44,4 @@ services: volumes: - ./demo-ui:/usr/local/apache2/htdocs ports: - - 80 - - sls: - build: . - env_file: - - path: ./local.env - required: false - volumes: - - .:/src - command: ["bash"] - - cdk: - build: . - volumes: - - ./:/src - user: 1000:1001 - env_file: - - path: local.env - required: false + - "80" diff --git a/go.mod b/go.mod index 1ac187b..8a914bd 100644 --- a/go.mod +++ b/go.mod @@ -10,9 +10,9 @@ require ( github.com/aws/aws-sdk-go-v2/service/dynamodb v1.43.3 github.com/fxamacker/cbor/v2 v2.8.0 github.com/go-webauthn/webauthn v0.11.2 - github.com/gorilla/mux v1.8.1 + github.com/google/uuid v1.6.0 github.com/kelseyhightower/envconfig v1.4.0 - github.com/satori/go.uuid v1.2.0 + github.com/pquerna/otp v1.5.0 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.39.0 ) @@ -30,12 +30,12 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.25.4 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.2 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.33.21 // indirect - github.com/aws/smithy-go v1.22.3 // indirect + github.com/aws/smithy-go v1.22.2 // indirect + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-webauthn/x v0.1.21 // indirect github.com/golang-jwt/jwt/v5 v5.2.2 // indirect - github.com/google/go-tpm v0.9.5 // indirect - github.com/google/uuid v1.6.0 // indirect + github.com/google/go-tpm v0.9.3 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index f3836a4..df9ba2f 100644 --- a/go.sum +++ b/go.sum @@ -32,9 +32,12 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.2 h1:XB4z0hbQtpmBnb1FQYvKaCM7 github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.2/go.mod h1:hwRpqkRxnQ58J9blRDrB4IanlXCpcKmsC83EhG77upg= github.com/aws/aws-sdk-go-v2/service/sts v1.33.21 h1:nyLjs8sYJShFYj6aiyjCBI3EcLn1udWrQTjEF+SOXB0= github.com/aws/aws-sdk-go-v2/service/sts v1.33.21/go.mod h1:EhdxtZ+g84MSGrSrHzZiUm9PYiZkrADNja15wtRJSJo= -github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= -github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fxamacker/cbor/v2 v2.8.0 h1:fFtUGXUzXPHTIUdne5+zzMPTfffl3RD5qYnkY40vtxU= @@ -45,12 +48,10 @@ github.com/go-webauthn/x v0.1.21 h1:nFbckQxudvHEJn2uy1VEi713MeSpApoAv9eRqsb9AdQ= github.com/go-webauthn/x v0.1.21/go.mod h1:sEYohtg1zL4An1TXIUIQ5csdmoO+WO0R4R2pGKaHYKA= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/google/go-tpm v0.9.5 h1:ocUmnDebX54dnW+MQWGQRbdaAcJELsa6PqZhJ48KwVU= -github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/go-tpm v0.9.3 h1:+yx0/anQuGzi+ssRqeD6WpXjW2L/V0dItUayO0i9sRc= +github.com/google/go-tpm v0.9.3/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= -github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= @@ -65,11 +66,13 @@ github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RR github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= -github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= -github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= diff --git a/lambda/main.go b/lambda/main.go index 13f72a7..cad88a4 100644 --- a/lambda/main.go +++ b/lambda/main.go @@ -2,8 +2,6 @@ package main import ( "context" - "encoding/json" - "fmt" "io" "log" "net/http" @@ -13,10 +11,10 @@ import ( "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" - "github.com/gorilla/mux" "github.com/kelseyhightower/envconfig" mfa "github.com/silinternational/serverless-mfa-api-go" + "github.com/silinternational/serverless-mfa-api-go/router" ) var envConfig mfa.EnvConfig @@ -44,82 +42,15 @@ func main() { lambda.Start(handler) } -func credentialToDelete(req events.APIGatewayProxyRequest) (string, bool) { - if strings.ToLower(req.HTTPMethod) != `delete` { - return "", false - } - - path := req.Path - if !strings.HasPrefix(path, "/webauthn/credential/") { - return "", false - } - - parts := strings.Split(path, `/`) - if len(parts) != 4 { - return "", false - } - - credID := parts[3] - return credID, true -} - -func addDeleteCredentialParamForMux(r *http.Request, key, value string) *http.Request { - vars := map[string]string{key: value} - return mux.SetURLVars(r, vars) -} - func handler(ctx context.Context, req events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { r := httpRequestFromProxyRequest(ctx, req) + w := newLambdaResponseWriter() app := mfa.NewApp(envConfig) + mux := router.NewMux(app) - if !strings.HasPrefix(r.URL.Path, "/api-key") { - user, err := mfa.AuthenticateRequest(r) - if err != nil { - return clientError(http.StatusUnauthorized, fmt.Sprintf("unable to authenticate request: %s", err)) - } - newCtx := context.WithValue(r.Context(), mfa.UserContextKey, user) - r = r.WithContext(newCtx) - } + mux.ServeHTTP(w, r) - // Use custom lambda http.ResponseWriter - w := newLambdaResponseWriter() - - // This (delete /webauthn/credential/abc123) expects a path param that is - // the id that was previously returned as - // the key_handle_hash from the FinishRegistration call. - // Alternatively, if the id param indicates that a legacy U2F key should be removed - // (e.g. by matching the string "u2f") - // then that user is saved with all of its legacy u2f fields blanked out. - if credIdToDelete, ok := credentialToDelete(req); ok { - // add the id to the context in order for mux to retrieve it - r = addDeleteCredentialParamForMux(r, mfa.IDParam, credIdToDelete) - app.DeleteCredential(w, r) - // Routes other than /webauthn/delete/credential/abc213 - } else { - route := strings.ToLower(fmt.Sprintf("%s %s", req.HTTPMethod, req.Path)) - - switch route { - case "post /api-key": - app.CreateApiKey(w, r) - case "post /api-key/activate": - app.ActivateApiKey(w, r) - case "post /api-key/rotate": - app.RotateApiKey(w, r) - case "post /webauthn/login": - app.BeginLogin(w, r) - case "put /webauthn/login": - app.FinishLogin(w, r) - case "post /webauthn/register": - app.BeginRegistration(w, r) - case "put /webauthn/register": - app.FinishRegistration(w, r) - case "delete /webauthn/user": - app.DeleteUser(w, r) - default: - return clientError(http.StatusNotFound, fmt.Sprintf("The requested route is not supported: %s", route)) - } - } headers := map[string]string{} for k, v := range w.Header() { headers[k] = v[0] @@ -132,20 +63,6 @@ func handler(ctx context.Context, req events.APIGatewayProxyRequest) (events.API }, nil } -// clientError helper for send responses relating to client errors. -func clientError(status int, body string) (events.APIGatewayProxyResponse, error) { - type cError struct { - Error string - } - - js, _ := json.Marshal(cError{Error: body}) - - return events.APIGatewayProxyResponse{ - StatusCode: status, - Body: string(js), - }, nil -} - func httpRequestFromProxyRequest(ctx context.Context, req events.APIGatewayProxyRequest) *http.Request { headers := http.Header{} for k, v := range req.Headers { @@ -162,5 +79,6 @@ func httpRequestFromProxyRequest(ctx context.Context, req events.APIGatewayProxy RequestURI: req.Path, URL: requestURL, } + return r.WithContext(ctx) } diff --git a/lambda/main_test.go b/lambda/main_test.go deleted file mode 100644 index 1d06e94..0000000 --- a/lambda/main_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package main - -import ( - "net/http/httptest" - "testing" - - "github.com/aws/aws-lambda-go/events" - "github.com/gorilla/mux" - "github.com/stretchr/testify/require" - - mfa "github.com/silinternational/serverless-mfa-api-go" -) - -func TestCredentialToDelete(t *testing.T) { - assert := require.New(t) - - tests := []struct { - name string - method string - path string - wantId string - wantOk bool - }{ - { - name: "method not delete", - method: "PUT", - path: "/webauthn/credential/abc123", - wantId: "", - wantOk: false, - }, - { - name: "path not /webauthn ...", - method: "DELETE", - path: "/badstuff/credential/abc123", - wantId: "", - wantOk: false, - }, - { - name: "path not /webauthn/credential ...", - method: "DELETE", - path: "/webauthn/badstuff/abc123", - wantId: "", - wantOk: false, - }, - { - name: "path too few parts", - method: "DELETE", - path: "/webauthn/credential", - wantId: "", - wantOk: false, - }, - { - name: "all good", - method: "DELETE", - path: "/webauthn/credential/abc123", - wantId: "abc123", - wantOk: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := events.APIGatewayProxyRequest{ - HTTPMethod: tt.method, - Path: tt.path, - } - gotId, gotOk := credentialToDelete(req) - assert.Equal(tt.wantOk, gotOk, "test %s: incorrect bool", tt.name) - assert.Equal(tt.wantId, gotId, "test %s: incorrect credential ID", tt.name) - }) - } -} - -func TestAddDeleteCredentialParamForMux(t *testing.T) { - assert := require.New(t) - r := httptest.NewRequest("DELETE", "/webauthn/credential/abc123", nil) - - credId := "abc123" - r = addDeleteCredentialParamForMux(r, mfa.IDParam, credId) - params := mux.Vars(r) - got, ok := params[mfa.IDParam] - assert.True(ok, "didn't find key in mux vars: %v", params) - assert.Equal(credId, got, "incorrect param value") -} diff --git a/local.env.example b/local.env.example index 0728b78..7f642d5 100644 --- a/local.env.example +++ b/local.env.example @@ -3,8 +3,7 @@ # https proxy config DNS_PROVIDER=cloudflare -CLOUDFLARE_EMAIL= -CLOUDFLARE_API_KEY= +CLOUDFLARE_DNS_API_TOKEN= LETS_ENCRYPT_EMAIL= LETS_ENCRYPT_CA=production TLD= diff --git a/openapi.yaml b/openapi.yaml index 4c19aad..1bd390c 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -629,9 +629,123 @@ paths: /webauthn/credential/{id}: delete: - summary: Delete Credential (not currently supported in the Lambda implementation) + summary: Delete Credential + description: >- + DeleteCredential removes a single passkey identified by the id parameter, which is the `key_handle_hash` + returned by the Finish Registration endpoint. + parameters: + - in: path + name: id + schema: + type: string + required: true + description: >- + The unique identifier for the credential. It is the `key_handle_hash` value returned by the + Finish Registration endpoint. If the id is the fixed string "u2f" the legacy U2F data is removed from the + record. responses: 204: description: Success 401: $ref: "#/components/responses/UnauthorizedError" + + /totp: + post: + summary: Create a new passcode (TOTP) + requestBody: + content: + application/json: + schema: + type: object + properties: + issuer: + type: string + description: >- + The name of the service/website for which the user is setting up this TOTP account. + required: true + example: "SIL IdP" + label: + type: string + description: A name for the user's account (such as their username). + required: false + example: "john_smith@example.com" + responses: + 200: + description: Success + content: + application/json: + schema: + type: object + properties: + uuid: + type: string + description: The unique identifier for the passcode. + example: "aaaaaaaa-1111-aaaa-1111-aaaaaaaaaaaa" + totpKey: + type: string + description: The passcode secret key. + example: "0123456789ABCDEF0123456789ABCDEF" + imageUrl: + type: string + description: >- + ImageURL is a base64-encoded image in data URL format like "...". + The image is a QR code that contains the OTPAuthURL, which the user scans to store the shared + secret key and metadata in their authenticator app. + example: "..." + otpAuthUrl: + type: string + description: >- + OTPAuthURL is an otpauth URI that contains the passcode secret key. It may also contain metadata + like issuer, algorithm, and number of digits. + example: "otpauth://totp/idp:john_smith?secret=0123456789ABCDEF0123456789ABCDEF&issuer=SIL%20IdP" + 401: + $ref: "#/components/responses/UnauthorizedError" + /totp/{uuid}: + delete: + summary: Delete a passcode (TOTP) + parameters: + - in: path + name: uuid + schema: + type: string + format: uuid + required: true + description: The unique identifier for the passcode. + responses: + "204": + description: Success + "404": + description: Not found + "401": + $ref: "#/components/responses/UnauthorizedError" + /totp/{uuid}/validate: + delete: + summary: Validate a passcode (TOTP) + parameters: + - in: path + name: uuid + schema: + type: string + format: uuid + required: true + description: The unique identifier for the passcode. + requestBody: + content: + application/json: + schema: + type: object + properties: + code: + type: string + required: true + description: >- + The generated one-time password to validate against the shared secret identified by the UUID + parameter. + example: "123456" + responses: + "200": + description: The code is valid. + "400": + description: Invalid request, e.g. missing code. + "401": + description: Key is not valid or one-time code is not valid. diff --git a/server/authentication_middleware.go b/router/middleware.go similarity index 86% rename from server/authentication_middleware.go rename to router/middleware.go index 501db11..f9ca988 100644 --- a/server/authentication_middleware.go +++ b/router/middleware.go @@ -1,8 +1,8 @@ -package main +package router import ( "context" - "fmt" + "log" "net/http" "strings" @@ -21,7 +21,8 @@ func authenticationMiddleware(next http.Handler) http.Handler { user, err := mfa.AuthenticateRequest(r) if err != nil { - http.Error(w, fmt.Sprintf("unable to authenticate request: %s", err), http.StatusUnauthorized) + log.Printf("unable to authenticate request: %s", err) + http.Error(w, "Unauthorized", http.StatusUnauthorized) return } diff --git a/router/router.go b/router/router.go new file mode 100644 index 0000000..b0dcd42 --- /dev/null +++ b/router/router.go @@ -0,0 +1,35 @@ +package router + +import ( + "net/http" + + mfa "github.com/silinternational/serverless-mfa-api-go" +) + +// NewMux forms a new ServeMux router, see https://pkg.go.dev/net/http#ServeMux. +func NewMux(app *mfa.App) *http.ServeMux { + mux := http.NewServeMux() + + for pattern, handler := range getRoutes(app) { + mux.Handle(pattern, authenticationMiddleware(handler)) + } + return mux +} + +// getRoutes returns a list of routes for the server +func getRoutes(app *mfa.App) map[string]http.HandlerFunc { + return map[string]http.HandlerFunc{ + "POST /api-key/activate": app.ActivateApiKey, + "POST /api-key/rotate": app.RotateApiKey, + "POST /api-key": app.CreateApiKey, + "POST /totp": app.CreateTOTP, + "DELETE /totp/{" + mfa.UUIDParam + "}": app.DeleteTOTP, + "POST /totp/{" + mfa.UUIDParam + "}/validate": app.ValidateTOTP, + "POST /webauthn/register": app.BeginRegistration, + "PUT /webauthn/register": app.FinishRegistration, + "POST /webauthn/login": app.BeginLogin, + "PUT /webauthn/login": app.FinishLogin, + "DELETE /webauthn/user": app.DeleteUser, + "DELETE /webauthn/credential/{" + mfa.IDParam + "}": app.DeleteCredential, + } +} diff --git a/server/main.go b/server/main.go index b6780f2..48393dd 100644 --- a/server/main.go +++ b/server/main.go @@ -1,16 +1,14 @@ package main import ( - "encoding/json" - "fmt" "log" "net/http" "os" - "github.com/gorilla/mux" "github.com/kelseyhightower/envconfig" mfa "github.com/silinternational/serverless-mfa-api-go" + "github.com/silinternational/serverless-mfa-api-go/router" ) var envConfig mfa.EnvConfig @@ -29,104 +27,7 @@ func main() { // ListenAndServe starts an HTTP server with a given address and // handler defined in NewRouter. log.Println("Starting service on port 8080") - router := newRouter(mfa.NewApp(envConfig)) - log.Fatal(http.ListenAndServe(":8080", router)) -} - -// route is used to pass information about a particular route. -type route struct { - Name string - Method string - Pattern string - HandlerFunc http.HandlerFunc -} - -// getRoutes returns a list of routes for the server -func getRoutes(app *mfa.App) []route { - return []route{ - { - Name: "ActivateApiKey", - Method: "POST", - Pattern: "/api-key/activate", - HandlerFunc: app.ActivateApiKey, - }, - { - Name: "RotateApiKey", - Method: "POST", - Pattern: "/api-key/rotate", - HandlerFunc: app.RotateApiKey, - }, - { - Name: "CreateApiKey", - Method: "POST", - Pattern: "/api-key", - HandlerFunc: app.CreateApiKey, - }, - { - Name: "FinishRegistration", - Method: "PUT", - Pattern: "/webauthn/register", - HandlerFunc: app.FinishRegistration, - }, - { - Name: "BeginLogin", - Method: "POST", - Pattern: "/webauthn/login", - HandlerFunc: app.BeginLogin, - }, - { - Name: "FinishLogin", - Method: "PUT", - Pattern: "/webauthn/login", - HandlerFunc: app.FinishLogin, - }, - { - Name: "DeleteUser", - Method: "DELETE", - Pattern: "/webauthn/user", - HandlerFunc: app.DeleteUser, - }, - { // This expects a path param that is the id that was previously returned - // as the key_handle_hash from the FinishRegistration call. - // Alternatively, if the id param indicates that a legacy U2F key should be removed - // (e.g. by matching the string "u2f") - // then that user is saved with all of its legacy u2f fields blanked out. - Name: "DeleteCredential", - Method: "DELETE", - Pattern: fmt.Sprintf("/webauthn/credential/{%s}", mfa.IDParam), - HandlerFunc: app.DeleteCredential, - }, - } -} - -// newRouter forms a new mux router, see https://github.com/gorilla/mux. -func newRouter(app *mfa.App) *mux.Router { - // Create a basic router. - router := mux.NewRouter().StrictSlash(true) - - // authenticate request based on api key and secret in headers - // also adds user to context - router.Use(authenticationMiddleware) - - // Assign the handlers to run when endpoints are called. - for _, route := range getRoutes(app) { - router.Methods(route.Method).Path(route.Pattern).Name(route.Name).Handler(route.HandlerFunc) - } - - router.NotFoundHandler = router.NewRoute().HandlerFunc(notFound).GetHandler() - return router -} - -func notFound(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json; charset=UTF-8") - w.WriteHeader(http.StatusNotFound) - - notFound := map[string]string{ - "Method": r.Method, - "URL": r.URL.String(), - "RequestURI": r.RequestURI, - } - if err := json.NewEncoder(w).Encode(notFound); err != nil { - log.Printf("ERROR could not marshal not found message to JSON: %s", err) - } + app := mfa.NewApp(envConfig) + mux := router.NewMux(app) + log.Fatal(http.ListenAndServe(":8080", mux)) } diff --git a/totp.go b/totp.go index 29df2b7..27d112a 100644 --- a/totp.go +++ b/totp.go @@ -1,7 +1,302 @@ package mfa +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "image/png" + "io" + "log" + "net/http" + "strings" + + "github.com/pquerna/otp/totp" +) + +// TOTPTablePK is the primary key in the TOTP DynamoDB table +const TOTPTablePK = "uuid" + +const ( + internalServerError = "Internal server error" + invalidRequest = "Invalid request" + totpNotFound = "TOTP not found" +) + +// TOTP contains data to represent a Time-based One-Time Passcode (token). The ID and encrypted fields are persisted in +// DynamoDB. The others are non-encrypted and are short-lived. type TOTP struct { - UUID string `dynamodbav:"uuid" json:"uuid"` - ApiKey string `dynamodbav:"apiKey" json:"apiKey"` + // UUID is the unique ID and primary key for the passcode. + UUID string `dynamodbav:"uuid" json:"uuid"` + + // ApiKey is the ID of the API Key used to create and access this passcode. + ApiKey string `dynamodbav:"apiKey" json:"apiKey"` + + // EncryptedTotpKey is the encrypted form of the key of the passcode. EncryptedTotpKey string `dynamodbav:"encryptedTotpKey" json:"encryptedTotpKey"` + + // Key is the passcode secret key. + Key string `dynamodbav:"-" json:"-"` + + // ImageURL is a base64-encoded image in data URL format like "...". The image + // is a QR code that contains the OTPAuthURL, which the user scans to store the shared secret key and metadata in + // their authenticator app. + ImageURL string `dynamodbav:"-" json:"-"` + + // OTPAuthURL is a otpauth URI like "otpauth://totp/idp:john_doe?secret=G5KFM3LNJ5NWQP3O&issuer=idp" that contains + // the passcode secret key. It may also contain metadata like issuer, algorithm, and number of digits. + OTPAuthURL string `dynamodbav:"-" json:"-"` +} + +// debugString is used by the debugger to show useful TOTP information in watched variables +func (t TOTP) debugString() string { + return fmt.Sprintf("UUID: %s, Key: %s, ApiKey: %s", t.UUID, t.Key, t.ApiKey) +} + +// CreateTOTPRequestBody defines the JSON request body for the CreateTOTP endpoint +type CreateTOTPRequestBody struct { + Issuer string `json:"issuer"` + Name string `json:"label"` +} + +// CreateTOTPResponseBody defines the JSON response body for the CreateTOTP endpoint +type CreateTOTPResponseBody struct { + UUID string `json:"uuid"` + TOTPKey string `json:"totpKey"` + OTPAuthURL string `json:"otpAuthUrl"` + ImageURL string `json:"imageUrl"` +} + +// ValidateTOTPRequestBody defines the JSON request body for the ValidateTOTP endpoint +type ValidateTOTPRequestBody struct { + Code string `json:"code"` +} + +// CreateTOTP is the http handler to create a new TOTP passcode. +func (a *App) CreateTOTP(w http.ResponseWriter, r *http.Request) { + requestBody, err := parseCreateTOTPRequestBody(r.Body) + if err != nil { + log.Println("Invalid CreateTOTP request body:", err) + jsonResponse(w, invalidRequest, http.StatusBadRequest) + return + } + + apiKey, err := getAPIKey(r) + if err != nil { + log.Printf("CreateTOTP API key error: %v", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) + return + } + + t, err := newTOTP(a.db, apiKey, requestBody.Issuer, requestBody.Name) + if err != nil { + log.Printf("failed to create a new TOTP: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) + return + } + + responseBody := CreateTOTPResponseBody{ + UUID: t.UUID, + TOTPKey: t.Key, + OTPAuthURL: t.OTPAuthURL, + ImageURL: t.ImageURL, + } + jsonResponse(w, responseBody, http.StatusOK) +} + +// parseCreateTOTPRequestBody parses and validates the CreateTOTP request body +func parseCreateTOTPRequestBody(body io.ReadCloser) (*CreateTOTPRequestBody, error) { + if body == nil { + return nil, fmt.Errorf("empty request body") + } + + requestBody := &CreateTOTPRequestBody{} + err := json.NewDecoder(body).Decode(requestBody) + if err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + if requestBody.Issuer == "" { + return nil, errors.New("issuer is required") + } + + if requestBody.Name == "" { + requestBody.Name = "SecretKey" + } + + return requestBody, nil +} + +// newTOTP creates a new TOTP passcode +func newTOTP(db *Storage, apiKey ApiKey, issuer, name string) (TOTP, error) { + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: issuer, + AccountName: name, + + // This is an increase from our existing Node.js implementation, which uses a 10-byte secret. + SecretSize: 20, + }) + if err != nil { + return TOTP{}, fmt.Errorf("generate failure: %w", err) + } + + cipherText, err := apiKey.EncryptLegacy(key.Secret()) + if err != nil { + return TOTP{}, fmt.Errorf("encrypt failure: %w", err) + } + + image, err := key.Image(164, 164) + if err != nil { + return TOTP{}, fmt.Errorf("image failure: %w", err) + } + + var buf bytes.Buffer + err = png.Encode(&buf, image) + if err != nil { + return TOTP{}, fmt.Errorf("encode failure: %w", err) + } + + imageDataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(buf.Bytes()) + + t := TOTP{ + UUID: NewUUID(), + ApiKey: apiKey.Key, + EncryptedTotpKey: cipherText, + Key: key.Secret(), + ImageURL: imageDataURL, + OTPAuthURL: key.URL(), + } + + err = db.Store(envConfig.TotpTable, t) + if err != nil { + return TOTP{}, fmt.Errorf("failed to store TOTP: %w", err) + } + return t, nil +} + +// DeleteTOTP is the http handler to delete a passcode. +func (a *App) DeleteTOTP(w http.ResponseWriter, r *http.Request) { + id := r.PathValue(UUIDParam) + + key, err := getAPIKey(r) + if err != nil { + log.Printf("DeleteTOTP API key error: %v", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) + return + } + + var t TOTP + err = a.db.Load(envConfig.TotpTable, TOTPTablePK, id, &t) + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + jsonResponse(w, totpNotFound, http.StatusNotFound) + } else { + log.Printf("error loading TOTP: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) + } + return + } + + if key.Key != t.ApiKey { + jsonResponse(w, totpNotFound, http.StatusNotFound) + return + } + + err = a.db.Delete(envConfig.TotpTable, TOTPTablePK, id) + if err != nil { + log.Printf("Failed to delete TOTP: %s", err) + jsonResponse(w, "Failed to delete TOTP", http.StatusInternalServerError) + return + } + + jsonResponse(w, nil, http.StatusNoContent) +} + +// ValidateTOTP is the http handler to validate a passcode. +func (a *App) ValidateTOTP(w http.ResponseWriter, r *http.Request) { + requestBody, err := parseValidateTOTPRequestBody(r.Body) + if err != nil { + log.Printf("Invalid ValidateTOTP request body: %s", err) + jsonResponse(w, invalidRequest, http.StatusBadRequest) + return + } + + id := r.PathValue(UUIDParam) + + key, err := getAPIKey(r) + if err != nil { + log.Printf("ValidateTOTP API key error: %v", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) + return + } + + var t TOTP + err = a.db.Load(envConfig.TotpTable, TOTPTablePK, id, &t) + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + jsonResponse(w, totpNotFound, http.StatusNotFound) + } else { + log.Printf("error loading TOTP: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) + } + return + } + + if key.Key != t.ApiKey { + jsonResponse(w, totpNotFound, http.StatusNotFound) + return + } + + secret, err := key.DecryptLegacy(t.EncryptedTotpKey) + if err != nil { + log.Printf("failed to decrypt TOTP key: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) + return + } + t.Key = secret + + valid := totp.Validate(requestBody.Code, t.Key) + if !valid { + jsonResponse(w, "Invalid", http.StatusUnauthorized) + return + } + + jsonResponse(w, "Valid", http.StatusOK) +} + +// parseValidateTOTPRequestBody parses and validates the ValidateTOTP request body +func parseValidateTOTPRequestBody(body io.ReadCloser) (*ValidateTOTPRequestBody, error) { + if body == nil { + return nil, fmt.Errorf("empty request body") + } + + requestBody := &ValidateTOTPRequestBody{} + err := json.NewDecoder(body).Decode(requestBody) + if err != nil { + return nil, fmt.Errorf("invalid request: %w", err) + } + + if requestBody.Code == "" { + return nil, errors.New("code is required") + } + + return requestBody, nil +} + +// authTOTP is a just a placeholder for TOTP. It takes the verified API Key and returns it as an authenticated User +// for later use. +func authTOTP(apiKey ApiKey) (User, error) { + return apiKey, nil +} + +// getAPIKey returns the authenticated API Key from the request context. The authentication middleware or +// early handler processing inserts the key into the context for retrieval by this function. +func getAPIKey(r *http.Request) (ApiKey, error) { + key, ok := r.Context().Value(UserContextKey).(ApiKey) + if !ok { + return ApiKey{}, fmt.Errorf("unable to get API key from request context") + } + + return key, nil } diff --git a/totp_test.go b/totp_test.go new file mode 100644 index 0000000..9da38fa --- /dev/null +++ b/totp_test.go @@ -0,0 +1,284 @@ +package mfa + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "time" + + "github.com/pquerna/otp/totp" +) + +func (ms *MfaSuite) TestAppCreateTOTP() { + ctxWithAPIKey := context.WithValue(context.Background(), UserContextKey, newTestKey()) + reqWithAPIKey := (&http.Request{}).WithContext(ctxWithAPIKey) + + validBody := io.NopCloser(strings.NewReader(`{"issuer":"idp","label":"label"}`)) + validRequest := (&http.Request{Body: validBody}).WithContext(ctxWithAPIKey) + + tests := []struct { + name string + request *http.Request + wantStatus int + }{ + { + name: "bad request body", + request: reqWithAPIKey, + wantStatus: http.StatusBadRequest, + }, + { + name: "valid request", + request: validRequest, + wantStatus: http.StatusOK, + }, + } + for _, tt := range tests { + ms.Run(tt.name, func() { + response := httptest.NewRecorder() + ms.app.CreateTOTP(response, tt.request) + ms.Equalf(tt.wantStatus, response.Code, "incorrect http status, response body: %s", response.Body.String()) + + if tt.wantStatus == http.StatusOK { + var responseBody CreateTOTPResponseBody + ms.NoError(json.Unmarshal(response.Body.Bytes(), &responseBody)) + ms.NotEmpty(responseBody.TOTPKey, "TOTPKey is empty") + ms.NotEmpty(responseBody.OTPAuthURL, "OTPAuthURL is empty") + ms.NotEmpty(responseBody.ImageURL, "ImageURL is empty") + ms.NotEmpty(responseBody.UUID, "UUID is empty") + } + }) + } +} + +func (ms *MfaSuite) TestParseCreateTOTPRequestBody() { + tests := []struct { + name string + body string + want *CreateTOTPRequestBody + wantErr string + }{ + { + name: "empty issuer", + body: `{"issuer":"","label":"john_doe@example.com"}`, + want: nil, + wantErr: "issuer is required", + }, + { + name: "empty label", + body: `{"issuer":"idp","label":""}`, + want: &CreateTOTPRequestBody{ + Issuer: "idp", + Name: "SecretKey", + }, + }, + { + name: "fully specified", + body: `{"issuer":"idp","label":"label"}`, + want: &CreateTOTPRequestBody{ + Issuer: "idp", + Name: "label", + }, + }, + } + for _, tt := range tests { + ms.Run(tt.name, func() { + got, err := parseCreateTOTPRequestBody(io.NopCloser(strings.NewReader(tt.body))) + if tt.wantErr != "" { + ms.Error(err) + ms.Equal(tt.wantErr, err.Error()) + return + } + + ms.NoError(err) + ms.Equal(tt.want, got) + }) + } +} + +func (ms *MfaSuite) TestNewTOTP() { + apiKey := newTestKey() + + got, err := newTOTP(ms.app.GetDB(), apiKey, "issuer", "label") + ms.NoError(err) + ms.Equal(apiKey.Key, got.ApiKey, "ApiKey isn't correct") + ms.Regexp("^[a-zA-Z0-9]{32}$", got.Key, "Key length is not correct. Check SecretSize in totp.GenerateOpts.") + ms.Regexp("^data:image/png;base64,[a-zA-Z0-9/+=]+$", got.ImageURL, "ImageURL isn't correct") + + wantOTPAuthURL := "otpauth://totp/issuer:label?algorithm=SHA1&digits=6&issuer=issuer&period=30&secret=" + got.Key + ms.Equal(wantOTPAuthURL, got.OTPAuthURL, "OTPAuthURL isn't correct") + + plainText, err := apiKey.DecryptLegacy(got.EncryptedTotpKey) + ms.NoError(err) + ms.Equal(got.Key, plainText, "EncryptedTotpKey isn't correct") +} + +func (ms *MfaSuite) TestAppDeleteTOTP() { + key := newTestKey() + otherKey := newTestKey() + testTOTP := ms.newPasscode(key) + + ctxWithAPIKey := context.WithValue(context.Background(), UserContextKey, key) + ctxWithOtherAPIKey := context.WithValue(context.Background(), UserContextKey, otherKey) + + mux := &http.ServeMux{} + mux.HandleFunc("DELETE /totp/{"+UUIDParam+"}", ms.app.DeleteTOTP) + + tests := []struct { + name string + request *http.Request + wantStatus int + }{ + { + name: "wrong UUID", + request: ms.newRequest(ctxWithAPIKey, http.MethodDelete, "/totp/"+NewUUID(), ""), + wantStatus: http.StatusNotFound, + }, + { + name: "correct UUID, wrong key", + request: ms.newRequest(ctxWithOtherAPIKey, http.MethodDelete, "/totp/"+testTOTP.UUID, ""), + wantStatus: http.StatusNotFound, + }, + { + name: "correct UUID, correct key", + request: ms.newRequest(ctxWithAPIKey, http.MethodDelete, "/totp/"+testTOTP.UUID, ""), + wantStatus: http.StatusNoContent, + }, + } + for _, tt := range tests { + ms.Run(tt.name, func() { + response := httptest.NewRecorder() + mux.ServeHTTP(response, tt.request) + + ms.Equalf(tt.wantStatus, response.Code, "incorrect http status, response body: %s", response.Body.String()) + }) + } +} + +func (ms *MfaSuite) TestAppValidateTOTP() { + key := newTestKey() + otherKey := newTestKey() + testTOTP, err := newTOTP(ms.app.db, key, "issuer", "name") + ms.NoError(err) + + ctxWithAPIKey := context.WithValue(context.Background(), UserContextKey, key) + ctxWithOtherAPIKey := context.WithValue(context.Background(), UserContextKey, otherKey) + + now := time.Now() + code, err := totp.GenerateCode(testTOTP.Key, now) + ms.NoError(err) + + mux := &http.ServeMux{} + mux.HandleFunc("POST /totp/{"+UUIDParam+"}/validate", ms.app.ValidateTOTP) + + tests := []struct { + name string + request *http.Request + wantStatus int + }{ + { + name: "wrong UUID", + request: ms.newRequest(ctxWithAPIKey, http.MethodPost, + "/totp/"+NewUUID()+"/validate", `{"code":"`+code+`"}`), + wantStatus: http.StatusNotFound, + }, + { + name: "correct UUID, wrong key", + request: ms.newRequest(ctxWithOtherAPIKey, http.MethodPost, + "/totp/"+testTOTP.UUID+"/validate", `{"code":"`+code+`"}`), + wantStatus: http.StatusNotFound, + }, + { + name: "correct UUID, correct key, wrong code", + request: ms.newRequest(ctxWithAPIKey, http.MethodPost, + "/totp/"+testTOTP.UUID+"/validate", `{"code":"000000"}`), + wantStatus: http.StatusUnauthorized, + }, + { + name: "correct UUID, correct key, correct code", + request: ms.newRequest(ctxWithAPIKey, http.MethodPost, + "/totp/"+testTOTP.UUID+"/validate", `{"code":"`+code+`"}`), + wantStatus: http.StatusOK, + }, + } + for _, tt := range tests { + ms.Run(tt.name, func() { + response := httptest.NewRecorder() + mux.ServeHTTP(response, tt.request) + + ms.Equalf(tt.wantStatus, response.Code, "incorrect http status, response body: %s", response.Body.String()) + }) + } +} + +func (ms *MfaSuite) TestParseValidateTOTPRequestBody() { + tests := []struct { + name string + body io.ReadCloser + want *ValidateTOTPRequestBody + wantErr string + }{ + { + name: "no body", + body: nil, + want: nil, + wantErr: "empty request body", + }, + { + name: "empty", + body: io.NopCloser(strings.NewReader("")), + want: nil, + wantErr: "invalid request: EOF", + }, + { + name: "missing code", + body: io.NopCloser(strings.NewReader("{}")), + want: nil, + wantErr: "code is required", + }, + { + name: "correct", + body: io.NopCloser(strings.NewReader(`{"code":"000000"}`)), + want: &ValidateTOTPRequestBody{ + Code: "000000", + }, + }, + } + for _, tt := range tests { + ms.Run(tt.name, func() { + got, err := parseValidateTOTPRequestBody(tt.body) + if tt.wantErr != "" { + ms.Error(err) + ms.Equal(tt.wantErr, err.Error()) + return + } + + ms.NoError(err) + ms.Equal(tt.want, got) + }) + } +} + +func (ms *MfaSuite) newRequest(ctx context.Context, method, path, body string) *http.Request { + r := &http.Request{ + Method: method, + URL: &url.URL{Path: path}, + } + if body != "" { + r.Body = io.NopCloser(strings.NewReader(body)) + } + return r.WithContext(ctx) +} + +func (ms *MfaSuite) newPasscode(key ApiKey) TOTP { + t := TOTP{ + UUID: NewUUID(), + ApiKey: key.Key, + EncryptedTotpKey: mustEncryptLegacy(key, "plain text TOTP key"), + } + must(ms.app.db.Store(ms.app.GetConfig().TotpTable, t)) + return t +} diff --git a/u2fserver/main.go b/u2fserver/main.go index 28f5763..30103a4 100644 --- a/u2fserver/main.go +++ b/u2fserver/main.go @@ -1,13 +1,10 @@ package main import ( - "encoding/json" "log" "net/http" "os" - "github.com/gorilla/mux" - u2fsim "github.com/silinternational/serverless-mfa-api-go/u2fsimulator" ) @@ -22,49 +19,9 @@ func main() { log.Fatal(http.ListenAndServe(":8080", router)) } -// route is used to pass information about a particular route. -type route struct { - Name string - Method string - Pattern string - HandlerFunc http.HandlerFunc -} - -// Define our routes. -var routes = []route{ - // For information on this, see the doc comment for u2fsimulator.U2fRegistration - { - "RegistrationResponse", - "POST", - "/u2f/registration", - u2fsim.U2fRegistration, - }, -} - -// newRouter forms a new mux router, see https://github.com/gorilla/mux. -func newRouter() *mux.Router { - // Create a basic router. - router := mux.NewRouter().StrictSlash(true) - - // Assign the handlers to run when endpoints are called. - for _, route := range routes { - router.Methods(route.Method).Path(route.Pattern).Name(route.Name).Handler(route.HandlerFunc) - } - - router.NotFoundHandler = router.NewRoute().HandlerFunc(notFound).GetHandler() - return router -} - -func notFound(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json; charset=UTF-8") - w.WriteHeader(http.StatusNotFound) - - notFound := map[string]string{ - "Method": r.Method, - "URL": r.URL.String(), - "RequestURI": r.RequestURI, - } - if err := json.NewEncoder(w).Encode(notFound); err != nil { - log.Printf("ERROR could not marshal not found message to JSON: %s", err) - } +// newRouter forms a new http.ServeMux +func newRouter() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("POST /u2f/registration", u2fsim.U2fRegistration) + return mux } diff --git a/u2fsimulator/u2fsimulator.go b/u2fsimulator/u2fsimulator.go index f0dde52..cedfee7 100644 --- a/u2fsimulator/u2fsimulator.go +++ b/u2fsimulator/u2fsimulator.go @@ -86,7 +86,7 @@ func jsonResponse(w http.ResponseWriter, body interface{}, status int) { w.WriteHeader(status) _, err = w.Write(jBody) if err != nil { - log.Printf("faild to write response in jsonResponse: %s\n", err) + log.Printf("failed to write response in jsonResponse: %s", err) } } @@ -123,8 +123,8 @@ func getPrivateKey() *ecdsa.PrivateKey { // and also returns the private key func GetAuthDataAndPrivateKey(rpID, keyHandle string) (authDataStr string, authData []byte, privateKey *ecdsa.PrivateKey) { // Add in the RP ID Hash (32 bytes) - RPIDHash := sha256.Sum256([]byte(rpID)) - for _, r := range RPIDHash { + hash := sha256.Sum256([]byte(rpID)) + for _, r := range hash { authData = append(authData, r) } diff --git a/webauthn.go b/webauthn.go index 1a02ee5..5ec0bb0 100644 --- a/webauthn.go +++ b/webauthn.go @@ -11,8 +11,6 @@ import ( "github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/webauthn" - "github.com/gorilla/mux" - uuid "github.com/satori/go.uuid" ) // WebauthnMeta holds metadata about the calling service for use in WebAuthn responses. @@ -51,18 +49,20 @@ type finishLoginResponse struct { func (a *App) BeginRegistration(w http.ResponseWriter, r *http.Request) { user, err := getWebauthnUser(r) if err != nil { - jsonResponse(w, err, http.StatusBadRequest) + log.Printf("failed to get user for BeginRegistration: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } // If user.id is empty, treat as new user/registration if user.ID == "" { - user.ID = uuid.NewV4().String() + user.ID = NewUUID() } options, err := user.BeginRegistration() if err != nil { - jsonResponse(w, fmt.Sprintf("failed to begin registration: %s", err), http.StatusBadRequest) + log.Printf("failed to begin registration: %s", err) + jsonResponse(w, invalidRequest, http.StatusBadRequest) return } @@ -79,13 +79,15 @@ func (a *App) BeginRegistration(w http.ResponseWriter, r *http.Request) { func (a *App) FinishRegistration(w http.ResponseWriter, r *http.Request) { user, err := getWebauthnUser(r) if err != nil { - jsonResponse(w, err, http.StatusBadRequest) + log.Printf("failed to get user for FinishRegistration: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } keyHandleHash, err := user.FinishRegistration(r) if err != nil { - jsonResponse(w, err, http.StatusBadRequest) + log.Printf("failed to finish registration: %s", err) + jsonResponse(w, invalidRequest, http.StatusBadRequest) return } @@ -101,15 +103,15 @@ func (a *App) FinishRegistration(w http.ResponseWriter, r *http.Request) { func (a *App) BeginLogin(w http.ResponseWriter, r *http.Request) { user, err := getWebauthnUser(r) if err != nil { - jsonResponse(w, err, http.StatusBadRequest) - log.Printf("error getting user from context: %s\n", err) + log.Printf("failed to get user for BeginLogin: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } options, err := user.BeginLogin() if err != nil { - jsonResponse(w, err, http.StatusBadRequest) - log.Printf("error beginning user login: %s\n", err) + log.Printf("error beginning user login: %s", err) + jsonResponse(w, invalidRequest, http.StatusBadRequest) return } @@ -121,15 +123,19 @@ func (a *App) BeginLogin(w http.ResponseWriter, r *http.Request) { func (a *App) FinishLogin(w http.ResponseWriter, r *http.Request) { user, err := getWebauthnUser(r) if err != nil { - jsonResponse(w, err, http.StatusBadRequest) - log.Printf("error getting user from context: %s\n", err) + log.Printf("failed to get user for FinishLogin: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } credential, err := user.FinishLogin(r) if err != nil { - jsonResponse(w, err, http.StatusBadRequest) - log.Printf("error finishing user login : %s\n", err) + // SonarQube flagged this as vulnerable to injection attacks. Rather than exhaustively search for places + // where user input is inserted into the error message, I'll just sanitize it as recommended. + sanitizedError := strings.ReplaceAll(strings.ReplaceAll(err.Error(), "\n", "_"), "\r", "_") + log.Printf("error finishing user login: %s", sanitizedError) + + jsonResponse(w, invalidRequest, http.StatusBadRequest) return } @@ -146,14 +152,14 @@ func (a *App) FinishLogin(w http.ResponseWriter, r *http.Request) { func (a *App) DeleteUser(w http.ResponseWriter, r *http.Request) { user, err := getWebauthnUser(r) if err != nil { - jsonResponse(w, err, http.StatusBadRequest) - log.Printf("error getting user from context: %s\n", err) + log.Printf("failed to get user for DeleteUser: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } if err := user.Delete(); err != nil { - jsonResponse(w, err, http.StatusInternalServerError) log.Printf("error deleting user: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } @@ -162,30 +168,39 @@ func (a *App) DeleteUser(w http.ResponseWriter, r *http.Request) { // DeleteCredential is the handler for the "DELETE /webauthn/credential/{credID}" endpoint. It removes a single // passkey identified by "credID", which is the key_handle_hash returned by the FinishRegistration endpoint, or "u2f" -// if it is a legacy U2F credential. +// if it is a legacy U2F credential, in which case that user is saved with all of its legacy u2f fields blanked out. func (a *App) DeleteCredential(w http.ResponseWriter, r *http.Request) { user, err := getWebauthnUser(r) if err != nil { - jsonResponse(w, err, http.StatusBadRequest) - log.Printf("error getting user from context: %s\n", err) + log.Printf("failed to get user for DeleteCredential: %s", err) + jsonResponse(w, internalServerError, http.StatusInternalServerError) return } - params := mux.Vars(r) - credID, ok := params[IDParam] - if !ok || credID == "" { - err := fmt.Errorf("%s path parameter not provided to DeleteCredential", IDParam) - jsonResponse(w, err, http.StatusBadRequest) - log.Printf("%s\n", err) + credID := r.PathValue(IDParam) + if credID == "" { + err := fmt.Errorf("%s path parameter not provided to DeleteCredential, path: %s", IDParam, r.URL.Path) + log.Printf("%s", err) + jsonResponse(w, invalidRequest, http.StatusBadRequest) return } status, err := user.DeleteCredential(credID) if err != nil { - log.Printf("error deleting user credential: %s", err) + log.Printf("error deleting user credential (%d): %s", status, err) } - jsonResponse(w, err, status) + switch status { + case http.StatusNoContent: + jsonResponse(w, nil, status) + case http.StatusNotFound: + jsonResponse(w, "Not found", status) + case http.StatusInternalServerError: + jsonResponse(w, internalServerError, status) + default: + log.Printf("unexpected status code (%d)", status) + jsonResponse(w, internalServerError, http.StatusInternalServerError) + } } // fixStringEncoding converts a string from standard Base64 to Base64-URL @@ -212,7 +227,7 @@ func getWebAuthnFromApiMeta(meta WebauthnMeta) (*webauthn.WebAuthn, error) { Debug: true, }) if err != nil { - fmt.Println(err) + log.Printf("failed to get new webauthn: %s", err) } return web, nil diff --git a/webauthn_test.go b/webauthn_test.go index e3a6cbf..14452c7 100644 --- a/webauthn_test.go +++ b/webauthn_test.go @@ -17,7 +17,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/webauthn" - "github.com/gorilla/mux" "github.com/stretchr/testify/require" u2fsim "github.com/silinternational/serverless-mfa-api-go/u2fsimulator" @@ -173,7 +172,7 @@ func (ms *MfaSuite) Test_BeginRegistration() { httpWriter: newLambdaResponseWriter(), httpReq: http.Request{}, wantBodyContains: []string{ - `"error":"unable to get user from request context"`, + `"error":"` + internalServerError + `"}`, }, }, { @@ -312,13 +311,13 @@ func (ms *MfaSuite) Test_FinishRegistration() { { name: "no user", httpReq: http.Request{}, - wantBodyContains: []string{`"error":"unable to get user from request context"`}, + wantBodyContains: []string{`"error":"` + internalServerError + `"}`}, }, { name: "request has no body", httpReq: reqNoBody, wantBodyContains: []string{ - `"error":"request Body may not be nil in FinishRegistration"`, + `"error":"` + invalidRequest + `"}`, }, }, { @@ -480,13 +479,13 @@ func (ms *MfaSuite) Test_BeginLogin() { name: "no user", httpWriter: newLambdaResponseWriter(), httpReq: http.Request{}, - wantBodyContains: []string{`"error":"unable to get user from request context"`}, + wantBodyContains: []string{`"error":"` + internalServerError + `"}`}, }, { name: "has a user but no credentials", httpWriter: newLambdaResponseWriter(), httpReq: reqNoCredentials, - wantBodyContains: []string{`"error":"Found no credentials for user"`}, + wantBodyContains: []string{`"error":"` + invalidRequest + `"}`}, }, { name: "has a user with credentials", @@ -657,7 +656,7 @@ func (ms *MfaSuite) Test_FinishLogin() { { name: "no user", httpReq: http.Request{}, - wantBodyContains: []string{`"error":"unable to get user from request context"`}, + wantBodyContains: []string{`"error":"` + internalServerError + `"}`}, }, { name: "with first credential", @@ -745,18 +744,14 @@ func Test_GetPublicKeyAsBytes(t *testing.T) { assert.Equal(want, got, "incorrect public Key") } -func Router(app *App) *mux.Router { - router := mux.NewRouter() - router.HandleFunc(fmt.Sprintf("/webauthn/credential/{%s}", IDParam), app.DeleteCredential).Methods("DELETE") +func Router(app *App) http.Handler { + mux := &http.ServeMux{} + mux.HandleFunc(fmt.Sprintf("DELETE /webauthn/credential/{%s}", IDParam), app.DeleteCredential) // Ensure a request without an id gets handled properly - router.HandleFunc("/webauthn/credential/", app.DeleteCredential).Methods("DELETE") - router.HandleFunc("/webauthn/credential", app.DeleteCredential).Methods("DELETE") + mux.HandleFunc("DELETE /webauthn/credential/", app.DeleteCredential) + mux.HandleFunc("DELETE /webauthn/credential", app.DeleteCredential) - // authenticate request based on api key and secret in headers - // also adds user to context - router.Use(testAuthnMiddleware) - - return router + return testAuthnMiddleware(mux) } func testAuthnMiddleware(next http.Handler) http.Handler { @@ -861,7 +856,7 @@ func (ms *MfaSuite) Test_DeleteCredential() { response := httptest.NewRecorder() Router(ms.app).ServeHTTP(response, request) - ms.Equal(tt.wantStatus, response.Code, "incorrect http status") + ms.Equal(tt.wantStatus, response.Code, "incorrect http status, body: %s", response.Body.String()) if tt.wantStatus != http.StatusNoContent { return diff --git a/webauthnuser.go b/webauthnuser.go index a8c7901..c816a7b 100644 --- a/webauthnuser.go +++ b/webauthnuser.go @@ -77,7 +77,7 @@ func NewWebauthnUser(apiConfig WebauthnMeta, storage *Storage, apiKey ApiKey, we err := u.Load() if err != nil { - log.Printf("failed to load user: %s\n", err) + log.Printf("failed to load user: %s", err) } return u } @@ -107,7 +107,7 @@ func (u *WebauthnUser) saveSessionData(sessionData webauthn.SessionData) error { js, err := json.Marshal(sessionData) if err != nil { - log.Printf("error marshaling session data to json. Session data: %+v\n Error: %s\n", sessionData, err) + log.Printf("error marshaling session data to json. Session data: %+v, Error: %s", sessionData, err) return err } @@ -354,7 +354,7 @@ func (u *WebauthnUser) BeginLogin() (*protocol.CredentialAssertion, error) { err = u.saveSessionData(*sessionData) if err != nil { - log.Printf("error saving session data: %s\n", err) + log.Printf("error saving session data: %s", err) return nil, err }