123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- package oidc
- import (
- "context"
- "errors"
- "fmt"
- "io/ioutil"
- "net/http"
- "sync"
- "time"
- "github.com/pquerna/cachecontrol"
- jose "gopkg.in/square/go-jose.v2"
- )
- // keysExpiryDelta is the allowed clock skew between a client and the OpenID Connect
- // server.
- //
- // When keys expire, they are valid for this amount of time after.
- //
- // If the keys have not expired, and an ID Token claims it was signed by a key not in
- // the cache, if and only if the keys expire in this amount of time, the keys will be
- // updated.
- const keysExpiryDelta = 30 * time.Second
- // NewRemoteKeySet returns a KeySet that can validate JSON web tokens by using HTTP
- // GETs to fetch JSON web token sets hosted at a remote URL. This is automatically
- // used by NewProvider using the URLs returned by OpenID Connect discovery, but is
- // exposed for providers that don't support discovery or to prevent round trips to the
- // discovery URL.
- //
- // The returned KeySet is a long lived verifier that caches keys based on cache-control
- // headers. Reuse a common remote key set instead of creating new ones as needed.
- //
- // The behavior of the returned KeySet is undefined once the context is canceled.
- func NewRemoteKeySet(ctx context.Context, jwksURL string) KeySet {
- return newRemoteKeySet(ctx, jwksURL, time.Now)
- }
- func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time) *remoteKeySet {
- if now == nil {
- now = time.Now
- }
- return &remoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now}
- }
- type remoteKeySet struct {
- jwksURL string
- ctx context.Context
- now func() time.Time
- // guard all other fields
- mu sync.Mutex
- // inflight suppresses parallel execution of updateKeys and allows
- // multiple goroutines to wait for its result.
- inflight *inflight
- // A set of cached keys and their expiry.
- cachedKeys []jose.JSONWebKey
- expiry time.Time
- }
- // inflight is used to wait on some in-flight request from multiple goroutines.
- type inflight struct {
- doneCh chan struct{}
- keys []jose.JSONWebKey
- err error
- }
- func newInflight() *inflight {
- return &inflight{doneCh: make(chan struct{})}
- }
- // wait returns a channel that multiple goroutines can receive on. Once it returns
- // a value, the inflight request is done and result() can be inspected.
- func (i *inflight) wait() <-chan struct{} {
- return i.doneCh
- }
- // done can only be called by a single goroutine. It records the result of the
- // inflight request and signals other goroutines that the result is safe to
- // inspect.
- func (i *inflight) done(keys []jose.JSONWebKey, err error) {
- i.keys = keys
- i.err = err
- close(i.doneCh)
- }
- // result cannot be called until the wait() channel has returned a value.
- func (i *inflight) result() ([]jose.JSONWebKey, error) {
- return i.keys, i.err
- }
- func (r *remoteKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
- jws, err := jose.ParseSigned(jwt)
- if err != nil {
- return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
- }
- return r.verify(ctx, jws)
- }
- func (r *remoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
- // We don't support JWTs signed with multiple signatures.
- keyID := ""
- for _, sig := range jws.Signatures {
- keyID = sig.Header.KeyID
- break
- }
- keys, expiry := r.keysFromCache()
- // Don't check expiry yet. This optimizes for when the provider is unavailable.
- for _, key := range keys {
- if keyID == "" || key.KeyID == keyID {
- if payload, err := jws.Verify(&key); err == nil {
- return payload, nil
- }
- }
- }
- if !r.now().Add(keysExpiryDelta).After(expiry) {
- // Keys haven't expired, don't refresh.
- return nil, errors.New("failed to verify id token signature")
- }
- keys, err := r.keysFromRemote(ctx)
- if err != nil {
- return nil, fmt.Errorf("fetching keys %v", err)
- }
- for _, key := range keys {
- if keyID == "" || key.KeyID == keyID {
- if payload, err := jws.Verify(&key); err == nil {
- return payload, nil
- }
- }
- }
- return nil, errors.New("failed to verify id token signature")
- }
- func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey, expiry time.Time) {
- r.mu.Lock()
- defer r.mu.Unlock()
- return r.cachedKeys, r.expiry
- }
- // keysFromRemote syncs the key set from the remote set, records the values in the
- // cache, and returns the key set.
- func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
- // Need to lock to inspect the inflight request field.
- r.mu.Lock()
- // If there's not a current inflight request, create one.
- if r.inflight == nil {
- r.inflight = newInflight()
- // This goroutine has exclusive ownership over the current inflight
- // request. It releases the resource by nil'ing the inflight field
- // once the goroutine is done.
- go func() {
- // Sync keys and finish inflight when that's done.
- keys, expiry, err := r.updateKeys()
- r.inflight.done(keys, err)
- // Lock to update the keys and indicate that there is no longer an
- // inflight request.
- r.mu.Lock()
- defer r.mu.Unlock()
- if err == nil {
- r.cachedKeys = keys
- r.expiry = expiry
- }
- // Free inflight so a different request can run.
- r.inflight = nil
- }()
- }
- inflight := r.inflight
- r.mu.Unlock()
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-inflight.wait():
- return inflight.result()
- }
- }
- func (r *remoteKeySet) updateKeys() ([]jose.JSONWebKey, time.Time, error) {
- req, err := http.NewRequest("GET", r.jwksURL, nil)
- if err != nil {
- return nil, time.Time{}, fmt.Errorf("oidc: can't create request: %v", err)
- }
- resp, err := doRequest(r.ctx, req)
- if err != nil {
- return nil, time.Time{}, fmt.Errorf("oidc: get keys failed %v", err)
- }
- defer resp.Body.Close()
- body, err := ioutil.ReadAll(resp.Body)
- if err != nil {
- return nil, time.Time{}, fmt.Errorf("unable to read response body: %v", err)
- }
- if resp.StatusCode != http.StatusOK {
- return nil, time.Time{}, fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body)
- }
- var keySet jose.JSONWebKeySet
- err = unmarshalResp(resp, body, &keySet)
- if err != nil {
- return nil, time.Time{}, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body)
- }
- // If the server doesn't provide cache control headers, assume the
- // keys expire immediately.
- expiry := r.now()
- _, e, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{})
- if err == nil && e.After(expiry) {
- expiry = e
- }
- return keySet.Keys, expiry, nil
- }
|