Files
seaweedfs/weed/iam/oidc/oidc_provider.go
T
Chris Lu d951a8df5a feat(iam): STS web-identity AWS-fidelity polish (Phase 1) (#9318)
* feat(iam): STS web-identity AWS-fidelity polish

- OIDC discovery via .well-known/openid-configuration; falls back to
  /.well-known/jwks.json when discovery is absent. Reject discovery docs
  whose issuer claim does not match the configured issuer to defend
  against issuer-substitution.
- ComputeParentUser derives a stable per-identity hash from (sub, iss).
  Surface as aws:userid in the request context and as a parent_user
  claim in the session JWT so per-user state survives token rotation.
- Per-role MaxSessionDuration (3600..43200) clamps requested
  DurationSeconds before the STS service applies its own caps.
- Tighten RoleSessionName to the AWS contract: 2..64 chars from
  [\w+=,.@-].
- Populate PackedPolicySize in AssumeRole / AssumeRoleWithWebIdentity /
  AssumeRoleWithLDAPIdentity responses as a percentage of the 2048-byte
  inline session policy budget.

* fix(iam): leave omitted DurationSeconds nil so STS default applies

capDurationByRole was substituting the role's MaxSessionDuration
when the caller omitted DurationSeconds entirely. AWS returns the
configured default (typically 1 hour) in that case, not the role's
upper bound — a 12h MaxSessionDuration shouldn't silently make every
no-duration assume-role mint a 12h session.

Return nil when requested is nil; let the downstream
calculateSessionDuration in the STS service apply its TokenDuration
default. The role-max upper bound still clamps when the request
arrives with a concrete value above the cap.

Addresses gemini high-priority review on PR #9318.

* fix(iam): synchronize OIDCProvider JWKS cache fields

jwksCache, jwksFetchedAt, resolvedJWKSUri, and discoveryFailed are
mutated lazily on the first token-validate call and refreshed
afterwards on TTL expiry. Multiple S3 requests can land here in
parallel, so the writes were racing against subsequent reads on
every other goroutine. resolvedJWKSUri/discoveryFailed inherited
the same un-protected pattern when discovery shipped.

Add sync.RWMutex; getPublicKey takes the read lock for the
common cache-hit path and promotes to the write lock for misses
+ refreshes. fetchJWKSLocked / resolveJWKSUriLocked assume the
write lock is held by the caller; fetchJWKS keeps the
test-friendly entry point that acquires the lock itself.

Addresses gemini high-priority review on PR #9318.

* fix(iam): trim trailing slash + retry discovery after transient failure

Two OIDC discovery edge cases reviewers flagged:

1. Issuer comparison was sensitive to trailing slashes. resolveJWKSUri
   trims them when building the discovery URL, but the doc.Issuer ↔
   p.config.Issuer check did not, so an IDP whose issuer claim drops or
   adds the slash relative to the configured value would be falsely
   rejected. Trim a single trailing slash on each side before comparing.

2. discoveryFailed flipped to true on any error and stayed there for the
   process lifetime. A transient 5xx at startup permanently locked the
   provider into the /.well-known/jwks.json fallback. Reset the flag at
   the top of fetchJWKSLocked when no URI has been cached yet, so each
   JWKS refresh (typically once per TTL = 1h) reattempts discovery.
   Successful discovery remains cached via resolvedJWKSUri so we don't
   pay the discovery RTT on every refresh.

Addresses gemini security-medium + medium reviews on PR #9318.

* fix(iam): require non-empty issuer in OIDC discovery doc

The previous "doc.Issuer != "" && ..." guard let a discovery document
that omitted the issuer field bypass the issuer-mismatch check
entirely, letting the doc steer fetchJWKS at any URL it provided.
OIDC Discovery 1.0 §3 mandates the issuer field; treat missing as a
hard failure same as mismatched. Trailing-slash equivalence still
applies.

Adds TestDiscoveryRejectsMissingIssuer alongside the existing
TestDiscoveryRejectsIssuerMismatch via a new omitDiscoveryIssuer
toggle on fakeIDP.
2026-05-04 22:10:49 -07:00

920 lines
27 KiB
Go

package oidc
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
)
// OIDCProvider implements OpenID Connect authentication
type OIDCProvider struct {
name string
config *OIDCConfig
initialized bool
httpClient *http.Client
jwksTTL time.Duration
// mu guards the lazily-mutated cache fields below: jwksCache, jwksFetchedAt,
// resolvedJWKSUri, and discoveryFailed are all populated on the first
// validate-token call and refreshed when the cache expires. Multiple S3
// requests can land here in parallel, so they need synchronization.
mu sync.RWMutex
jwksCache *JWKS
jwksFetchedAt time.Time
resolvedJWKSUri string
discoveryFailed bool
}
// OIDCConfig holds OIDC provider configuration
type OIDCConfig struct {
// Issuer is the OIDC issuer URL
Issuer string `json:"issuer"`
// ClientID is the OAuth2 client ID
ClientID string `json:"clientId"`
// ClientSecret is the OAuth2 client secret (optional for public clients)
ClientSecret string `json:"clientSecret,omitempty"`
// JWKSUri is the JSON Web Key Set URI
JWKSUri string `json:"jwksUri,omitempty"`
// UserInfoUri is the UserInfo endpoint URI
UserInfoUri string `json:"userInfoUri,omitempty"`
// Scopes are the OAuth2 scopes to request
Scopes []string `json:"scopes,omitempty"`
// RoleMapping defines how to map OIDC claims to roles
RoleMapping *providers.RoleMapping `json:"roleMapping,omitempty"`
// ClaimsMapping defines how to map OIDC claims to identity attributes
ClaimsMapping map[string]string `json:"claimsMapping,omitempty"`
// JWKSCacheTTLSeconds sets how long to cache JWKS before refresh (default 3600 seconds)
JWKSCacheTTLSeconds int `json:"jwksCacheTTLSeconds,omitempty"`
// TLSCACert is the path to the CA certificate file for custom/self-signed certificates
TLSCACert string `json:"tlsCaCert,omitempty"`
// TLSInsecureSkipVerify controls whether to skip TLS verification.
// WARNING: Should only be used in development/testing environments. Never use in production.
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify,omitempty"`
}
// JWKS represents JSON Web Key Set
type JWKS struct {
Keys []JWK `json:"keys"`
}
// JWK represents a JSON Web Key
type JWK struct {
Kty string `json:"kty"` // Key Type (RSA, EC, etc.)
Kid string `json:"kid"` // Key ID
Use string `json:"use"` // Usage (sig for signature)
Alg string `json:"alg"` // Algorithm (RS256, etc.)
N string `json:"n"` // RSA public key modulus
E string `json:"e"` // RSA public key exponent
X string `json:"x"` // EC public key x coordinate
Y string `json:"y"` // EC public key y coordinate
Crv string `json:"crv"` // EC curve
}
// NewOIDCProvider creates a new OIDC provider
func NewOIDCProvider(name string) *OIDCProvider {
return &OIDCProvider{
name: name,
httpClient: &http.Client{Timeout: 30 * time.Second},
}
}
// Name returns the provider name
func (p *OIDCProvider) Name() string {
return p.name
}
// GetIssuer returns the configured issuer URL for efficient provider lookup
func (p *OIDCProvider) GetIssuer() string {
if p.config == nil {
return ""
}
return p.config.Issuer
}
// Initialize initializes the OIDC provider with configuration
func (p *OIDCProvider) Initialize(config interface{}) error {
if config == nil {
return fmt.Errorf("config cannot be nil")
}
oidcConfig, ok := config.(*OIDCConfig)
if !ok {
return fmt.Errorf("invalid config type for OIDC provider")
}
if err := p.validateConfig(oidcConfig); err != nil {
return fmt.Errorf("invalid OIDC configuration: %w", err)
}
p.config = oidcConfig
p.initialized = true
// Configure JWKS cache TTL
if oidcConfig.JWKSCacheTTLSeconds > 0 {
p.jwksTTL = time.Duration(oidcConfig.JWKSCacheTTLSeconds) * time.Second
} else {
p.jwksTTL = time.Hour
}
// Configure HTTP client with TLS settings
tlsConfig := &tls.Config{
InsecureSkipVerify: oidcConfig.TLSInsecureSkipVerify,
MinVersion: tls.VersionTLS12, // Prevent TLS downgrade attacks
}
if oidcConfig.TLSInsecureSkipVerify {
glog.Warningf("OIDC provider %q is configured to skip TLS verification. This is insecure and should not be used in production.", p.name)
}
if oidcConfig.TLSCACert != "" {
// Validate that the CA cert path is absolute to prevent reading unintended files
if !filepath.IsAbs(oidcConfig.TLSCACert) {
return fmt.Errorf("TLSCACert must be an absolute path, got: %s", oidcConfig.TLSCACert)
}
caCert, err := os.ReadFile(oidcConfig.TLSCACert)
if err != nil {
return fmt.Errorf("failed to read CA cert file: %w", err)
}
// Start with the system cert pool to trust public CAs, then add the custom one.
rootCAs, _ := x509.SystemCertPool()
if rootCAs == nil {
rootCAs = x509.NewCertPool()
}
if !rootCAs.AppendCertsFromPEM(caCert) {
return fmt.Errorf("failed to append CA cert from file: %s", oidcConfig.TLSCACert)
}
tlsConfig.RootCAs = rootCAs
}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
}
p.httpClient = &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}
// For testing, we'll skip the actual OIDC client initialization
return nil
}
// validateConfig validates the OIDC configuration
func (p *OIDCProvider) validateConfig(config *OIDCConfig) error {
if config.Issuer == "" {
return fmt.Errorf("issuer is required")
}
if config.ClientID == "" {
return fmt.Errorf("client ID is required")
}
// Basic URL validation for issuer
if config.Issuer != "" && config.Issuer != "https://accounts.google.com" && config.Issuer[0:4] != "http" {
return fmt.Errorf("invalid issuer URL format")
}
return nil
}
// Authenticate authenticates a user with an OIDC token
func (p *OIDCProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
if token == "" {
return nil, fmt.Errorf("token cannot be empty")
}
// Validate token and get claims
claims, err := p.ValidateToken(ctx, token)
if err != nil {
return nil, err
}
// Map claims to external identity
email, _ := claims.GetClaimString("email")
displayName, _ := claims.GetClaimString("name")
groups, _ := claims.GetClaimStringSlice("groups")
// Debug: Log available claims
glog.V(3).Infof("Available claims: %+v", claims.Claims)
if rolesFromClaims, exists := claims.GetClaimStringSlice("roles"); exists {
glog.V(3).Infof("Roles claim found as string slice: %v", rolesFromClaims)
} else if roleFromClaims, exists := claims.GetClaimString("roles"); exists {
glog.V(3).Infof("Roles claim found as string: %s", roleFromClaims)
} else {
glog.V(3).Infof("No roles claim found in token")
}
// Map claims to roles using configured role mapping
roles := p.mapClaimsToRolesWithConfig(claims)
// Create attributes map and add roles
attributes := make(map[string]string)
if len(roles) > 0 {
// Store roles as a comma-separated string in attributes
attributes["roles"] = strings.Join(roles, ",")
}
// Store all additional claims as attributes
processedClaims := map[string]struct{}{
// user / business claims already handled elsewhere
"sub": {},
"email": {},
"name": {},
"groups": {},
"roles": {},
// standard structural OIDC/JWT claims that should not be exposed as attributes
"iss": {},
"aud": {},
"exp": {},
"iat": {},
"nbf": {},
"jti": {},
}
for key, value := range claims.Claims {
if _, isProcessed := processedClaims[key]; !isProcessed {
if strValue, ok := value.(string); ok {
attributes[key] = strValue
} else if jsonValue, err := json.Marshal(value); err == nil {
attributes[key] = string(jsonValue)
} else {
glog.Warningf("failed to marshal claim %q to JSON for OIDC attributes: %v", key, err)
}
}
}
identity := &providers.ExternalIdentity{
UserID: claims.Subject,
Email: email,
DisplayName: displayName,
Groups: groups,
Attributes: attributes,
Provider: p.name,
Issuer: claims.Issuer,
}
// Pass the token expiration to limit session duration
// This ensures the STS session doesn't exceed the source token's validity
if !claims.ExpiresAt.IsZero() {
identity.TokenExpiration = &claims.ExpiresAt
}
return identity, nil
}
// GetUserInfo retrieves user information from the UserInfo endpoint
func (p *OIDCProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
if userID == "" {
return nil, fmt.Errorf("user ID cannot be empty")
}
// For now, we'll use a token-based approach since OIDC UserInfo typically requires a token
// In a real implementation, this would need an access token from the authentication flow
return p.getUserInfoWithToken(ctx, userID, "")
}
// GetUserInfoWithToken retrieves user information using an access token
func (p *OIDCProvider) GetUserInfoWithToken(ctx context.Context, accessToken string) (*providers.ExternalIdentity, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
if accessToken == "" {
return nil, fmt.Errorf("access token cannot be empty")
}
return p.getUserInfoWithToken(ctx, "", accessToken)
}
// getUserInfoWithToken is the internal implementation for UserInfo endpoint calls
func (p *OIDCProvider) getUserInfoWithToken(ctx context.Context, userID, accessToken string) (*providers.ExternalIdentity, error) {
// Determine UserInfo endpoint URL
userInfoUri := p.config.UserInfoUri
if userInfoUri == "" {
// Use standard OIDC discovery endpoint convention
userInfoUri = strings.TrimSuffix(p.config.Issuer, "/") + "/userinfo"
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, "GET", userInfoUri, nil)
if err != nil {
return nil, fmt.Errorf("failed to create UserInfo request: %v", err)
}
// Set authorization header if access token is provided
if accessToken != "" {
req.Header.Set("Authorization", "Bearer "+accessToken)
}
req.Header.Set("Accept", "application/json")
// Make HTTP request
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call UserInfo endpoint: %v", err)
}
defer resp.Body.Close()
// Check response status
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("UserInfo endpoint returned status %d", resp.StatusCode)
}
// Parse JSON response
var userInfo map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return nil, fmt.Errorf("failed to decode UserInfo response: %v", err)
}
glog.V(4).Infof("Received UserInfo response: %+v", userInfo)
// Map UserInfo claims to ExternalIdentity
identity := p.mapUserInfoToIdentity(userInfo)
// If userID was provided but not found in claims, use it
if userID != "" && identity.UserID == "" {
identity.UserID = userID
}
glog.V(3).Infof("Retrieved user info from OIDC provider: %s", identity.UserID)
return identity, nil
}
// ValidateToken validates an OIDC JWT token
func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
if !p.initialized {
return nil, fmt.Errorf("provider not initialized")
}
if token == "" {
return nil, fmt.Errorf("token cannot be empty")
}
// Parse token without verification first to get header info
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
if err != nil {
return nil, fmt.Errorf("failed to parse JWT token: %v", err)
}
// Get key ID from header
kid, ok := parsedToken.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("missing key ID in JWT header")
}
// Get signing key from JWKS
publicKey, err := p.getPublicKey(ctx, kid)
if err != nil {
return nil, fmt.Errorf("failed to get public key: %v", err)
}
// Parse and validate token with proper signature verification
claims := jwt.MapClaims{}
validatedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
// Verify signing method
switch token.Method.(type) {
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
return publicKey, nil
default:
return nil, fmt.Errorf("unsupported signing method: %v", token.Header["alg"])
}
})
if err != nil {
// Use JWT library's typed errors for robust error checking
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, fmt.Errorf("%w: %v", providers.ErrProviderTokenExpired, err)
}
return nil, fmt.Errorf("%w: %v", providers.ErrProviderInvalidToken, err)
}
if !validatedToken.Valid {
return nil, fmt.Errorf("%w: token validation failed", providers.ErrProviderInvalidToken)
}
// Validate required claims
issuer, ok := claims["iss"].(string)
if !ok || issuer != p.config.Issuer {
return nil, fmt.Errorf("%w: expected %s, got %s", providers.ErrProviderInvalidIssuer, p.config.Issuer, issuer)
}
// Check audience claim (aud) or authorized party (azp) - Keycloak uses azp
// Per RFC 7519, aud can be either a string or an array of strings
var audienceMatched bool
if audClaim, ok := claims["aud"]; ok {
switch aud := audClaim.(type) {
case string:
if aud == p.config.ClientID {
audienceMatched = true
}
case []interface{}:
for _, a := range aud {
if str, ok := a.(string); ok && str == p.config.ClientID {
audienceMatched = true
break
}
}
}
}
if !audienceMatched {
if azp, ok := claims["azp"].(string); ok && azp == p.config.ClientID {
audienceMatched = true
}
}
if !audienceMatched {
return nil, fmt.Errorf("%w: expected client ID %s", providers.ErrProviderInvalidAudience, p.config.ClientID)
}
subject, ok := claims["sub"].(string)
if !ok {
return nil, fmt.Errorf("%w: missing subject claim", providers.ErrProviderMissingClaims)
}
// Convert to our TokenClaims structure
tokenClaims := &providers.TokenClaims{
Subject: subject,
Issuer: issuer,
Claims: make(map[string]interface{}),
}
// Extract time-based claims (exp, iat, nbf)
for key, target := range map[string]*time.Time{
"exp": &tokenClaims.ExpiresAt,
"iat": &tokenClaims.IssuedAt,
"nbf": &tokenClaims.NotBefore,
} {
if val, ok := claims[key]; ok {
switch v := val.(type) {
case float64:
*target = time.Unix(int64(v), 0)
case json.Number:
if intVal, err := v.Int64(); err == nil {
*target = time.Unix(intVal, 0)
}
}
}
}
// Copy all claims
for key, value := range claims {
tokenClaims.Claims[key] = value
}
return tokenClaims, nil
}
// mapClaimsToRoles maps token claims to SeaweedFS roles (legacy method)
func (p *OIDCProvider) mapClaimsToRoles(claims *providers.TokenClaims) []string {
roles := []string{}
// Get groups from claims
groups, _ := claims.GetClaimStringSlice("groups")
// Basic role mapping based on groups
for _, group := range groups {
switch group {
case "admins":
roles = append(roles, "admin")
case "developers":
roles = append(roles, "readwrite")
case "users":
roles = append(roles, "readonly")
}
}
if len(roles) == 0 {
roles = []string{"readonly"} // Default role
}
return roles
}
// mapClaimsToRolesWithConfig maps token claims to roles using configured role mapping
func (p *OIDCProvider) mapClaimsToRolesWithConfig(claims *providers.TokenClaims) []string {
glog.V(3).Infof("mapClaimsToRolesWithConfig: RoleMapping is nil? %t", p.config.RoleMapping == nil)
if p.config.RoleMapping == nil {
glog.V(2).Infof("No role mapping configured for provider %s, using legacy mapping", p.name)
// Fallback to legacy mapping if no role mapping configured
return p.mapClaimsToRoles(claims)
}
glog.V(3).Infof("Applying %d role mapping rules", len(p.config.RoleMapping.Rules))
roles := []string{}
// Apply role mapping rules
for i, rule := range p.config.RoleMapping.Rules {
glog.V(3).Infof("Rule %d: claim=%s, value=%s, role=%s", i, rule.Claim, rule.Value, rule.Role)
if rule.Matches(claims) {
glog.V(2).Infof("Rule %d matched! Adding role: %s", i, rule.Role)
roles = append(roles, rule.Role)
} else {
glog.V(3).Infof("Rule %d did not match", i)
}
}
// Use default role if no rules matched
if len(roles) == 0 && p.config.RoleMapping.DefaultRole != "" {
glog.V(2).Infof("No rules matched, using default role: %s", p.config.RoleMapping.DefaultRole)
roles = []string{p.config.RoleMapping.DefaultRole}
}
glog.V(2).Infof("Role mapping result: %v", roles)
return roles
}
// getPublicKey retrieves the public key for the given key ID from JWKS.
// Cache hits use the read lock so concurrent token validations don't
// serialize on JWKS lookup. Misses and expirations promote to the write
// lock so the JWKS fetch + cache write happens once per refresh cycle.
func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{}, error) {
// Fast path: read lock and look in cache.
p.mu.RLock()
if p.jwksCache != nil && (p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) <= p.jwksTTL) {
for _, key := range p.jwksCache.Keys {
if key.Kid == kid {
k := key
p.mu.RUnlock()
return p.parseJWK(&k)
}
}
}
p.mu.RUnlock()
// Slow path: take the write lock for the (re)fetch + retry. Re-check the
// cache under the write lock in case another goroutine already refreshed.
p.mu.Lock()
defer p.mu.Unlock()
cacheValid := p.jwksCache != nil && (p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) <= p.jwksTTL)
if !cacheValid {
if err := p.fetchJWKSLocked(ctx); err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %v", err)
}
}
for _, key := range p.jwksCache.Keys {
if key.Kid == kid {
k := key
return p.parseJWK(&k)
}
}
// Key not found in cache. Refresh JWKS once to handle key rotation.
if err := p.fetchJWKSLocked(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh JWKS after key miss: %v", err)
}
for _, key := range p.jwksCache.Keys {
if key.Kid == kid {
k := key
return p.parseJWK(&k)
}
}
return nil, fmt.Errorf("key with ID %s not found in JWKS after refresh", kid)
}
// discoveryDocument is the subset of the OpenID Provider Configuration we need.
// See https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata.
type discoveryDocument struct {
Issuer string `json:"issuer"`
JWKSUri string `json:"jwks_uri"`
}
// resolveJWKSUriLocked determines the JWKS URI for the provider. The caller
// must hold p.mu (write lock); the function reads/writes p.resolvedJWKSUri
// and p.discoveryFailed without taking the lock itself.
//
// Order of resolution:
// 1. explicit config.JWKSUri (operator override; never overridden by discovery).
// 2. cached resolvedJWKSUri from a prior discovery (refreshed when JWKS cache expires).
// 3. .well-known/openid-configuration discovery (per OIDC Discovery 1.0).
// 4. fallback to {issuer}/.well-known/jwks.json (compat path for IDPs that
// don't publish discovery).
func (p *OIDCProvider) resolveJWKSUriLocked(ctx context.Context) (string, error) {
if p.config.JWKSUri != "" {
return p.config.JWKSUri, nil
}
if p.resolvedJWKSUri != "" {
return p.resolvedJWKSUri, nil
}
issuer := strings.TrimSuffix(p.config.Issuer, "/")
if !p.discoveryFailed {
discoveryURL := issuer + "/.well-known/openid-configuration"
uri, err := p.fetchDiscoveryJWKSUri(ctx, discoveryURL)
switch {
case err == nil:
p.resolvedJWKSUri = uri
return uri, nil
default:
// Cache the failure so we don't pay the discovery RTT on every refresh.
// Operators with non-discovery IDPs see one failed lookup at startup.
glog.V(3).Infof("OIDC discovery at %s failed (%v); falling back to /.well-known/jwks.json", discoveryURL, err)
p.discoveryFailed = true
}
}
return issuer + "/.well-known/jwks.json", nil
}
// fetchDiscoveryJWKSUri retrieves the OIDC discovery document and returns
// the jwks_uri field. The issuer claim in the document must match config.Issuer
// to defend against issuer-substitution attacks during discovery.
func (p *OIDCProvider) fetchDiscoveryJWKSUri(ctx context.Context, discoveryURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", discoveryURL, nil)
if err != nil {
return "", fmt.Errorf("create discovery request: %v", err)
}
req.Header.Set("Accept", "application/json")
resp, err := p.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("fetch discovery document: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("discovery endpoint returned status %d", resp.StatusCode)
}
var doc discoveryDocument
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
return "", fmt.Errorf("decode discovery document: %v", err)
}
if doc.JWKSUri == "" {
return "", fmt.Errorf("discovery document missing jwks_uri")
}
// Issuer must be present and match: a discovery doc that points to a
// different issuer is either a misconfiguration or an attack against
// issuer-confusion, and a doc that omits the issuer field entirely
// would have bypassed the previous check (doc.Issuer != "" guard) and
// silently accepted whatever JWKS URI the document supplied. OIDC
// Discovery 1.0 §3 mandates the issuer field, so treat missing as a
// hard failure. Compare after trimming a single trailing slash on each
// side because real IdPs disagree on whether the configured issuer
// has one.
if strings.TrimSuffix(doc.Issuer, "/") != strings.TrimSuffix(p.config.Issuer, "/") {
return "", fmt.Errorf("discovery issuer %q does not match configured issuer %q", doc.Issuer, p.config.Issuer)
}
return doc.JWKSUri, nil
}
// fetchJWKS is a thin wrapper around fetchJWKSLocked that acquires the
// write lock. Used by tests; production callers in getPublicKey already
// hold the lock and call fetchJWKSLocked directly.
func (p *OIDCProvider) fetchJWKS(ctx context.Context) error {
p.mu.Lock()
defer p.mu.Unlock()
return p.fetchJWKSLocked(ctx)
}
// fetchJWKSLocked fetches the JWKS from the provider. The caller must hold
// p.mu (write lock); the function writes p.jwksCache and p.jwksFetchedAt
// without taking the lock itself.
//
// Each fetch reattempts discovery if the previous attempt failed: a
// transient 5xx that flipped discoveryFailed at startup shouldn't lock the
// provider into the fallback path forever. The retry rate is bounded by
// the JWKS TTL (typically 1h), so the discovery RTT cost is amortized.
func (p *OIDCProvider) fetchJWKSLocked(ctx context.Context) error {
if p.config.JWKSUri == "" && p.resolvedJWKSUri == "" {
p.discoveryFailed = false
}
jwksURL, err := p.resolveJWKSUriLocked(ctx)
if err != nil {
return fmt.Errorf("resolve JWKS URI: %v", err)
}
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
if err != nil {
return fmt.Errorf("failed to create JWKS request: %v", err)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to fetch JWKS: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("JWKS endpoint returned status: %d", resp.StatusCode)
}
var jwks JWKS
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return fmt.Errorf("failed to decode JWKS response: %v", err)
}
p.jwksCache = &jwks
p.jwksFetchedAt = time.Now()
glog.V(3).Infof("Fetched JWKS with %d keys from %s", len(jwks.Keys), jwksURL)
return nil
}
// parseJWK converts a JWK to a public key
func (p *OIDCProvider) parseJWK(key *JWK) (interface{}, error) {
switch key.Kty {
case "RSA":
return p.parseRSAKey(key)
case "EC":
return p.parseECKey(key)
default:
return nil, fmt.Errorf("unsupported key type: %s", key.Kty)
}
}
// parseRSAKey parses an RSA key from JWK
func (p *OIDCProvider) parseRSAKey(key *JWK) (*rsa.PublicKey, error) {
// Decode the modulus (n)
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
if err != nil {
return nil, fmt.Errorf("failed to decode RSA modulus: %v", err)
}
// Decode the exponent (e)
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
if err != nil {
return nil, fmt.Errorf("failed to decode RSA exponent: %v", err)
}
// Convert exponent bytes to int
var exponent int
for _, b := range eBytes {
exponent = exponent*256 + int(b)
}
// Create RSA public key
pubKey := &rsa.PublicKey{
E: exponent,
}
pubKey.N = new(big.Int).SetBytes(nBytes)
return pubKey, nil
}
// parseECKey parses an Elliptic Curve key from JWK
func (p *OIDCProvider) parseECKey(key *JWK) (*ecdsa.PublicKey, error) {
// Validate required fields
if key.X == "" || key.Y == "" || key.Crv == "" {
return nil, fmt.Errorf("incomplete EC key: missing x, y, or crv parameter")
}
// Get the curve
var curve elliptic.Curve
switch key.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported EC curve: %s", key.Crv)
}
// Decode x coordinate
xBytes, err := base64.RawURLEncoding.DecodeString(key.X)
if err != nil {
return nil, fmt.Errorf("failed to decode EC x coordinate: %v", err)
}
// Decode y coordinate
yBytes, err := base64.RawURLEncoding.DecodeString(key.Y)
if err != nil {
return nil, fmt.Errorf("failed to decode EC y coordinate: %v", err)
}
// Create EC public key
pubKey := &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(xBytes),
Y: new(big.Int).SetBytes(yBytes),
}
// Validate that the point is on the curve
if !curve.IsOnCurve(pubKey.X, pubKey.Y) {
return nil, fmt.Errorf("EC key coordinates are not on the specified curve")
}
return pubKey, nil
}
// mapUserInfoToIdentity maps UserInfo response to ExternalIdentity
func (p *OIDCProvider) mapUserInfoToIdentity(userInfo map[string]interface{}) *providers.ExternalIdentity {
identity := &providers.ExternalIdentity{
Provider: p.name,
Attributes: make(map[string]string),
}
// Map standard OIDC claims
if sub, ok := userInfo["sub"].(string); ok {
identity.UserID = sub
}
if email, ok := userInfo["email"].(string); ok {
identity.Email = email
}
if name, ok := userInfo["name"].(string); ok {
identity.DisplayName = name
}
// Handle groups claim (can be array of strings or single string)
if groupsData, exists := userInfo["groups"]; exists {
switch groups := groupsData.(type) {
case []interface{}:
// Array of groups
for _, group := range groups {
if groupStr, ok := group.(string); ok {
identity.Groups = append(identity.Groups, groupStr)
}
}
case []string:
// Direct string array
identity.Groups = groups
case string:
// Single group as string
identity.Groups = []string{groups}
}
}
// Map configured custom claims
if p.config.ClaimsMapping != nil {
for identityField, oidcClaim := range p.config.ClaimsMapping {
if value, exists := userInfo[oidcClaim]; exists {
if strValue, ok := value.(string); ok {
switch identityField {
case "email":
if identity.Email == "" {
identity.Email = strValue
}
case "displayName":
if identity.DisplayName == "" {
identity.DisplayName = strValue
}
case "userID":
if identity.UserID == "" {
identity.UserID = strValue
}
default:
identity.Attributes[identityField] = strValue
}
}
}
}
}
// Store all additional claims as attributes
for key, value := range userInfo {
if key != "sub" && key != "email" && key != "name" && key != "groups" {
if strValue, ok := value.(string); ok {
identity.Attributes[key] = strValue
} else if jsonValue, err := json.Marshal(value); err == nil {
identity.Attributes[key] = string(jsonValue)
}
}
}
return identity
}