jwks.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. package oidc
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io/ioutil"
  7. "net/http"
  8. "sync"
  9. "time"
  10. "github.com/pquerna/cachecontrol"
  11. jose "gopkg.in/square/go-jose.v2"
  12. )
  13. // keysExpiryDelta is the allowed clock skew between a client and the OpenID Connect
  14. // server.
  15. //
  16. // When keys expire, they are valid for this amount of time after.
  17. //
  18. // If the keys have not expired, and an ID Token claims it was signed by a key not in
  19. // the cache, if and only if the keys expire in this amount of time, the keys will be
  20. // updated.
  21. const keysExpiryDelta = 30 * time.Second
  22. // NewRemoteKeySet returns a KeySet that can validate JSON web tokens by using HTTP
  23. // GETs to fetch JSON web token sets hosted at a remote URL. This is automatically
  24. // used by NewProvider using the URLs returned by OpenID Connect discovery, but is
  25. // exposed for providers that don't support discovery or to prevent round trips to the
  26. // discovery URL.
  27. //
  28. // The returned KeySet is a long lived verifier that caches keys based on cache-control
  29. // headers. Reuse a common remote key set instead of creating new ones as needed.
  30. //
  31. // The behavior of the returned KeySet is undefined once the context is canceled.
  32. func NewRemoteKeySet(ctx context.Context, jwksURL string) KeySet {
  33. return newRemoteKeySet(ctx, jwksURL, time.Now)
  34. }
  35. func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time) *remoteKeySet {
  36. if now == nil {
  37. now = time.Now
  38. }
  39. return &remoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now}
  40. }
  41. type remoteKeySet struct {
  42. jwksURL string
  43. ctx context.Context
  44. now func() time.Time
  45. // guard all other fields
  46. mu sync.Mutex
  47. // inflight suppresses parallel execution of updateKeys and allows
  48. // multiple goroutines to wait for its result.
  49. inflight *inflight
  50. // A set of cached keys and their expiry.
  51. cachedKeys []jose.JSONWebKey
  52. expiry time.Time
  53. }
  54. // inflight is used to wait on some in-flight request from multiple goroutines.
  55. type inflight struct {
  56. doneCh chan struct{}
  57. keys []jose.JSONWebKey
  58. err error
  59. }
  60. func newInflight() *inflight {
  61. return &inflight{doneCh: make(chan struct{})}
  62. }
  63. // wait returns a channel that multiple goroutines can receive on. Once it returns
  64. // a value, the inflight request is done and result() can be inspected.
  65. func (i *inflight) wait() <-chan struct{} {
  66. return i.doneCh
  67. }
  68. // done can only be called by a single goroutine. It records the result of the
  69. // inflight request and signals other goroutines that the result is safe to
  70. // inspect.
  71. func (i *inflight) done(keys []jose.JSONWebKey, err error) {
  72. i.keys = keys
  73. i.err = err
  74. close(i.doneCh)
  75. }
  76. // result cannot be called until the wait() channel has returned a value.
  77. func (i *inflight) result() ([]jose.JSONWebKey, error) {
  78. return i.keys, i.err
  79. }
  80. func (r *remoteKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
  81. jws, err := jose.ParseSigned(jwt)
  82. if err != nil {
  83. return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
  84. }
  85. return r.verify(ctx, jws)
  86. }
  87. func (r *remoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
  88. // We don't support JWTs signed with multiple signatures.
  89. keyID := ""
  90. for _, sig := range jws.Signatures {
  91. keyID = sig.Header.KeyID
  92. break
  93. }
  94. keys, expiry := r.keysFromCache()
  95. // Don't check expiry yet. This optimizes for when the provider is unavailable.
  96. for _, key := range keys {
  97. if keyID == "" || key.KeyID == keyID {
  98. if payload, err := jws.Verify(&key); err == nil {
  99. return payload, nil
  100. }
  101. }
  102. }
  103. if !r.now().Add(keysExpiryDelta).After(expiry) {
  104. // Keys haven't expired, don't refresh.
  105. return nil, errors.New("failed to verify id token signature")
  106. }
  107. keys, err := r.keysFromRemote(ctx)
  108. if err != nil {
  109. return nil, fmt.Errorf("fetching keys %v", err)
  110. }
  111. for _, key := range keys {
  112. if keyID == "" || key.KeyID == keyID {
  113. if payload, err := jws.Verify(&key); err == nil {
  114. return payload, nil
  115. }
  116. }
  117. }
  118. return nil, errors.New("failed to verify id token signature")
  119. }
  120. func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey, expiry time.Time) {
  121. r.mu.Lock()
  122. defer r.mu.Unlock()
  123. return r.cachedKeys, r.expiry
  124. }
  125. // keysFromRemote syncs the key set from the remote set, records the values in the
  126. // cache, and returns the key set.
  127. func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
  128. // Need to lock to inspect the inflight request field.
  129. r.mu.Lock()
  130. // If there's not a current inflight request, create one.
  131. if r.inflight == nil {
  132. r.inflight = newInflight()
  133. // This goroutine has exclusive ownership over the current inflight
  134. // request. It releases the resource by nil'ing the inflight field
  135. // once the goroutine is done.
  136. go func() {
  137. // Sync keys and finish inflight when that's done.
  138. keys, expiry, err := r.updateKeys()
  139. r.inflight.done(keys, err)
  140. // Lock to update the keys and indicate that there is no longer an
  141. // inflight request.
  142. r.mu.Lock()
  143. defer r.mu.Unlock()
  144. if err == nil {
  145. r.cachedKeys = keys
  146. r.expiry = expiry
  147. }
  148. // Free inflight so a different request can run.
  149. r.inflight = nil
  150. }()
  151. }
  152. inflight := r.inflight
  153. r.mu.Unlock()
  154. select {
  155. case <-ctx.Done():
  156. return nil, ctx.Err()
  157. case <-inflight.wait():
  158. return inflight.result()
  159. }
  160. }
  161. func (r *remoteKeySet) updateKeys() ([]jose.JSONWebKey, time.Time, error) {
  162. req, err := http.NewRequest("GET", r.jwksURL, nil)
  163. if err != nil {
  164. return nil, time.Time{}, fmt.Errorf("oidc: can't create request: %v", err)
  165. }
  166. resp, err := doRequest(r.ctx, req)
  167. if err != nil {
  168. return nil, time.Time{}, fmt.Errorf("oidc: get keys failed %v", err)
  169. }
  170. defer resp.Body.Close()
  171. body, err := ioutil.ReadAll(resp.Body)
  172. if err != nil {
  173. return nil, time.Time{}, fmt.Errorf("unable to read response body: %v", err)
  174. }
  175. if resp.StatusCode != http.StatusOK {
  176. return nil, time.Time{}, fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body)
  177. }
  178. var keySet jose.JSONWebKeySet
  179. err = unmarshalResp(resp, body, &keySet)
  180. if err != nil {
  181. return nil, time.Time{}, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body)
  182. }
  183. // If the server doesn't provide cache control headers, assume the
  184. // keys expire immediately.
  185. expiry := r.now()
  186. _, e, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{})
  187. if err == nil && e.After(expiry) {
  188. expiry = e
  189. }
  190. return keySet.Keys, expiry, nil
  191. }