Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continuation of JWKS support from PR #71 #85

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 48 additions & 1 deletion _example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ import (
"github.com/lestrrat-go/jwx/v2/jwt"
)

type dynamicTokenAuth struct {
keySet []byte
}

func (d *dynamicTokenAuth) JWTAuth() (*jwtauth.JWTAuth, error) {
keySet, err := jwtauth.NewKeySet(d.keySet)
if err != nil {
return nil, err
}
return keySet, nil
}

var tokenAuth *jwtauth.JWTAuth

func init() {
Expand All @@ -76,7 +88,8 @@ func init() {
// For debugging/example purposes, we generate and print
// a sample jwt token with claims `user_id:123` here:
_, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123})
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
fmt.Printf("DEBUG: a sample jwt for /admin is %s\n\n", tokenString)
fmt.Printf("DEBUG: a sample jwt for /rotate is %s\n\n", sampleJWTRotate)
}

func main() {
Expand Down Expand Up @@ -105,6 +118,23 @@ func router() http.Handler {
})
})

r.Group(func(r chi.Router) {
dynamicTokenAuth := dynamicTokenAuth{keySet: keySet}
// Seek, verify and validate JWT tokens based on keys returned by the callback function
r.Use(jwtauth.VerifierDynamic(dynamicTokenAuth.JWTAuth))

// Handle valid / invalid tokens. In this example, we use
// the provided authenticator middleware, but you can write your
// own very easily, look at the Authenticator method in jwtauth.go
// and tweak it, its not scary.
r.Use(jwtauth.Authenticator)

r.Get("/rotate", func(w http.ResponseWriter, r *http.Request) {
_, claims, _ := jwtauth.FromContext(r.Context())
w.Write([]byte(fmt.Sprintf("protected area. hi %v", claims["user_id"])))
})
})

// Public routes
r.Group(func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -114,3 +144,20 @@ func router() http.Handler {

return r
}

var (
keySet = []byte(`{
"keys": [
{
"kty": "RSA",
"alg": "RS256",
"kid": "kid",
"use": "sig",
"n": "rgzO_v14UXJ33MvccKI8aIw3YpknVJbRB-m1z1X4j3gaTmmzmb7_naEd1TOKhF6Z1BGupvAKhCs8uHtp5e1PCrp52kzrjv7nqQfDpdppPZmKpwf-OD_lVgLLuCljB71mX9w7T5vI_WiVknuNhm48y0TJQNslpDZum4E2e0BLKUDRKKlo25foGoDuQN535_Xso861U8KsA80jX37BJplQ6IHewV_bbe04NYTVqaFcmLaZCAzh2f8L1h4xt76Y0xF_u8FXt2-rgcWlz17CtZzxC8ZXNI_92pX8CY5LY2eQf_B_n5Rhd5TQvEIdoI1GNBrcKUI9pMeEC4pErcOGgKGH7w",
"e": "AQAB"
}
]
}`)

sampleJWTRotate = `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImtpZCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.APC4bUOmfbcXjBnZnmyiGBpXqlboTB4Qbh_sqJrgSU5AEQlwzjvDJ79eBlty8h6kfq3i5ffy87s-g82ZoRsHqMjwCIvTOVnoEyDgVu68s9lE32uaA0cc2-hbA13DIBsyIUGjehh9c3h93BrUoUr7n0CHgoKgx2OEw1Bq8vm4EqvmFGF-mr_0qi32uudPy3I15SyP1NJfU0ogQEFUdDHww3c8omDmrTPiGlWZAl9AiBMroDu0nq3UOtC4d5Se-361NEGiZ9J_kHcVWGdoMwsi5KEB0Uf3wAfXK3wcXeRu1pTXYKOV3X3g_2ss6mh65bNMsSx-MZUnQv5v6qZMOxMBUA`
)
21 changes: 21 additions & 0 deletions jwtauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package jwtauth

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
)

Expand All @@ -17,6 +20,7 @@ type JWTAuth struct {
verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms
verifier jwt.ParseOption
validateOptions []jwt.ValidateOption
keySet jwk.Set
}

var (
Expand Down Expand Up @@ -50,6 +54,19 @@ func New(alg string, signKey interface{}, verifyKey interface{}, validateOptions
return ja
}

func NewKeySet(keySet []byte) (*JWTAuth, error) {
ks := jwk.NewSet()
err := json.Unmarshal(keySet, &ks)
if err != nil {
return nil, err
}

ja := &JWTAuth{keySet: ks}
ja.verifier = jwt.WithKeySet(ks)

return ja, nil
}

// Verifier http middleware handler will verify a JWT string from a http request.
//
// Verifier will search for a JWT token in a http request, in the order:
Expand Down Expand Up @@ -120,6 +137,10 @@ func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) {
}

