feat(iam): STS web-identity AWS-fidelity polish (Phase 1) (#9318)

* feat(iam): STS web-identity AWS-fidelity polish

- OIDC discovery via .well-known/openid-configuration; falls back to
  /.well-known/jwks.json when discovery is absent. Reject discovery docs
  whose issuer claim does not match the configured issuer to defend
  against issuer-substitution.
- ComputeParentUser derives a stable per-identity hash from (sub, iss).
  Surface as aws:userid in the request context and as a parent_user
  claim in the session JWT so per-user state survives token rotation.
- Per-role MaxSessionDuration (3600..43200) clamps requested
  DurationSeconds before the STS service applies its own caps.
- Tighten RoleSessionName to the AWS contract: 2..64 chars from
  [\w+=,.@-].
- Populate PackedPolicySize in AssumeRole / AssumeRoleWithWebIdentity /
  AssumeRoleWithLDAPIdentity responses as a percentage of the 2048-byte
  inline session policy budget.

* fix(iam): leave omitted DurationSeconds nil so STS default applies

capDurationByRole was substituting the role's MaxSessionDuration
when the caller omitted DurationSeconds entirely. AWS returns the
configured default (typically 1 hour) in that case, not the role's
upper bound — a 12h MaxSessionDuration shouldn't silently make every
no-duration assume-role mint a 12h session.

Return nil when requested is nil; let the downstream
calculateSessionDuration in the STS service apply its TokenDuration
default. The role-max upper bound still clamps when the request
arrives with a concrete value above the cap.

Addresses gemini high-priority review on PR #9318.

* fix(iam): synchronize OIDCProvider JWKS cache fields

jwksCache, jwksFetchedAt, resolvedJWKSUri, and discoveryFailed are
mutated lazily on the first token-validate call and refreshed
afterwards on TTL expiry. Multiple S3 requests can land here in
parallel, so the writes were racing against subsequent reads on
every other goroutine. resolvedJWKSUri/discoveryFailed inherited
the same un-protected pattern when discovery shipped.

Add sync.RWMutex; getPublicKey takes the read lock for the
common cache-hit path and promotes to the write lock for misses
+ refreshes. fetchJWKSLocked / resolveJWKSUriLocked assume the
write lock is held by the caller; fetchJWKS keeps the
test-friendly entry point that acquires the lock itself.

Addresses gemini high-priority review on PR #9318.

* fix(iam): trim trailing slash + retry discovery after transient failure

Two OIDC discovery edge cases reviewers flagged:

1. Issuer comparison was sensitive to trailing slashes. resolveJWKSUri
   trims them when building the discovery URL, but the doc.Issuer ↔
   p.config.Issuer check did not, so an IDP whose issuer claim drops or
   adds the slash relative to the configured value would be falsely
   rejected. Trim a single trailing slash on each side before comparing.

2. discoveryFailed flipped to true on any error and stayed there for the
   process lifetime. A transient 5xx at startup permanently locked the
   provider into the /.well-known/jwks.json fallback. Reset the flag at
   the top of fetchJWKSLocked when no URI has been cached yet, so each
   JWKS refresh (typically once per TTL = 1h) reattempts discovery.
   Successful discovery remains cached via resolvedJWKSUri so we don't
   pay the discovery RTT on every refresh.

Addresses gemini security-medium + medium reviews on PR #9318.

* fix(iam): require non-empty issuer in OIDC discovery doc

The previous "doc.Issuer != "" && ..." guard let a discovery document
that omitted the issuer field bypass the issuer-mismatch check
entirely, letting the doc steer fetchJWKS at any URL it provided.
OIDC Discovery 1.0 §3 mandates the issuer field; treat missing as a
hard failure same as mismatched. Trailing-slash equivalence still
applies.

Adds TestDiscoveryRejectsMissingIssuer alongside the existing
TestDiscoveryRejectsIssuerMismatch via a new omitDiscoveryIssuer
toggle on fakeIDP.
This commit is contained in:
Chris Lu
2026-05-04 22:10:49 -07:00
committed by GitHub
parent 2417ba0354
commit d951a8df5a
11 changed files with 698 additions and 39 deletions
+39
View File
@@ -75,6 +75,12 @@ type RoleDefinition struct {
// Description is an optional description of the role // Description is an optional description of the role
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
// MaxSessionDuration is the upper bound (in seconds) on session length when
// callers assume this role. Zero means "use the global STS default". When
// set it must satisfy AWS bounds: 3600 ≤ MaxSessionDuration ≤ 43200.
// Honoured by AssumeRole, AssumeRoleWithWebIdentity, AssumeRoleWithCredentials.
MaxSessionDuration int64 `json:"maxSessionDuration,omitempty"`
} }
// ActionRequest represents a request to perform an action // ActionRequest represents a request to perform an action
@@ -272,6 +278,13 @@ func (m *IAMManager) CreateRole(ctx context.Context, filerAddress string, roleNa
} }
} }
// Validate per-role MaxSessionDuration if specified. AWS bounds: 1h..12h.
if roleDef.MaxSessionDuration != 0 {
if roleDef.MaxSessionDuration < 3600 || roleDef.MaxSessionDuration > 43200 {
return fmt.Errorf("MaxSessionDuration must be between 3600 and 43200 seconds, got %d", roleDef.MaxSessionDuration)
}
}
// Store role definition // Store role definition
return m.roleStore.StoreRole(ctx, "", roleName, roleDef) return m.roleStore.StoreRole(ctx, "", roleName, roleDef)
} }
@@ -329,10 +342,33 @@ func (m *IAMManager) AssumeRoleWithWebIdentity(ctx context.Context, request *sts
return nil, fmt.Errorf("trust policy validation failed: %w", err) return nil, fmt.Errorf("trust policy validation failed: %w", err)
} }
// Apply role-level MaxSessionDuration cap. The STS service still applies
// the global MaxSessionLength and the source-token-expiry cap on top of
// this; per-role takes precedence whenever it is the tightest bound.
request.DurationSeconds = capDurationByRole(request.DurationSeconds, roleDef.MaxSessionDuration)
// Use STS service to assume the role // Use STS service to assume the role
return m.stsService.AssumeRoleWithWebIdentity(ctx, request) return m.stsService.AssumeRoleWithWebIdentity(ctx, request)
} }
// capDurationByRole returns the requested duration clamped to the role's
// MaxSessionDuration. A nil requested duration is left nil so the STS
// service's calculateSessionDuration applies the global default (typically
// 1 hour) — substituting the role's max here would silently mint a 12h
// session for any caller who omitted DurationSeconds, which AWS does not
// do. The role-max upper bound still applies in the downstream cap chain
// once the request has a concrete duration.
func capDurationByRole(requested *int64, roleMax int64) *int64 {
if roleMax <= 0 || requested == nil {
return requested
}
if *requested > roleMax {
v := roleMax
return &v
}
return requested
}
// AssumeRoleWithCredentials assumes a role using credentials (LDAP) // AssumeRoleWithCredentials assumes a role using credentials (LDAP)
func (m *IAMManager) AssumeRoleWithCredentials(ctx context.Context, request *sts.AssumeRoleWithCredentialsRequest) (*sts.AssumeRoleResponse, error) { func (m *IAMManager) AssumeRoleWithCredentials(ctx context.Context, request *sts.AssumeRoleWithCredentialsRequest) (*sts.AssumeRoleResponse, error) {
if !m.initialized { if !m.initialized {
@@ -353,6 +389,9 @@ func (m *IAMManager) AssumeRoleWithCredentials(ctx context.Context, request *sts
return nil, fmt.Errorf("trust policy validation failed: %w", err) return nil, fmt.Errorf("trust policy validation failed: %w", err)
} }
// Apply role-level MaxSessionDuration cap.
request.DurationSeconds = capDurationByRole(request.DurationSeconds, roleDef.MaxSessionDuration)
// Use STS service to assume the role // Use STS service to assume the role
return m.stsService.AssumeRoleWithCredentials(ctx, request) return m.stsService.AssumeRoleWithCredentials(ctx, request)
} }
@@ -0,0 +1,34 @@
package integration
import "testing"
func intPtr(v int64) *int64 { return &v }
func TestCapDurationByRole(t *testing.T) {
cases := []struct {
name string
requested *int64
roleMax int64
want *int64
}{
{"no cap, no request", nil, 0, nil},
{"no cap, with request", intPtr(7200), 0, intPtr(7200)},
{"cap only, no request -> nil so STS default applies", nil, 3600, nil},
{"request below cap -> request", intPtr(1800), 3600, intPtr(1800)},
{"request equal cap -> request", intPtr(3600), 3600, intPtr(3600)},
{"request above cap -> cap", intPtr(43200), 3600, intPtr(3600)},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := capDurationByRole(tc.requested, tc.roleMax)
switch {
case got == nil && tc.want == nil:
return
case got == nil || tc.want == nil:
t.Fatalf("nilness mismatch: got=%v want=%v", got, tc.want)
case *got != *tc.want:
t.Fatalf("got=%d want=%d", *got, *tc.want)
}
})
}
}
+197
View File
@@ -0,0 +1,197 @@
package oidc
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
)
// fakeIDP wraps an httptest.Server and counts how many times each well-known
// endpoint is hit. Tests use it to assert discovery vs. fallback behaviour.
type fakeIDP struct {
server *httptest.Server
discoveryHits atomic.Int32
jwksHits atomic.Int32
customJWKSHits atomic.Int32
disableDiscovery bool
discoveryStatusCode int
discoveryIssuer string
omitDiscoveryIssuer bool // when true, the discovery doc omits the "issuer" field entirely
customJWKSPathSuffix string // optional suffix that fakeIDP serves at /custom/<suffix>
jwks JWKS
}
func newFakeIDP(t *testing.T) *fakeIDP {
t.Helper()
idp := &fakeIDP{
discoveryStatusCode: http.StatusOK,
jwks: JWKS{Keys: []JWK{{Kty: "RSA", Kid: "k1", Use: "sig", Alg: "RS256", N: "AQAB", E: "AQAB"}}},
}
mux := http.NewServeMux()
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
idp.discoveryHits.Add(1)
if idp.disableDiscovery {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(idp.discoveryStatusCode)
issuer := idp.discoveryIssuer
if issuer == "" {
issuer = idp.server.URL
}
jwksURI := idp.server.URL + "/discovered/jwks"
if idp.customJWKSPathSuffix != "" {
jwksURI = idp.server.URL + "/custom/" + idp.customJWKSPathSuffix
}
body := map[string]string{"jwks_uri": jwksURI}
if !idp.omitDiscoveryIssuer {
body["issuer"] = issuer
}
_ = json.NewEncoder(w).Encode(body)
})
mux.HandleFunc("/discovered/jwks", func(w http.ResponseWriter, r *http.Request) {
idp.jwksHits.Add(1)
_ = json.NewEncoder(w).Encode(idp.jwks)
})
mux.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) {
idp.jwksHits.Add(1)
_ = json.NewEncoder(w).Encode(idp.jwks)
})
mux.HandleFunc("/custom/", func(w http.ResponseWriter, r *http.Request) {
idp.customJWKSHits.Add(1)
_ = json.NewEncoder(w).Encode(idp.jwks)
})
idp.server = httptest.NewServer(mux)
t.Cleanup(idp.server.Close)
return idp
}
func newProviderForIDP(t *testing.T, idp *fakeIDP, jwksURIOverride string) *OIDCProvider {
t.Helper()
p := NewOIDCProvider("test")
cfg := &OIDCConfig{
Issuer: idp.server.URL,
ClientID: "test-client",
JWKSUri: jwksURIOverride,
}
if err := p.Initialize(cfg); err != nil {
t.Fatalf("Initialize: %v", err)
}
return p
}
func TestDiscoveryHappyPath(t *testing.T) {
idp := newFakeIDP(t)
p := newProviderForIDP(t, idp, "")
if err := p.fetchJWKS(context.Background()); err != nil {
t.Fatalf("fetchJWKS: %v", err)
}
if got := idp.discoveryHits.Load(); got != 1 {
t.Fatalf("expected 1 discovery hit, got %d", got)
}
if got := idp.jwksHits.Load(); got != 1 {
t.Fatalf("expected 1 JWKS hit at discovered uri, got %d", got)
}
// A second fetch reuses the cached jwks_uri without re-discovering.
if err := p.fetchJWKS(context.Background()); err != nil {
t.Fatalf("fetchJWKS second: %v", err)
}
if got := idp.discoveryHits.Load(); got != 1 {
t.Fatalf("discovery should be cached, got %d hits", got)
}
}
func TestDiscoveryFallback404(t *testing.T) {
idp := newFakeIDP(t)
idp.disableDiscovery = true
p := newProviderForIDP(t, idp, "")
if err := p.fetchJWKS(context.Background()); err != nil {
t.Fatalf("fetchJWKS: %v", err)
}
if got := idp.discoveryHits.Load(); got != 1 {
t.Fatalf("expected 1 discovery probe, got %d", got)
}
if got := idp.jwksHits.Load(); got != 1 {
t.Fatalf("expected 1 JWKS hit at fallback uri, got %d", got)
}
// Subsequent fetches retry discovery — discoveryFailed resets at the top
// of fetchJWKSLocked when no URI was cached, so a transient 5xx at
// startup doesn't lock the provider into the fallback path forever.
// Retry rate is bounded by the JWKS TTL (one retry per refresh cycle).
if err := p.fetchJWKS(context.Background()); err != nil {
t.Fatalf("fetchJWKS second: %v", err)
}
if got := idp.discoveryHits.Load(); got != 2 {
t.Fatalf("discovery probe should retry while no URI is cached, got %d hits", got)
}
}
func TestDiscoveryDisabledByExplicitJWKSUri(t *testing.T) {
idp := newFakeIDP(t)
override := idp.server.URL + "/custom/explicit"
idp.customJWKSPathSuffix = "explicit"
p := newProviderForIDP(t, idp, override)
if err := p.fetchJWKS(context.Background()); err != nil {
t.Fatalf("fetchJWKS: %v", err)
}
if got := idp.discoveryHits.Load(); got != 0 {
t.Fatalf("explicit JWKSUri should bypass discovery, got %d hits", got)
}
if got := idp.customJWKSHits.Load(); got != 1 {
t.Fatalf("expected 1 custom JWKS hit, got %d", got)
}
}
func TestDiscoveryRejectsIssuerMismatch(t *testing.T) {
idp := newFakeIDP(t)
idp.discoveryIssuer = "https://attacker.example/"
p := newProviderForIDP(t, idp, "")
if err := p.fetchJWKS(context.Background()); err != nil {
t.Fatalf("fetchJWKS should fall back to /.well-known/jwks.json, got error: %v", err)
}
// Discovery probe was tried once, rejected, then fell through to fallback path.
if got := idp.discoveryHits.Load(); got != 1 {
t.Fatalf("expected 1 discovery probe, got %d", got)
}
if got := idp.jwksHits.Load(); got != 1 {
t.Fatalf("expected 1 fallback JWKS hit, got %d", got)
}
}
// TestDiscoveryRejectsMissingIssuer: a discovery document that omits the
// issuer field entirely must be treated the same as one that supplies a
// mismatched issuer. Otherwise an attacker who can intercept the discovery
// response can strip the issuer field and the comparison silently passes,
// letting the document point fetchJWKS at any URL it pleases.
func TestDiscoveryRejectsMissingIssuer(t *testing.T) {
idp := newFakeIDP(t)
idp.omitDiscoveryIssuer = true
p := newProviderForIDP(t, idp, "")
if err := p.fetchJWKS(context.Background()); err != nil {
t.Fatalf("fetchJWKS should fall back to /.well-known/jwks.json on issuer-missing discovery: %v", err)
}
if got := idp.discoveryHits.Load(); got != 1 {
t.Fatalf("expected 1 discovery probe, got %d", got)
}
// The discovery document was rejected; the JWKS that ultimately served
// us must be the fallback one, not the discovered URI. The fakeIDP
// counts both hits under jwksHits since they share a counter; what
// matters is that customJWKSHits stayed zero.
if got := idp.customJWKSHits.Load(); got != 0 {
t.Fatalf("custom JWKS endpoint must not have been used, got %d hits", got)
}
if got := idp.jwksHits.Load(); got != 1 {
t.Fatalf("expected 1 fallback JWKS hit, got %d", got)
}
}
+161 -21
View File
@@ -16,6 +16,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
@@ -25,13 +26,21 @@ import (
// OIDCProvider implements OpenID Connect authentication // OIDCProvider implements OpenID Connect authentication
type OIDCProvider struct { type OIDCProvider struct {
name string name string
config *OIDCConfig config *OIDCConfig
initialized bool initialized bool
jwksCache *JWKS httpClient *http.Client
httpClient *http.Client jwksTTL time.Duration
jwksFetchedAt time.Time
jwksTTL time.Duration // mu guards the lazily-mutated cache fields below: jwksCache, jwksFetchedAt,
// resolvedJWKSUri, and discoveryFailed are all populated on the first
// validate-token call and refreshed when the cache expires. Multiple S3
// requests can land here in parallel, so they need synchronization.
mu sync.RWMutex
jwksCache *JWKS
jwksFetchedAt time.Time
resolvedJWKSUri string
discoveryFailed bool
} }
// OIDCConfig holds OIDC provider configuration // OIDCConfig holds OIDC provider configuration
@@ -272,6 +281,7 @@ func (p *OIDCProvider) Authenticate(ctx context.Context, token string) (*provide
Groups: groups, Groups: groups,
Attributes: attributes, Attributes: attributes,
Provider: p.name, Provider: p.name,
Issuer: claims.Issuer,
} }
// Pass the token expiration to limit session duration // Pass the token expiration to limit session duration
@@ -550,39 +560,169 @@ func (p *OIDCProvider) mapClaimsToRolesWithConfig(claims *providers.TokenClaims)
return roles return roles
} }
// getPublicKey retrieves the public key for the given key ID from JWKS // getPublicKey retrieves the public key for the given key ID from JWKS.
// Cache hits use the read lock so concurrent token validations don't
// serialize on JWKS lookup. Misses and expirations promote to the write
// lock so the JWKS fetch + cache write happens once per refresh cycle.
func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{}, error) { func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{}, error) {
// Fetch JWKS if not cached or refresh if expired // Fast path: read lock and look in cache.
if p.jwksCache == nil || (!p.jwksFetchedAt.IsZero() && time.Since(p.jwksFetchedAt) > p.jwksTTL) { p.mu.RLock()
if err := p.fetchJWKS(ctx); err != nil { if p.jwksCache != nil && (p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) <= p.jwksTTL) {
for _, key := range p.jwksCache.Keys {
if key.Kid == kid {
k := key
p.mu.RUnlock()
return p.parseJWK(&k)
}
}
}
p.mu.RUnlock()
// Slow path: take the write lock for the (re)fetch + retry. Re-check the
// cache under the write lock in case another goroutine already refreshed.
p.mu.Lock()
defer p.mu.Unlock()
cacheValid := p.jwksCache != nil && (p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) <= p.jwksTTL)
if !cacheValid {
if err := p.fetchJWKSLocked(ctx); err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %v", err) return nil, fmt.Errorf("failed to fetch JWKS: %v", err)
} }
} }
// Find the key with matching kid
for _, key := range p.jwksCache.Keys { for _, key := range p.jwksCache.Keys {
if key.Kid == kid { if key.Kid == kid {
return p.parseJWK(&key) k := key
return p.parseJWK(&k)
} }
} }
// Key not found in cache. Refresh JWKS once to handle key rotation and retry. // Key not found in cache. Refresh JWKS once to handle key rotation.
if err := p.fetchJWKS(ctx); err != nil { if err := p.fetchJWKSLocked(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh JWKS after key miss: %v", err) return nil, fmt.Errorf("failed to refresh JWKS after key miss: %v", err)
} }
for _, key := range p.jwksCache.Keys { for _, key := range p.jwksCache.Keys {
if key.Kid == kid { if key.Kid == kid {
return p.parseJWK(&key) k := key
return p.parseJWK(&k)
} }
} }
return nil, fmt.Errorf("key with ID %s not found in JWKS after refresh", kid) return nil, fmt.Errorf("key with ID %s not found in JWKS after refresh", kid)
} }
// fetchJWKS fetches the JWKS from the provider // discoveryDocument is the subset of the OpenID Provider Configuration we need.
// See https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata.
type discoveryDocument struct {
Issuer string `json:"issuer"`
JWKSUri string `json:"jwks_uri"`
}
// resolveJWKSUriLocked determines the JWKS URI for the provider. The caller
// must hold p.mu (write lock); the function reads/writes p.resolvedJWKSUri
// and p.discoveryFailed without taking the lock itself.
//
// Order of resolution:
// 1. explicit config.JWKSUri (operator override; never overridden by discovery).
// 2. cached resolvedJWKSUri from a prior discovery (refreshed when JWKS cache expires).
// 3. .well-known/openid-configuration discovery (per OIDC Discovery 1.0).
// 4. fallback to {issuer}/.well-known/jwks.json (compat path for IDPs that
// don't publish discovery).
func (p *OIDCProvider) resolveJWKSUriLocked(ctx context.Context) (string, error) {
if p.config.JWKSUri != "" {
return p.config.JWKSUri, nil
}
if p.resolvedJWKSUri != "" {
return p.resolvedJWKSUri, nil
}
issuer := strings.TrimSuffix(p.config.Issuer, "/")
if !p.discoveryFailed {
discoveryURL := issuer + "/.well-known/openid-configuration"
uri, err := p.fetchDiscoveryJWKSUri(ctx, discoveryURL)
switch {
case err == nil:
p.resolvedJWKSUri = uri
return uri, nil
default:
// Cache the failure so we don't pay the discovery RTT on every refresh.
// Operators with non-discovery IDPs see one failed lookup at startup.
glog.V(3).Infof("OIDC discovery at %s failed (%v); falling back to /.well-known/jwks.json", discoveryURL, err)
p.discoveryFailed = true
}
}
return issuer + "/.well-known/jwks.json", nil
}
// fetchDiscoveryJWKSUri retrieves the OIDC discovery document and returns
// the jwks_uri field. The issuer claim in the document must match config.Issuer
// to defend against issuer-substitution attacks during discovery.
func (p *OIDCProvider) fetchDiscoveryJWKSUri(ctx context.Context, discoveryURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", discoveryURL, nil)
if err != nil {
return "", fmt.Errorf("create discovery request: %v", err)
}
req.Header.Set("Accept", "application/json")
resp, err := p.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("fetch discovery document: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("discovery endpoint returned status %d", resp.StatusCode)
}
var doc discoveryDocument
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
return "", fmt.Errorf("decode discovery document: %v", err)
}
if doc.JWKSUri == "" {
return "", fmt.Errorf("discovery document missing jwks_uri")
}
// Issuer must be present and match: a discovery doc that points to a
// different issuer is either a misconfiguration or an attack against
// issuer-confusion, and a doc that omits the issuer field entirely
// would have bypassed the previous check (doc.Issuer != "" guard) and
// silently accepted whatever JWKS URI the document supplied. OIDC
// Discovery 1.0 §3 mandates the issuer field, so treat missing as a
// hard failure. Compare after trimming a single trailing slash on each
// side because real IdPs disagree on whether the configured issuer
// has one.
if strings.TrimSuffix(doc.Issuer, "/") != strings.TrimSuffix(p.config.Issuer, "/") {
return "", fmt.Errorf("discovery issuer %q does not match configured issuer %q", doc.Issuer, p.config.Issuer)
}
return doc.JWKSUri, nil
}
// fetchJWKS is a thin wrapper around fetchJWKSLocked that acquires the
// write lock. Used by tests; production callers in getPublicKey already
// hold the lock and call fetchJWKSLocked directly.
func (p *OIDCProvider) fetchJWKS(ctx context.Context) error { func (p *OIDCProvider) fetchJWKS(ctx context.Context) error {
jwksURL := p.config.JWKSUri p.mu.Lock()
if jwksURL == "" { defer p.mu.Unlock()
jwksURL = strings.TrimSuffix(p.config.Issuer, "/") + "/.well-known/jwks.json" return p.fetchJWKSLocked(ctx)
}
// fetchJWKSLocked fetches the JWKS from the provider. The caller must hold
// p.mu (write lock); the function writes p.jwksCache and p.jwksFetchedAt
// without taking the lock itself.
//
// Each fetch reattempts discovery if the previous attempt failed: a
// transient 5xx that flipped discoveryFailed at startup shouldn't lock the
// provider into the fallback path forever. The retry rate is bounded by
// the JWKS TTL (typically 1h), so the discovery RTT cost is amortized.
func (p *OIDCProvider) fetchJWKSLocked(ctx context.Context) error {
if p.config.JWKSUri == "" && p.resolvedJWKSUri == "" {
p.discoveryFailed = false
}
jwksURL, err := p.resolveJWKSUriLocked(ctx)
if err != nil {
return fmt.Errorf("resolve JWKS URI: %v", err)
} }
req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil) req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil)
+5
View File
@@ -49,6 +49,11 @@ type ExternalIdentity struct {
// Provider is the name of the identity provider // Provider is the name of the identity provider
Provider string `json:"provider"` Provider string `json:"provider"`
// Issuer is the OIDC `iss` claim (or equivalent) from the source token.
// Stable per (provider, identity) and used together with UserID to derive
// a stable parent-user hash that survives token rotation.
Issuer string `json:"issuer,omitempty"`
// TokenExpiration is the expiration time of the source identity token // TokenExpiration is the expiration time of the source identity token
// This is used to limit session duration to not exceed the token's exp claim // This is used to limit session duration to not exceed the token's exp claim
TokenExpiration *time.Time `json:"tokenExpiration,omitempty"` TokenExpiration *time.Time `json:"tokenExpiration,omitempty"`
+65
View File
@@ -0,0 +1,65 @@
package sts
import (
"strings"
"testing"
"time"
)
func TestComputeParentUserStability(t *testing.T) {
// Same (sub, iss) must produce the same hash, regardless of order or
// whitespace mutations callers should never apply.
a := ComputeParentUser("alice", "https://idp.example/")
b := ComputeParentUser("alice", "https://idp.example/")
if a == "" {
t.Fatal("parent user should not be empty for non-empty sub")
}
if a != b {
t.Fatalf("expected stable hash, got %q vs %q", a, b)
}
}
func TestComputeParentUserDistinguishesIssuer(t *testing.T) {
// The point of incorporating iss is that the same `sub` from two providers
// must not collide. If this assertion ever fails, the hash input is wrong.
a := ComputeParentUser("alice", "https://idp-a.example/")
b := ComputeParentUser("alice", "https://idp-b.example/")
if a == b {
t.Fatalf("hashes for different issuers must differ, both = %q", a)
}
}
func TestComputeParentUserDistinguishesSubject(t *testing.T) {
a := ComputeParentUser("alice", "https://idp.example/")
b := ComputeParentUser("bob", "https://idp.example/")
if a == b {
t.Fatalf("hashes for different subjects must differ, both = %q", a)
}
}
func TestComputeParentUserEmptySub(t *testing.T) {
if got := ComputeParentUser("", "https://idp.example/"); got != "" {
t.Fatalf("empty sub should produce empty parent user, got %q", got)
}
}
func TestComputeParentUserEncoding(t *testing.T) {
got := ComputeParentUser("alice", "https://idp.example/")
// Base64 RawURL has no padding and uses URL-safe alphabet — important
// because parent_user shows up in filer paths and audit log fields.
if strings.ContainsAny(got, "=+/") {
t.Fatalf("parent user should be base64 raw url, got %q", got)
}
}
func TestSessionClaimsRoundTripParentUser(t *testing.T) {
parent := ComputeParentUser("alice", "https://idp.example/")
claims := NewSTSSessionClaims("sid-1", "issuer", time.Now().Add(time.Hour)).
WithRoleInfo("arn:aws:iam::123:role/r", "arn:aws:sts::123:assumed-role/r/s", "arn:aws:sts::123:assumed-role/r/s").
WithParentUser(parent)
info := claims.ToSessionInfo()
if info.ParentUser != parent {
t.Fatalf("ParentUser lost on round-trip: got %q want %q", info.ParentUser, parent)
}
}
+30
View File
@@ -1,6 +1,8 @@
package sts package sts
import ( import (
"crypto/sha256"
"encoding/base64"
"fmt" "fmt"
"time" "time"
@@ -8,6 +10,20 @@ import (
"github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/glog"
) )
// ComputeParentUser returns a stable per-identity hash derived from the OIDC
// (sub, iss) tuple. Only the (sub, iss) pair is guaranteed stable across token
// refreshes per OpenID Connect Core 1.0 §5.7, so any per-user state (audit
// logs, quotas) must key off this value rather than the access-key or session
// id. The hash is base64-rawurl-encoded SHA-256 over "openid:<sub>:<iss>" so
// it stays filesystem-safe and bounded in length for storage in audit paths.
func ComputeParentUser(sub, iss string) string {
if sub == "" {
return ""
}
h := sha256.Sum256([]byte("openid:" + sub + ":" + iss))
return base64.RawURLEncoding.EncodeToString(h[:])
}
// defaultCredentialGenerator is a reusable instance for generating temporary credentials // defaultCredentialGenerator is a reusable instance for generating temporary credentials
// Reusing a single instance across all calls to ToSessionInfo() reduces allocation overhead // Reusing a single instance across all calls to ToSessionInfo() reduces allocation overhead
// since this method may be called frequently during signature verification // since this method may be called frequently during signature verification
@@ -45,6 +61,12 @@ type STSSessionClaims struct {
// Session metadata // Session metadata
AssumedAt time.Time `json:"assumed_at"` // when role was assumed AssumedAt time.Time `json:"assumed_at"` // when role was assumed
MaxDuration int64 `json:"max_dur,omitempty"` // maximum session duration in seconds MaxDuration int64 `json:"max_dur,omitempty"` // maximum session duration in seconds
// ParentUser is a stable hash of (sub, iss) for tokens minted from an OIDC
// identity. It survives token rotation since only the (sub, iss) tuple is
// guaranteed stable per OpenID Connect Core 1.0. Empty for non-federated
// session types.
ParentUser string `json:"puid,omitempty"`
} }
// NewSTSSessionClaims creates new STS session claims with all required information // NewSTSSessionClaims creates new STS session claims with all required information
@@ -96,6 +118,7 @@ func (c *STSSessionClaims) ToSessionInfo() *SessionInfo {
ExternalUserId: c.ExternalUserId, ExternalUserId: c.ExternalUserId,
ProviderIssuer: c.ProviderIssuer, ProviderIssuer: c.ProviderIssuer,
RequestContext: c.RequestContext, RequestContext: c.RequestContext,
ParentUser: c.ParentUser,
// Provide the Subject (sub) from registered claims // Provide the Subject (sub) from registered claims
Subject: c.Subject, Subject: c.Subject,
Credentials: credentials, Credentials: credentials,
@@ -182,3 +205,10 @@ func (c *STSSessionClaims) WithSessionName(sessionName string) *STSSessionClaims
c.SessionName = sessionName c.SessionName = sessionName
return c return c
} }
// WithParentUser sets the stable per-identity hash for the session. See
// ComputeParentUser for the derivation rule.
func (c *STSSessionClaims) WithParentUser(parentUser string) *STSSessionClaims {
c.ParentUser = parentUser
return c
}
+17 -1
View File
@@ -254,6 +254,9 @@ type SessionInfo struct {
// Credentials are the temporary credentials for this session // Credentials are the temporary credentials for this session
Credentials *Credentials `json:"credentials"` Credentials *Credentials `json:"credentials"`
// ParentUser is the stable hashed identity (sub+iss) derived at federation time.
ParentUser string `json:"parentUser,omitempty"`
} }
// NewSTSService creates a new STS service // NewSTSService creates a new STS service
@@ -506,13 +509,26 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass
// Add sub as well since it's commonly used // Add sub as well since it's commonly used
requestContext["sub"] = externalIdentity.UserID requestContext["sub"] = externalIdentity.UserID
// Compute a stable parent-user hash from (sub, iss). Only this tuple is
// guaranteed stable across token refresh per OIDC Core 1.0, so this is the
// right key for any per-identity state (audit trail, future quotas).
parentUser := ComputeParentUser(externalIdentity.UserID, externalIdentity.Issuer)
if parentUser != "" {
// Surface as aws:userid so policies can reference it directly without
// caring about token-rotation churn.
requestContext["aws:userid"] = parentUser
}
// Create rich JWT claims with all session information // Create rich JWT claims with all session information
sessionClaims := NewSTSSessionClaims(sessionId, s.Config.Issuer, expiresAt). sessionClaims := NewSTSSessionClaims(sessionId, s.Config.Issuer, expiresAt).
WithSessionName(request.RoleSessionName). WithSessionName(request.RoleSessionName).
WithRoleInfo(request.RoleArn, assumedRoleUser.Arn, assumedRoleUser.Arn). WithRoleInfo(request.RoleArn, assumedRoleUser.Arn, assumedRoleUser.Arn).
WithIdentityProvider(provider.Name(), externalIdentity.UserID, ""). WithIdentityProvider(provider.Name(), externalIdentity.UserID, externalIdentity.Issuer).
WithMaxDuration(sessionDuration). WithMaxDuration(sessionDuration).
WithRequestContext(requestContext) WithRequestContext(requestContext)
if parentUser != "" {
sessionClaims.WithParentUser(parentUser)
}
if sessionPolicy != "" { if sessionPolicy != "" {
sessionClaims.WithSessionPolicy(sessionPolicy) sessionClaims.WithSessionPolicy(sessionPolicy)
} }
+69 -17
View File
@@ -53,6 +53,34 @@ const (
// federationNameRegex validates the Name parameter for GetFederationToken per AWS spec // federationNameRegex validates the Name parameter for GetFederationToken per AWS spec
var federationNameRegex = regexp.MustCompile(`^[\w+=,.@-]+$`) var federationNameRegex = regexp.MustCompile(`^[\w+=,.@-]+$`)
// roleSessionNameRegex validates RoleSessionName per AWS spec.
// Same character class as federation Name, but the length bounds differ
// (RoleSessionName is 2..64).
var roleSessionNameRegex = regexp.MustCompile(`^[\w+=,.@-]+$`)
const (
minRoleSessionNameLen = 2
maxRoleSessionNameLen = 64
)
// validateRoleSessionName enforces the AWS RoleSessionName contract:
// length 2..64, characters [\w+=,.@-]+. Returns the STS error code and a
// descriptive error suitable for callers to surface to the caller.
func validateRoleSessionName(name string) (STSErrorCode, error) {
if name == "" {
return STSErrMissingParameter, fmt.Errorf("RoleSessionName is required")
}
if len(name) < minRoleSessionNameLen || len(name) > maxRoleSessionNameLen {
return STSErrInvalidParameterValue,
fmt.Errorf("RoleSessionName must be between %d and %d characters", minRoleSessionNameLen, maxRoleSessionNameLen)
}
if !roleSessionNameRegex.MatchString(name) {
return STSErrInvalidParameterValue,
fmt.Errorf(`RoleSessionName contains invalid characters; allowed: [\w+=,.@-]`)
}
return "", nil
}
// STS duration constants (AWS specification) // STS duration constants (AWS specification)
const ( const (
minDurationSeconds = int64(900) // 15 minutes minDurationSeconds = int64(900) // 15 minutes
@@ -61,6 +89,27 @@ const (
maxFederationDurationSeconds = int64(129600) // 36 hours (GetFederationToken max) maxFederationDurationSeconds = int64(129600) // 36 hours (GetFederationToken max)
) )
// AWS limits inline session policies to 2048 characters for AssumeRole,
// AssumeRoleWithWebIdentity, and AssumeRoleWithSAML. PackedPolicySize is
// returned as a percentage of that budget so callers can detect how close
// they are to the limit.
const sessionPolicyBudgetBytes = 2048
// computePackedPolicySize returns the inline session policy size as a
// percentage of the per-action budget, or nil when no session policy was
// provided. Output is bounded to [0, 100] for AWS-compat reporting; the
// actual policy size validation happens upstream in NormalizeSessionPolicy.
func computePackedPolicySize(policyJSON string) *int64 {
if policyJSON == "" {
return nil
}
pct := int64(len(policyJSON)) * 100 / sessionPolicyBudgetBytes
if pct > 100 {
pct = 100
}
return &pct
}
// parseDurationSecondsWithBounds parses and validates the DurationSeconds parameter // parseDurationSecondsWithBounds parses and validates the DurationSeconds parameter
// against the given min and max bounds. Returns nil if the parameter is not provided. // against the given min and max bounds. Returns nil if the parameter is not provided.
func parseDurationSecondsWithBounds(r *http.Request, minSec, maxSec int64) (*int64, STSErrorCode, error) { func parseDurationSecondsWithBounds(r *http.Request, minSec, maxSec int64) (*int64, STSErrorCode, error) {
@@ -170,9 +219,8 @@ func (h *STSHandlers) handleAssumeRoleWithWebIdentity(w http.ResponseWriter, r *
return return
} }
if roleSessionName == "" { if errCode, err := validateRoleSessionName(roleSessionName); err != nil {
h.writeSTSErrorResponse(w, r, STSErrMissingParameter, h.writeSTSErrorResponse(w, r, errCode, err)
fmt.Errorf("RoleSessionName is required"))
return return
} }
@@ -245,6 +293,7 @@ func (h *STSHandlers) handleAssumeRoleWithWebIdentity(w http.ResponseWriter, r *
Expiration: response.Credentials.Expiration.Format(time.RFC3339), Expiration: response.Credentials.Expiration.Format(time.RFC3339),
}, },
SubjectFromWebIdentityToken: response.AssumedRoleUser.Subject, SubjectFromWebIdentityToken: response.AssumedRoleUser.Subject,
PackedPolicySize: computePackedPolicySize(sessionPolicyJSON),
}, },
} }
xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r)
@@ -264,9 +313,8 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) {
// Validate required parameters // Validate required parameters
// RoleArn is optional to support S3-compatible clients that omit it // RoleArn is optional to support S3-compatible clients that omit it
if roleSessionName == "" { if errCode, err := validateRoleSessionName(roleSessionName); err != nil {
h.writeSTSErrorResponse(w, r, STSErrMissingParameter, h.writeSTSErrorResponse(w, r, errCode, err)
fmt.Errorf("RoleSessionName is required"))
return return
} }
@@ -373,8 +421,9 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) {
// Build and return response // Build and return response
xmlResponse := &AssumeRoleResponse{ xmlResponse := &AssumeRoleResponse{
Result: AssumeRoleResult{ Result: AssumeRoleResult{
Credentials: stsCreds, Credentials: stsCreds,
AssumedRoleUser: assumedUser, AssumedRoleUser: assumedUser,
PackedPolicySize: computePackedPolicySize(sessionPolicyJSON),
}, },
} }
xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r)
@@ -397,9 +446,8 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r
return return
} }
if roleSessionName == "" { if errCode, err := validateRoleSessionName(roleSessionName); err != nil {
h.writeSTSErrorResponse(w, r, STSErrMissingParameter, h.writeSTSErrorResponse(w, r, errCode, err)
fmt.Errorf("RoleSessionName is required"))
return return
} }
@@ -514,8 +562,9 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r
// Build and return response // Build and return response
xmlResponse := &AssumeRoleWithLDAPIdentityResponse{ xmlResponse := &AssumeRoleWithLDAPIdentityResponse{
Result: LDAPIdentityResult{ Result: LDAPIdentityResult{
Credentials: stsCreds, Credentials: stsCreds,
AssumedRoleUser: assumedUser, AssumedRoleUser: assumedUser,
PackedPolicySize: computePackedPolicySize(sessionPolicyJSON),
}, },
} }
xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r)
@@ -906,6 +955,7 @@ type WebIdentityResult struct {
Credentials STSCredentials `xml:"Credentials"` Credentials STSCredentials `xml:"Credentials"`
SubjectFromWebIdentityToken string `xml:"SubjectFromWebIdentityToken,omitempty"` SubjectFromWebIdentityToken string `xml:"SubjectFromWebIdentityToken,omitempty"`
AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"`
PackedPolicySize *int64 `xml:"PackedPolicySize,omitempty"`
} }
// STSCredentials represents temporary security credentials // STSCredentials represents temporary security credentials
@@ -933,8 +983,9 @@ type AssumeRoleResponse struct {
// AssumeRoleResult contains the result of AssumeRole // AssumeRoleResult contains the result of AssumeRole
type AssumeRoleResult struct { type AssumeRoleResult struct {
Credentials STSCredentials `xml:"Credentials"` Credentials STSCredentials `xml:"Credentials"`
AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"`
PackedPolicySize *int64 `xml:"PackedPolicySize,omitempty"`
} }
// AssumeRoleWithLDAPIdentityResponse is the response for AssumeRoleWithLDAPIdentity // AssumeRoleWithLDAPIdentityResponse is the response for AssumeRoleWithLDAPIdentity
@@ -948,8 +999,9 @@ type AssumeRoleWithLDAPIdentityResponse struct {
// LDAPIdentityResult contains the result of AssumeRoleWithLDAPIdentity // LDAPIdentityResult contains the result of AssumeRoleWithLDAPIdentity
type LDAPIdentityResult struct { type LDAPIdentityResult struct {
Credentials STSCredentials `xml:"Credentials"` Credentials STSCredentials `xml:"Credentials"`
AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"`
PackedPolicySize *int64 `xml:"PackedPolicySize,omitempty"`
} }
// GetCallerIdentityResponse is the response for GetCallerIdentity // GetCallerIdentityResponse is the response for GetCallerIdentity
+34
View File
@@ -0,0 +1,34 @@
package s3api
import "testing"
func TestComputePackedPolicySize(t *testing.T) {
cases := []struct {
name string
policyLen int
empty bool
want int64
}{
{"empty -> nil", 0, true, 0},
{"tiny policy -> 0%", 10, false, 0},
{"half budget -> 50%", sessionPolicyBudgetBytes / 2, false, 50},
{"full budget -> 100%", sessionPolicyBudgetBytes, false, 100},
{"oversized -> capped at 100", sessionPolicyBudgetBytes * 3, false, 100},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
policy := repeat('a', tc.policyLen)
got := computePackedPolicySize(policy)
switch {
case tc.empty:
if got != nil {
t.Fatalf("expected nil for empty input, got %d", *got)
}
case got == nil:
t.Fatalf("expected non-nil result, got nil")
case *got != tc.want:
t.Fatalf("got=%d want=%d", *got, tc.want)
}
})
}
}
+47
View File
@@ -0,0 +1,47 @@
package s3api
import "testing"
func TestValidateRoleSessionName(t *testing.T) {
cases := []struct {
name string
input string
wantErr bool
// wantCode is checked only when wantErr is true
wantCode STSErrorCode
}{
{"empty rejected", "", true, STSErrMissingParameter},
{"single char rejected (below min len 2)", "a", true, STSErrInvalidParameterValue},
{"min length 2 accepted", "ab", false, ""},
{"plain ascii accepted", "session-name_1", false, ""},
{"all special chars allowed", "+=,.@-", false, ""},
{"email-style accepted", "alice@example.com", false, ""},
{"max length 64 accepted", string(make([]byte, 64)), true, STSErrInvalidParameterValue}, // zero bytes -> invalid charset
{"max length 64 valid charset accepted", repeat('a', 64), false, ""},
{"length 65 rejected", repeat('a', 65), true, STSErrInvalidParameterValue},
{"space rejected", "alice bob", true, STSErrInvalidParameterValue},
{"slash rejected", "alice/bob", true, STSErrInvalidParameterValue},
{"colon rejected", "alice:bob", true, STSErrInvalidParameterValue},
{"unicode rejected", "alicé", true, STSErrInvalidParameterValue},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
code, err := validateRoleSessionName(tc.input)
gotErr := err != nil
if gotErr != tc.wantErr {
t.Fatalf("err mismatch: got=%v want=%v (err=%v)", gotErr, tc.wantErr, err)
}
if tc.wantErr && code != tc.wantCode {
t.Fatalf("code mismatch: got=%s want=%s", code, tc.wantCode)
}
})
}
}
func repeat(b byte, n int) string {
out := make([]byte, n)
for i := range out {
out[i] = b
}
return string(out)
}