diff --git a/cmd/jwt-wallet/main_test.go b/cmd/jwt-wallet/main_test.go index 5343ef8..369765f 100644 --- a/cmd/jwt-wallet/main_test.go +++ b/cmd/jwt-wallet/main_test.go @@ -9,12 +9,12 @@ import ( "testing" "time" - "github.com/Kong/go-pdk/test" - secp256k1 "github.com/btcsuite/btcd/btcec" - "github.com/golang-jwt/jwt/v4" jwtwallet "github.com/FigureTechnologies/kong-jwt-wallet" "github.com/FigureTechnologies/kong-jwt-wallet/grants" "github.com/FigureTechnologies/kong-jwt-wallet/signing" + "github.com/Kong/go-pdk/test" + secp256k1 "github.com/btcsuite/btcd/btcec" + "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/assert" ) @@ -68,7 +68,7 @@ func TestMissingAddrClaim(t *testing.T) { pkBytes, _ := hex.DecodeString("8C037EFC21AB3F0F8D32CF209D90FDBF41D10071FF600BA66A30EFA994F268A3") prvk, pubk := secp256k1.PrivKeyFromBytes(secp256k1.S256(), pkBytes) - claims := GenerateClaims("", pubk) + claims := GenerateClaims("", pubk, base64.RawURLEncoding) token := jwt.NewWithClaims(signing.NewSecp256k1Signer(), claims) sig, _ := token.SignedString(prvk) @@ -89,7 +89,7 @@ func TestMissingSubClaim(t *testing.T) { pkBytes, _ := hex.DecodeString("8C037EFC21AB3F0F8D32CF209D90FDBF41D10071FF600BA66A30EFA994F268A3") prvk, _ := secp256k1.PrivKeyFromBytes(secp256k1.S256(), pkBytes) - claims := GenerateClaims("tbMadeUpAddr", nil) + claims := GenerateClaims("tbMadeUpAddr", nil, base64.RawURLEncoding) token := jwt.NewWithClaims(signing.NewSecp256k1Signer(), claims) sig, _ := token.SignedString(prvk) @@ -109,7 +109,7 @@ func TestExpiredToken(t *testing.T) { pkBytes, _ := hex.DecodeString("8C037EFC21AB3F0F8D32CF209D90FDBF41D10071FF600BA66A30EFA994F268A3") prvk, pubk := secp256k1.PrivKeyFromBytes(secp256k1.S256(), pkBytes) - claims := GenerateClaims("tb1MadeUpAddr", pubk) + claims := GenerateClaims("tb1MadeUpAddr", pubk, base64.RawURLEncoding) claims.ExpiresAt = jwt.NewNumericDate(time.Date(1999, 12, 31, 11, 10, 0, 0, time.Local)) token := jwt.NewWithClaims(signing.NewSecp256k1Signer(), claims) sig, _ := token.SignedString(prvk) @@ -126,10 +126,40 @@ func TestExpiredToken(t *testing.T) { assert.Equal(t, 401, env.ClientRes.Status) } -func TestValidJwt(t *testing.T) { +func TestValidUrlEncodedJwt(t *testing.T) { + pkBytes, _ := hex.DecodeString("8C037EFC21AB3F0F8D32CF209D90FDBF41D10071FF600BA66A30EFA994F268A3") + prvk, pubk := secp256k1.PrivKeyFromBytes(secp256k1.S256(), pkBytes) + claims := GenerateClaims("tp1y34frcm3hmnmgszmnxufcyw4aeslplsz8hkuxv", pubk, base64.RawURLEncoding) + token := jwt.NewWithClaims(signing.NewSecp256k1Signer(), claims) + sig, _ := token.SignedString(prvk) + + r := ioutil.NopCloser(bytes.NewReader([]byte(subjectJSONString))) + GetDoFunc = func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: r, + }, nil + } + + env, err := test.New(t, test.Request{ + Method: "GET", + Url: "http://example.com", + Headers: map[string][]string{"Authorization": {"Bearer " + sig}}, + }) + assert.NoError(t, err) + + env.DoHttp(config) + + assert.Equal(t, 200, env.ClientRes.Status) + assert.NotEmpty(t, env.ServiceReq.Headers.Get("x-wallet-access")) + assert.Empty(t, env.ServiceReq.Headers.Get("x-sender")) + assert.Equal(t, subjectJSONString, env.ServiceReq.Headers.Get("x-wallet-access")) +} + +func TestValidNonUrlEncodedJwt(t *testing.T) { pkBytes, _ := hex.DecodeString("8C037EFC21AB3F0F8D32CF209D90FDBF41D10071FF600BA66A30EFA994F268A3") prvk, pubk := secp256k1.PrivKeyFromBytes(secp256k1.S256(), pkBytes) - claims := GenerateClaims("tp1y34frcm3hmnmgszmnxufcyw4aeslplsz8hkuxv", pubk) + claims := GenerateClaims("tp1y34frcm3hmnmgszmnxufcyw4aeslplsz8hkuxv", pubk, base64.RawStdEncoding) token := jwt.NewWithClaims(signing.NewSecp256k1Signer(), claims) sig, _ := token.SignedString(prvk) @@ -160,7 +190,7 @@ func TestValidJwtWithEmptyRbacUrl(t *testing.T) { pkBytes, _ := hex.DecodeString("8C037EFC21AB3F0F8D32CF209D90FDBF41D10071FF600BA66A30EFA994F268A3") prvk, pubk := secp256k1.PrivKeyFromBytes(secp256k1.S256(), pkBytes) addr := "tp1y34frcm3hmnmgszmnxufcyw4aeslplsz8hkuxv" - claims := GenerateClaims(addr, pubk) + claims := GenerateClaims(addr, pubk, base64.RawURLEncoding) token := jwt.NewWithClaims(signing.NewSecp256k1Signer(), claims) sig, _ := token.SignedString(prvk) @@ -186,7 +216,7 @@ func TestIncorrectAddress(t *testing.T) { pkBytes, _ := hex.DecodeString("8C037EFC21AB3F0F8D32CF209D90FDBF41D10071FF600BA66A30EFA994F268A3") prvk, pubk := secp256k1.PrivKeyFromBytes(secp256k1.S256(), pkBytes) - claims := GenerateClaims("tp1rr4d0eu62pgt4edw38d2ev27798pfhdhp5ttha", pubk) + claims := GenerateClaims("tp1rr4d0eu62pgt4edw38d2ev27798pfhdhp5ttha", pubk, base64.RawURLEncoding) token := jwt.NewWithClaims(signing.NewSecp256k1Signer(), claims) sig, _ := token.SignedString(prvk) @@ -203,10 +233,10 @@ func TestIncorrectAddress(t *testing.T) { assert.Equal(t, "address does not match public key", env.ClientRes.Body) } -func GenerateClaims(addr string, pubKey *secp256k1.PublicKey) *signing.Claims { +func GenerateClaims(addr string, pubKey *secp256k1.PublicKey, encoding *base64.Encoding) *signing.Claims { compressedKey := "" if pubKey != nil { - compressedKey = base64.RawURLEncoding.EncodeToString(pubKey.SerializeCompressed()) + compressedKey = encoding.EncodeToString(pubKey.SerializeCompressed()) } loc, _ := time.LoadLocation("GMT") return &signing.Claims{ diff --git a/jwt-wallet.go b/jwt-wallet.go index 747d45f..a64e654 100644 --- a/jwt-wallet.go +++ b/jwt-wallet.go @@ -8,9 +8,9 @@ import ( "hash" "strings" - "github.com/cosmos/btcutil/bech32" "github.com/FigureTechnologies/kong-jwt-wallet/grants" "github.com/FigureTechnologies/kong-jwt-wallet/signing" + "github.com/cosmos/btcutil/bech32" "golang.org/x/crypto/ripemd160" "github.com/Kong/go-pdk" @@ -146,11 +146,14 @@ func verifyAddress(addr string, pubKey string, kong *pdk.PDK) bool { hrp := addr[0:separator] keyB64 := strings.Split(pubKey, ",")[0] - keyBytes, err := base64.RawURLEncoding.DecodeString(keyB64) + var keyBytes, err = base64.RawStdEncoding.DecodeString(keyB64) if err != nil { - kong.Log.Err("Could not decode public key") - return false + keyBytes, err = base64.RawURLEncoding.DecodeString(keyB64) + if err != nil { + kong.Log.Err("Could not decode public key") + return false + } } hash160Bytes := Hash160(keyBytes) diff --git a/signing/signing.go b/signing/signing.go index b7d1ba9..05fc574 100644 --- a/signing/signing.go +++ b/signing/signing.go @@ -39,9 +39,12 @@ func ParseKey(kong *pdk.PDK) func(token *jwt.Token) (interface{}, error) { return nil, fmt.Errorf("no subject") } keyB64 := strings.Split(sub, ",")[0] - keyBytes, err := base64.RawURLEncoding.DecodeString(keyB64) + var keyBytes, err = base64.RawStdEncoding.DecodeString(keyB64) if err != nil { - return nil, err + keyBytes, err = base64.RawURLEncoding.DecodeString(keyB64) + if err != nil { + return nil, err + } } pubk, err := secp256k1.ParsePubKey(keyBytes, secp256k1.S256()) if err != nil {