func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) {
if ja.keySet != nil {
return nil, "", fmt.Errorf("encode not supported")
}

t = jwt.New()
for k, v := range claims {
t.Set(k, v)
Expand Down
164 changes: 164 additions & 0 deletions jwtauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"testing"
"time"

"github.com/lestrrat-go/jwx/v2/jws"

"github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
Expand Down Expand Up @@ -41,6 +43,27 @@ MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYw
DLxxa5/7QyH6y77nCRQyJ3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQ==
-----END PUBLIC KEY-----
`

KeySet = `{
"keys": [
{
"kty": "RSA",
"n": "vGjc8KMXDhCOA5fTpAIkgkGddc2IRjAMvHFrn_tDIfrLvucJFDInfHdTAX2tQPREKyniw11fmQ5D09TIfI60JQ",
"e": "AQAB",
"alg": "RS256",
"kid": "1",
"use": "sig"
},
{
"kty": "RSA",
"n": "foo",
"e": "AQAB",
"alg": "RS256",
"kid": "2",
"use": "sig"
}
]
}`
)

func init() {
Expand All @@ -51,6 +74,57 @@ func init() {
// Tests
//

func TestNewKeySet(t *testing.T) {
_, err := jwtauth.NewKeySet([]byte("not a valid key set"))
if err == nil {
t.Fatal("The error should not be nil")
}

_, err = jwtauth.NewKeySet([]byte(KeySet))
if err != nil {
t.Fatalf(err.Error())
}
}

func TestKeySetRSA(t *testing.T) {
privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String))

privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)

if err != nil {
t.Fatalf(err.Error())
}

KeySetAuth, _ := jwtauth.NewKeySet([]byte(KeySet))
claims := map[string]interface{}{
"key": "val",
"key2": "val2",
"key3": "val3",
}

signed := newJwtRSAToken(jwa.RS256, privateKey, "1", claims)

token, err := KeySetAuth.Decode(signed)

if err != nil {
t.Fatalf("Failed to decode token string %s\n", err.Error())
}

tokenClaims, err := token.AsMap(context.Background())
if err != nil {
t.Fatal(err.Error())
}

if !reflect.DeepEqual(claims, tokenClaims) {
t.Fatalf("The decoded claims don't match the original ones\n")
}

_, _, err = KeySetAuth.Encode(claims)
if err.Error() != "encode not supported" {
t.Fatalf("Expect error to equal %s. Found: %s.", "encode not supported", err.Error())
}
}

func TestSimple(t *testing.T) {
r := chi.NewRouter()

Expand Down Expand Up @@ -279,6 +353,73 @@ func TestMore(t *testing.T) {
}
}

func TestKeySet(t *testing.T) {
privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String))
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
if err != nil {
t.Fatalf(err.Error())
}

r := chi.NewRouter()

keySet, err := jwtauth.NewKeySet([]byte(KeySet))
if err != nil {
t.Fatalf(err.Error())
}

// Protected routes
r.Group(func(r chi.Router) {
r.Use(jwtauth.Verifier(keySet))

authenticator := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, _, err := jwtauth.FromContext(r.Context())

if err != nil {
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
return
}

if err := jwt.Validate(token); err != nil {
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
return
}

// Token is authenticated, pass it through
next.ServeHTTP(w, r)
})
}
r.Use(authenticator)

r.Get("/admin", func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())

if err != nil {
w.Write([]byte(fmt.Sprintf("error! %v", err)))
return
}

w.Write([]byte(fmt.Sprintf("protected, user:%v", claims["user_id"])))
})
})

// Public routes
r.Group(func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("welcome"))
})
})

ts := httptest.NewServer(r)
defer ts.Close()

h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtRSAToken(jwa.RS256, privateKey, "1", map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
t.Fatalf(resp)
}
}

//
// Test helper functions
//
Expand Down Expand Up @@ -340,6 +481,29 @@ func newJwt512Token(secret []byte, claims ...map[string]interface{}) string {
return string(tokenPayload)
}

func newJwtRSAToken(alg jwa.SignatureAlgorithm, secret interface{}, kid string, claims ...map[string]interface{}) string {
token := jwt.New()
if len(claims) > 0 {
for k, v := range claims[0] {
token.Set(k, v)
}
}

headers := jws.NewHeaders()
if kid != "" {
err := headers.Set("kid", kid)
if err != nil {
log.Fatal(err)
}
}

tokenPayload, err := jwt.Sign(token, jwt.WithKey(alg, secret, jws.WithProtectedHeaders(headers)))
if err != nil {
log.Fatal(err)
}
return string(tokenPayload)
}

func newAuthHeader(claims ...map[string]interface{}) http.Header {
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
Expand Down