From d951a8df5a72a02fb8029c4f340b6be346899afa Mon Sep 17 00:00:00 2001 From: Chris Lu Date: Mon, 4 May 2026 22:10:49 -0700 Subject: [PATCH] feat(iam): STS web-identity AWS-fidelity polish (Phase 1) (#9318) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(iam): STS web-identity AWS-fidelity polish - OIDC discovery via .well-known/openid-configuration; falls back to /.well-known/jwks.json when discovery is absent. Reject discovery docs whose issuer claim does not match the configured issuer to defend against issuer-substitution. - ComputeParentUser derives a stable per-identity hash from (sub, iss). Surface as aws:userid in the request context and as a parent_user claim in the session JWT so per-user state survives token rotation. - Per-role MaxSessionDuration (3600..43200) clamps requested DurationSeconds before the STS service applies its own caps. - Tighten RoleSessionName to the AWS contract: 2..64 chars from [\w+=,.@-]. - Populate PackedPolicySize in AssumeRole / AssumeRoleWithWebIdentity / AssumeRoleWithLDAPIdentity responses as a percentage of the 2048-byte inline session policy budget. * fix(iam): leave omitted DurationSeconds nil so STS default applies capDurationByRole was substituting the role's MaxSessionDuration when the caller omitted DurationSeconds entirely. AWS returns the configured default (typically 1 hour) in that case, not the role's upper bound — a 12h MaxSessionDuration shouldn't silently make every no-duration assume-role mint a 12h session. Return nil when requested is nil; let the downstream calculateSessionDuration in the STS service apply its TokenDuration default. The role-max upper bound still clamps when the request arrives with a concrete value above the cap. Addresses gemini high-priority review on PR #9318. * fix(iam): synchronize OIDCProvider JWKS cache fields jwksCache, jwksFetchedAt, resolvedJWKSUri, and discoveryFailed are mutated lazily on the first token-validate call and refreshed afterwards on TTL expiry. Multiple S3 requests can land here in parallel, so the writes were racing against subsequent reads on every other goroutine. resolvedJWKSUri/discoveryFailed inherited the same un-protected pattern when discovery shipped. Add sync.RWMutex; getPublicKey takes the read lock for the common cache-hit path and promotes to the write lock for misses + refreshes. fetchJWKSLocked / resolveJWKSUriLocked assume the write lock is held by the caller; fetchJWKS keeps the test-friendly entry point that acquires the lock itself. Addresses gemini high-priority review on PR #9318. * fix(iam): trim trailing slash + retry discovery after transient failure Two OIDC discovery edge cases reviewers flagged: 1. Issuer comparison was sensitive to trailing slashes. resolveJWKSUri trims them when building the discovery URL, but the doc.Issuer ↔ p.config.Issuer check did not, so an IDP whose issuer claim drops or adds the slash relative to the configured value would be falsely rejected. Trim a single trailing slash on each side before comparing. 2. discoveryFailed flipped to true on any error and stayed there for the process lifetime. A transient 5xx at startup permanently locked the provider into the /.well-known/jwks.json fallback. Reset the flag at the top of fetchJWKSLocked when no URI has been cached yet, so each JWKS refresh (typically once per TTL = 1h) reattempts discovery. Successful discovery remains cached via resolvedJWKSUri so we don't pay the discovery RTT on every refresh. Addresses gemini security-medium + medium reviews on PR #9318. * fix(iam): require non-empty issuer in OIDC discovery doc The previous "doc.Issuer != "" && ..." guard let a discovery document that omitted the issuer field bypass the issuer-mismatch check entirely, letting the doc steer fetchJWKS at any URL it provided. OIDC Discovery 1.0 §3 mandates the issuer field; treat missing as a hard failure same as mismatched. Trailing-slash equivalence still applies. Adds TestDiscoveryRejectsMissingIssuer alongside the existing TestDiscoveryRejectsIssuerMismatch via a new omitDiscoveryIssuer toggle on fakeIDP. --- weed/iam/integration/iam_manager.go | 39 ++++ weed/iam/integration/role_max_session_test.go | 34 +++ weed/iam/oidc/oidc_discovery_test.go | 197 ++++++++++++++++++ weed/iam/oidc/oidc_provider.go | 182 ++++++++++++++-- weed/iam/providers/provider.go | 5 + weed/iam/sts/parent_user_test.go | 65 ++++++ weed/iam/sts/session_claims.go | 30 +++ weed/iam/sts/sts_service.go | 18 +- weed/s3api/s3api_sts.go | 86 ++++++-- weed/s3api/sts_packed_policy_test.go | 34 +++ weed/s3api/sts_session_name_test.go | 47 +++++ 11 files changed, 698 insertions(+), 39 deletions(-) create mode 100644 weed/iam/integration/role_max_session_test.go create mode 100644 weed/iam/oidc/oidc_discovery_test.go create mode 100644 weed/iam/sts/parent_user_test.go create mode 100644 weed/s3api/sts_packed_policy_test.go create mode 100644 weed/s3api/sts_session_name_test.go diff --git a/weed/iam/integration/iam_manager.go b/weed/iam/integration/iam_manager.go index e3af56328..4af2587c3 100644 --- a/weed/iam/integration/iam_manager.go +++ b/weed/iam/integration/iam_manager.go @@ -75,6 +75,12 @@ type RoleDefinition struct { // Description is an optional description of the role Description string `json:"description,omitempty"` + + // MaxSessionDuration is the upper bound (in seconds) on session length when + // callers assume this role. Zero means "use the global STS default". When + // set it must satisfy AWS bounds: 3600 ≤ MaxSessionDuration ≤ 43200. + // Honoured by AssumeRole, AssumeRoleWithWebIdentity, AssumeRoleWithCredentials. + MaxSessionDuration int64 `json:"maxSessionDuration,omitempty"` } // ActionRequest represents a request to perform an action @@ -272,6 +278,13 @@ func (m *IAMManager) CreateRole(ctx context.Context, filerAddress string, roleNa } } + // Validate per-role MaxSessionDuration if specified. AWS bounds: 1h..12h. + if roleDef.MaxSessionDuration != 0 { + if roleDef.MaxSessionDuration < 3600 || roleDef.MaxSessionDuration > 43200 { + return fmt.Errorf("MaxSessionDuration must be between 3600 and 43200 seconds, got %d", roleDef.MaxSessionDuration) + } + } + // Store role definition return m.roleStore.StoreRole(ctx, "", roleName, roleDef) } @@ -329,10 +342,33 @@ func (m *IAMManager) AssumeRoleWithWebIdentity(ctx context.Context, request *sts return nil, fmt.Errorf("trust policy validation failed: %w", err) } + // Apply role-level MaxSessionDuration cap. The STS service still applies + // the global MaxSessionLength and the source-token-expiry cap on top of + // this; per-role takes precedence whenever it is the tightest bound. + request.DurationSeconds = capDurationByRole(request.DurationSeconds, roleDef.MaxSessionDuration) + // Use STS service to assume the role return m.stsService.AssumeRoleWithWebIdentity(ctx, request) } +// capDurationByRole returns the requested duration clamped to the role's +// MaxSessionDuration. A nil requested duration is left nil so the STS +// service's calculateSessionDuration applies the global default (typically +// 1 hour) — substituting the role's max here would silently mint a 12h +// session for any caller who omitted DurationSeconds, which AWS does not +// do. The role-max upper bound still applies in the downstream cap chain +// once the request has a concrete duration. +func capDurationByRole(requested *int64, roleMax int64) *int64 { + if roleMax <= 0 || requested == nil { + return requested + } + if *requested > roleMax { + v := roleMax + return &v + } + return requested +} + // AssumeRoleWithCredentials assumes a role using credentials (LDAP) func (m *IAMManager) AssumeRoleWithCredentials(ctx context.Context, request *sts.AssumeRoleWithCredentialsRequest) (*sts.AssumeRoleResponse, error) { if !m.initialized { @@ -353,6 +389,9 @@ func (m *IAMManager) AssumeRoleWithCredentials(ctx context.Context, request *sts return nil, fmt.Errorf("trust policy validation failed: %w", err) } + // Apply role-level MaxSessionDuration cap. + request.DurationSeconds = capDurationByRole(request.DurationSeconds, roleDef.MaxSessionDuration) + // Use STS service to assume the role return m.stsService.AssumeRoleWithCredentials(ctx, request) } diff --git a/weed/iam/integration/role_max_session_test.go b/weed/iam/integration/role_max_session_test.go new file mode 100644 index 000000000..9f011c71f --- /dev/null +++ b/weed/iam/integration/role_max_session_test.go @@ -0,0 +1,34 @@ +package integration + +import "testing" + +func intPtr(v int64) *int64 { return &v } + +func TestCapDurationByRole(t *testing.T) { + cases := []struct { + name string + requested *int64 + roleMax int64 + want *int64 + }{ + {"no cap, no request", nil, 0, nil}, + {"no cap, with request", intPtr(7200), 0, intPtr(7200)}, + {"cap only, no request -> nil so STS default applies", nil, 3600, nil}, + {"request below cap -> request", intPtr(1800), 3600, intPtr(1800)}, + {"request equal cap -> request", intPtr(3600), 3600, intPtr(3600)}, + {"request above cap -> cap", intPtr(43200), 3600, intPtr(3600)}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := capDurationByRole(tc.requested, tc.roleMax) + switch { + case got == nil && tc.want == nil: + return + case got == nil || tc.want == nil: + t.Fatalf("nilness mismatch: got=%v want=%v", got, tc.want) + case *got != *tc.want: + t.Fatalf("got=%d want=%d", *got, *tc.want) + } + }) + } +} diff --git a/weed/iam/oidc/oidc_discovery_test.go b/weed/iam/oidc/oidc_discovery_test.go new file mode 100644 index 000000000..66e4d75e9 --- /dev/null +++ b/weed/iam/oidc/oidc_discovery_test.go @@ -0,0 +1,197 @@ +package oidc + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" +) + +// fakeIDP wraps an httptest.Server and counts how many times each well-known +// endpoint is hit. Tests use it to assert discovery vs. fallback behaviour. +type fakeIDP struct { + server *httptest.Server + discoveryHits atomic.Int32 + jwksHits atomic.Int32 + customJWKSHits atomic.Int32 + disableDiscovery bool + discoveryStatusCode int + discoveryIssuer string + omitDiscoveryIssuer bool // when true, the discovery doc omits the "issuer" field entirely + customJWKSPathSuffix string // optional suffix that fakeIDP serves at /custom/ + jwks JWKS +} + +func newFakeIDP(t *testing.T) *fakeIDP { + t.Helper() + idp := &fakeIDP{ + discoveryStatusCode: http.StatusOK, + jwks: JWKS{Keys: []JWK{{Kty: "RSA", Kid: "k1", Use: "sig", Alg: "RS256", N: "AQAB", E: "AQAB"}}}, + } + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + idp.discoveryHits.Add(1) + if idp.disableDiscovery { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(idp.discoveryStatusCode) + issuer := idp.discoveryIssuer + if issuer == "" { + issuer = idp.server.URL + } + jwksURI := idp.server.URL + "/discovered/jwks" + if idp.customJWKSPathSuffix != "" { + jwksURI = idp.server.URL + "/custom/" + idp.customJWKSPathSuffix + } + body := map[string]string{"jwks_uri": jwksURI} + if !idp.omitDiscoveryIssuer { + body["issuer"] = issuer + } + _ = json.NewEncoder(w).Encode(body) + }) + mux.HandleFunc("/discovered/jwks", func(w http.ResponseWriter, r *http.Request) { + idp.jwksHits.Add(1) + _ = json.NewEncoder(w).Encode(idp.jwks) + }) + mux.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { + idp.jwksHits.Add(1) + _ = json.NewEncoder(w).Encode(idp.jwks) + }) + mux.HandleFunc("/custom/", func(w http.ResponseWriter, r *http.Request) { + idp.customJWKSHits.Add(1) + _ = json.NewEncoder(w).Encode(idp.jwks) + }) + idp.server = httptest.NewServer(mux) + t.Cleanup(idp.server.Close) + return idp +} + +func newProviderForIDP(t *testing.T, idp *fakeIDP, jwksURIOverride string) *OIDCProvider { + t.Helper() + p := NewOIDCProvider("test") + cfg := &OIDCConfig{ + Issuer: idp.server.URL, + ClientID: "test-client", + JWKSUri: jwksURIOverride, + } + if err := p.Initialize(cfg); err != nil { + t.Fatalf("Initialize: %v", err) + } + return p +} + +func TestDiscoveryHappyPath(t *testing.T) { + idp := newFakeIDP(t) + p := newProviderForIDP(t, idp, "") + + if err := p.fetchJWKS(context.Background()); err != nil { + t.Fatalf("fetchJWKS: %v", err) + } + if got := idp.discoveryHits.Load(); got != 1 { + t.Fatalf("expected 1 discovery hit, got %d", got) + } + if got := idp.jwksHits.Load(); got != 1 { + t.Fatalf("expected 1 JWKS hit at discovered uri, got %d", got) + } + + // A second fetch reuses the cached jwks_uri without re-discovering. + if err := p.fetchJWKS(context.Background()); err != nil { + t.Fatalf("fetchJWKS second: %v", err) + } + if got := idp.discoveryHits.Load(); got != 1 { + t.Fatalf("discovery should be cached, got %d hits", got) + } +} + +func TestDiscoveryFallback404(t *testing.T) { + idp := newFakeIDP(t) + idp.disableDiscovery = true + p := newProviderForIDP(t, idp, "") + + if err := p.fetchJWKS(context.Background()); err != nil { + t.Fatalf("fetchJWKS: %v", err) + } + if got := idp.discoveryHits.Load(); got != 1 { + t.Fatalf("expected 1 discovery probe, got %d", got) + } + if got := idp.jwksHits.Load(); got != 1 { + t.Fatalf("expected 1 JWKS hit at fallback uri, got %d", got) + } + + // Subsequent fetches retry discovery — discoveryFailed resets at the top + // of fetchJWKSLocked when no URI was cached, so a transient 5xx at + // startup doesn't lock the provider into the fallback path forever. + // Retry rate is bounded by the JWKS TTL (one retry per refresh cycle). + if err := p.fetchJWKS(context.Background()); err != nil { + t.Fatalf("fetchJWKS second: %v", err) + } + if got := idp.discoveryHits.Load(); got != 2 { + t.Fatalf("discovery probe should retry while no URI is cached, got %d hits", got) + } +} + +func TestDiscoveryDisabledByExplicitJWKSUri(t *testing.T) { + idp := newFakeIDP(t) + override := idp.server.URL + "/custom/explicit" + idp.customJWKSPathSuffix = "explicit" + p := newProviderForIDP(t, idp, override) + + if err := p.fetchJWKS(context.Background()); err != nil { + t.Fatalf("fetchJWKS: %v", err) + } + if got := idp.discoveryHits.Load(); got != 0 { + t.Fatalf("explicit JWKSUri should bypass discovery, got %d hits", got) + } + if got := idp.customJWKSHits.Load(); got != 1 { + t.Fatalf("expected 1 custom JWKS hit, got %d", got) + } +} + +func TestDiscoveryRejectsIssuerMismatch(t *testing.T) { + idp := newFakeIDP(t) + idp.discoveryIssuer = "https://attacker.example/" + p := newProviderForIDP(t, idp, "") + + if err := p.fetchJWKS(context.Background()); err != nil { + t.Fatalf("fetchJWKS should fall back to /.well-known/jwks.json, got error: %v", err) + } + // Discovery probe was tried once, rejected, then fell through to fallback path. + if got := idp.discoveryHits.Load(); got != 1 { + t.Fatalf("expected 1 discovery probe, got %d", got) + } + if got := idp.jwksHits.Load(); got != 1 { + t.Fatalf("expected 1 fallback JWKS hit, got %d", got) + } +} + +// TestDiscoveryRejectsMissingIssuer: a discovery document that omits the +// issuer field entirely must be treated the same as one that supplies a +// mismatched issuer. Otherwise an attacker who can intercept the discovery +// response can strip the issuer field and the comparison silently passes, +// letting the document point fetchJWKS at any URL it pleases. +func TestDiscoveryRejectsMissingIssuer(t *testing.T) { + idp := newFakeIDP(t) + idp.omitDiscoveryIssuer = true + p := newProviderForIDP(t, idp, "") + + if err := p.fetchJWKS(context.Background()); err != nil { + t.Fatalf("fetchJWKS should fall back to /.well-known/jwks.json on issuer-missing discovery: %v", err) + } + if got := idp.discoveryHits.Load(); got != 1 { + t.Fatalf("expected 1 discovery probe, got %d", got) + } + // The discovery document was rejected; the JWKS that ultimately served + // us must be the fallback one, not the discovered URI. The fakeIDP + // counts both hits under jwksHits since they share a counter; what + // matters is that customJWKSHits stayed zero. + if got := idp.customJWKSHits.Load(); got != 0 { + t.Fatalf("custom JWKS endpoint must not have been used, got %d hits", got) + } + if got := idp.jwksHits.Load(); got != 1 { + t.Fatalf("expected 1 fallback JWKS hit, got %d", got) + } +} diff --git a/weed/iam/oidc/oidc_provider.go b/weed/iam/oidc/oidc_provider.go index dda970665..c9f289aa6 100644 --- a/weed/iam/oidc/oidc_provider.go +++ b/weed/iam/oidc/oidc_provider.go @@ -16,6 +16,7 @@ import ( "os" "path/filepath" "strings" + "sync" "time" "github.com/golang-jwt/jwt/v5" @@ -25,13 +26,21 @@ import ( // OIDCProvider implements OpenID Connect authentication type OIDCProvider struct { - name string - config *OIDCConfig - initialized bool - jwksCache *JWKS - httpClient *http.Client - jwksFetchedAt time.Time - jwksTTL time.Duration + name string + config *OIDCConfig + initialized bool + httpClient *http.Client + jwksTTL time.Duration + + // mu guards the lazily-mutated cache fields below: jwksCache, jwksFetchedAt, + // resolvedJWKSUri, and discoveryFailed are all populated on the first + // validate-token call and refreshed when the cache expires. Multiple S3 + // requests can land here in parallel, so they need synchronization. + mu sync.RWMutex + jwksCache *JWKS + jwksFetchedAt time.Time + resolvedJWKSUri string + discoveryFailed bool } // OIDCConfig holds OIDC provider configuration @@ -272,6 +281,7 @@ func (p *OIDCProvider) Authenticate(ctx context.Context, token string) (*provide Groups: groups, Attributes: attributes, Provider: p.name, + Issuer: claims.Issuer, } // Pass the token expiration to limit session duration @@ -550,39 +560,169 @@ func (p *OIDCProvider) mapClaimsToRolesWithConfig(claims *providers.TokenClaims) return roles } -// getPublicKey retrieves the public key for the given key ID from JWKS +// getPublicKey retrieves the public key for the given key ID from JWKS. +// Cache hits use the read lock so concurrent token validations don't +// serialize on JWKS lookup. Misses and expirations promote to the write +// lock so the JWKS fetch + cache write happens once per refresh cycle. func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{}, error) { - // Fetch JWKS if not cached or refresh if expired - if p.jwksCache == nil || (!p.jwksFetchedAt.IsZero() && time.Since(p.jwksFetchedAt) > p.jwksTTL) { - if err := p.fetchJWKS(ctx); err != nil { + // Fast path: read lock and look in cache. + p.mu.RLock() + if p.jwksCache != nil && (p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) <= p.jwksTTL) { + for _, key := range p.jwksCache.Keys { + if key.Kid == kid { + k := key + p.mu.RUnlock() + return p.parseJWK(&k) + } + } + } + p.mu.RUnlock() + + // Slow path: take the write lock for the (re)fetch + retry. Re-check the + // cache under the write lock in case another goroutine already refreshed. + p.mu.Lock() + defer p.mu.Unlock() + + cacheValid := p.jwksCache != nil && (p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) <= p.jwksTTL) + if !cacheValid { + if err := p.fetchJWKSLocked(ctx); err != nil { return nil, fmt.Errorf("failed to fetch JWKS: %v", err) } } - - // Find the key with matching kid for _, key := range p.jwksCache.Keys { if key.Kid == kid { - return p.parseJWK(&key) + k := key + return p.parseJWK(&k) } } - // Key not found in cache. Refresh JWKS once to handle key rotation and retry. - if err := p.fetchJWKS(ctx); err != nil { + // Key not found in cache. Refresh JWKS once to handle key rotation. + if err := p.fetchJWKSLocked(ctx); err != nil { return nil, fmt.Errorf("failed to refresh JWKS after key miss: %v", err) } for _, key := range p.jwksCache.Keys { if key.Kid == kid { - return p.parseJWK(&key) + k := key + return p.parseJWK(&k) } } return nil, fmt.Errorf("key with ID %s not found in JWKS after refresh", kid) } -// fetchJWKS fetches the JWKS from the provider +// discoveryDocument is the subset of the OpenID Provider Configuration we need. +// See https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata. +type discoveryDocument struct { + Issuer string `json:"issuer"` + JWKSUri string `json:"jwks_uri"` +} + +// resolveJWKSUriLocked determines the JWKS URI for the provider. The caller +// must hold p.mu (write lock); the function reads/writes p.resolvedJWKSUri +// and p.discoveryFailed without taking the lock itself. +// +// Order of resolution: +// 1. explicit config.JWKSUri (operator override; never overridden by discovery). +// 2. cached resolvedJWKSUri from a prior discovery (refreshed when JWKS cache expires). +// 3. .well-known/openid-configuration discovery (per OIDC Discovery 1.0). +// 4. fallback to {issuer}/.well-known/jwks.json (compat path for IDPs that +// don't publish discovery). +func (p *OIDCProvider) resolveJWKSUriLocked(ctx context.Context) (string, error) { + if p.config.JWKSUri != "" { + return p.config.JWKSUri, nil + } + if p.resolvedJWKSUri != "" { + return p.resolvedJWKSUri, nil + } + + issuer := strings.TrimSuffix(p.config.Issuer, "/") + + if !p.discoveryFailed { + discoveryURL := issuer + "/.well-known/openid-configuration" + uri, err := p.fetchDiscoveryJWKSUri(ctx, discoveryURL) + switch { + case err == nil: + p.resolvedJWKSUri = uri + return uri, nil + default: + // Cache the failure so we don't pay the discovery RTT on every refresh. + // Operators with non-discovery IDPs see one failed lookup at startup. + glog.V(3).Infof("OIDC discovery at %s failed (%v); falling back to /.well-known/jwks.json", discoveryURL, err) + p.discoveryFailed = true + } + } + + return issuer + "/.well-known/jwks.json", nil +} + +// fetchDiscoveryJWKSUri retrieves the OIDC discovery document and returns +// the jwks_uri field. The issuer claim in the document must match config.Issuer +// to defend against issuer-substitution attacks during discovery. +func (p *OIDCProvider) fetchDiscoveryJWKSUri(ctx context.Context, discoveryURL string) (string, error) { + req, err := http.NewRequestWithContext(ctx, "GET", discoveryURL, nil) + if err != nil { + return "", fmt.Errorf("create discovery request: %v", err) + } + req.Header.Set("Accept", "application/json") + + resp, err := p.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("fetch discovery document: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("discovery endpoint returned status %d", resp.StatusCode) + } + + var doc discoveryDocument + if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil { + return "", fmt.Errorf("decode discovery document: %v", err) + } + + if doc.JWKSUri == "" { + return "", fmt.Errorf("discovery document missing jwks_uri") + } + + // Issuer must be present and match: a discovery doc that points to a + // different issuer is either a misconfiguration or an attack against + // issuer-confusion, and a doc that omits the issuer field entirely + // would have bypassed the previous check (doc.Issuer != "" guard) and + // silently accepted whatever JWKS URI the document supplied. OIDC + // Discovery 1.0 §3 mandates the issuer field, so treat missing as a + // hard failure. Compare after trimming a single trailing slash on each + // side because real IdPs disagree on whether the configured issuer + // has one. + if strings.TrimSuffix(doc.Issuer, "/") != strings.TrimSuffix(p.config.Issuer, "/") { + return "", fmt.Errorf("discovery issuer %q does not match configured issuer %q", doc.Issuer, p.config.Issuer) + } + + return doc.JWKSUri, nil +} + +// fetchJWKS is a thin wrapper around fetchJWKSLocked that acquires the +// write lock. Used by tests; production callers in getPublicKey already +// hold the lock and call fetchJWKSLocked directly. func (p *OIDCProvider) fetchJWKS(ctx context.Context) error { - jwksURL := p.config.JWKSUri - if jwksURL == "" { - jwksURL = strings.TrimSuffix(p.config.Issuer, "/") + "/.well-known/jwks.json" + p.mu.Lock() + defer p.mu.Unlock() + return p.fetchJWKSLocked(ctx) +} + +// fetchJWKSLocked fetches the JWKS from the provider. The caller must hold +// p.mu (write lock); the function writes p.jwksCache and p.jwksFetchedAt +// without taking the lock itself. +// +// Each fetch reattempts discovery if the previous attempt failed: a +// transient 5xx that flipped discoveryFailed at startup shouldn't lock the +// provider into the fallback path forever. The retry rate is bounded by +// the JWKS TTL (typically 1h), so the discovery RTT cost is amortized. +func (p *OIDCProvider) fetchJWKSLocked(ctx context.Context) error { + if p.config.JWKSUri == "" && p.resolvedJWKSUri == "" { + p.discoveryFailed = false + } + jwksURL, err := p.resolveJWKSUriLocked(ctx) + if err != nil { + return fmt.Errorf("resolve JWKS URI: %v", err) } req, err := http.NewRequestWithContext(ctx, "GET", jwksURL, nil) diff --git a/weed/iam/providers/provider.go b/weed/iam/providers/provider.go index 98b8343fd..f9fac3dae 100644 --- a/weed/iam/providers/provider.go +++ b/weed/iam/providers/provider.go @@ -49,6 +49,11 @@ type ExternalIdentity struct { // Provider is the name of the identity provider Provider string `json:"provider"` + // Issuer is the OIDC `iss` claim (or equivalent) from the source token. + // Stable per (provider, identity) and used together with UserID to derive + // a stable parent-user hash that survives token rotation. + Issuer string `json:"issuer,omitempty"` + // TokenExpiration is the expiration time of the source identity token // This is used to limit session duration to not exceed the token's exp claim TokenExpiration *time.Time `json:"tokenExpiration,omitempty"` diff --git a/weed/iam/sts/parent_user_test.go b/weed/iam/sts/parent_user_test.go new file mode 100644 index 000000000..6b17b4ce4 --- /dev/null +++ b/weed/iam/sts/parent_user_test.go @@ -0,0 +1,65 @@ +package sts + +import ( + "strings" + "testing" + "time" +) + +func TestComputeParentUserStability(t *testing.T) { + // Same (sub, iss) must produce the same hash, regardless of order or + // whitespace mutations callers should never apply. + a := ComputeParentUser("alice", "https://idp.example/") + b := ComputeParentUser("alice", "https://idp.example/") + if a == "" { + t.Fatal("parent user should not be empty for non-empty sub") + } + if a != b { + t.Fatalf("expected stable hash, got %q vs %q", a, b) + } +} + +func TestComputeParentUserDistinguishesIssuer(t *testing.T) { + // The point of incorporating iss is that the same `sub` from two providers + // must not collide. If this assertion ever fails, the hash input is wrong. + a := ComputeParentUser("alice", "https://idp-a.example/") + b := ComputeParentUser("alice", "https://idp-b.example/") + if a == b { + t.Fatalf("hashes for different issuers must differ, both = %q", a) + } +} + +func TestComputeParentUserDistinguishesSubject(t *testing.T) { + a := ComputeParentUser("alice", "https://idp.example/") + b := ComputeParentUser("bob", "https://idp.example/") + if a == b { + t.Fatalf("hashes for different subjects must differ, both = %q", a) + } +} + +func TestComputeParentUserEmptySub(t *testing.T) { + if got := ComputeParentUser("", "https://idp.example/"); got != "" { + t.Fatalf("empty sub should produce empty parent user, got %q", got) + } +} + +func TestComputeParentUserEncoding(t *testing.T) { + got := ComputeParentUser("alice", "https://idp.example/") + // Base64 RawURL has no padding and uses URL-safe alphabet — important + // because parent_user shows up in filer paths and audit log fields. + if strings.ContainsAny(got, "=+/") { + t.Fatalf("parent user should be base64 raw url, got %q", got) + } +} + +func TestSessionClaimsRoundTripParentUser(t *testing.T) { + parent := ComputeParentUser("alice", "https://idp.example/") + claims := NewSTSSessionClaims("sid-1", "issuer", time.Now().Add(time.Hour)). + WithRoleInfo("arn:aws:iam::123:role/r", "arn:aws:sts::123:assumed-role/r/s", "arn:aws:sts::123:assumed-role/r/s"). + WithParentUser(parent) + + info := claims.ToSessionInfo() + if info.ParentUser != parent { + t.Fatalf("ParentUser lost on round-trip: got %q want %q", info.ParentUser, parent) + } +} diff --git a/weed/iam/sts/session_claims.go b/weed/iam/sts/session_claims.go index ce923504f..975a2f9f2 100644 --- a/weed/iam/sts/session_claims.go +++ b/weed/iam/sts/session_claims.go @@ -1,6 +1,8 @@ package sts import ( + "crypto/sha256" + "encoding/base64" "fmt" "time" @@ -8,6 +10,20 @@ import ( "github.com/seaweedfs/seaweedfs/weed/glog" ) +// ComputeParentUser returns a stable per-identity hash derived from the OIDC +// (sub, iss) tuple. Only the (sub, iss) pair is guaranteed stable across token +// refreshes per OpenID Connect Core 1.0 §5.7, so any per-user state (audit +// logs, quotas) must key off this value rather than the access-key or session +// id. The hash is base64-rawurl-encoded SHA-256 over "openid::" so +// it stays filesystem-safe and bounded in length for storage in audit paths. +func ComputeParentUser(sub, iss string) string { + if sub == "" { + return "" + } + h := sha256.Sum256([]byte("openid:" + sub + ":" + iss)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} + // defaultCredentialGenerator is a reusable instance for generating temporary credentials // Reusing a single instance across all calls to ToSessionInfo() reduces allocation overhead // since this method may be called frequently during signature verification @@ -45,6 +61,12 @@ type STSSessionClaims struct { // Session metadata AssumedAt time.Time `json:"assumed_at"` // when role was assumed MaxDuration int64 `json:"max_dur,omitempty"` // maximum session duration in seconds + + // ParentUser is a stable hash of (sub, iss) for tokens minted from an OIDC + // identity. It survives token rotation since only the (sub, iss) tuple is + // guaranteed stable per OpenID Connect Core 1.0. Empty for non-federated + // session types. + ParentUser string `json:"puid,omitempty"` } // NewSTSSessionClaims creates new STS session claims with all required information @@ -96,6 +118,7 @@ func (c *STSSessionClaims) ToSessionInfo() *SessionInfo { ExternalUserId: c.ExternalUserId, ProviderIssuer: c.ProviderIssuer, RequestContext: c.RequestContext, + ParentUser: c.ParentUser, // Provide the Subject (sub) from registered claims Subject: c.Subject, Credentials: credentials, @@ -182,3 +205,10 @@ func (c *STSSessionClaims) WithSessionName(sessionName string) *STSSessionClaims c.SessionName = sessionName return c } + +// WithParentUser sets the stable per-identity hash for the session. See +// ComputeParentUser for the derivation rule. +func (c *STSSessionClaims) WithParentUser(parentUser string) *STSSessionClaims { + c.ParentUser = parentUser + return c +} diff --git a/weed/iam/sts/sts_service.go b/weed/iam/sts/sts_service.go index 0d0481795..cbd46aece 100644 --- a/weed/iam/sts/sts_service.go +++ b/weed/iam/sts/sts_service.go @@ -254,6 +254,9 @@ type SessionInfo struct { // Credentials are the temporary credentials for this session Credentials *Credentials `json:"credentials"` + + // ParentUser is the stable hashed identity (sub+iss) derived at federation time. + ParentUser string `json:"parentUser,omitempty"` } // NewSTSService creates a new STS service @@ -506,13 +509,26 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass // Add sub as well since it's commonly used requestContext["sub"] = externalIdentity.UserID + // Compute a stable parent-user hash from (sub, iss). Only this tuple is + // guaranteed stable across token refresh per OIDC Core 1.0, so this is the + // right key for any per-identity state (audit trail, future quotas). + parentUser := ComputeParentUser(externalIdentity.UserID, externalIdentity.Issuer) + if parentUser != "" { + // Surface as aws:userid so policies can reference it directly without + // caring about token-rotation churn. + requestContext["aws:userid"] = parentUser + } + // Create rich JWT claims with all session information sessionClaims := NewSTSSessionClaims(sessionId, s.Config.Issuer, expiresAt). WithSessionName(request.RoleSessionName). WithRoleInfo(request.RoleArn, assumedRoleUser.Arn, assumedRoleUser.Arn). - WithIdentityProvider(provider.Name(), externalIdentity.UserID, ""). + WithIdentityProvider(provider.Name(), externalIdentity.UserID, externalIdentity.Issuer). WithMaxDuration(sessionDuration). WithRequestContext(requestContext) + if parentUser != "" { + sessionClaims.WithParentUser(parentUser) + } if sessionPolicy != "" { sessionClaims.WithSessionPolicy(sessionPolicy) } diff --git a/weed/s3api/s3api_sts.go b/weed/s3api/s3api_sts.go index c1ec653b5..de9f38c0f 100644 --- a/weed/s3api/s3api_sts.go +++ b/weed/s3api/s3api_sts.go @@ -53,6 +53,34 @@ const ( // federationNameRegex validates the Name parameter for GetFederationToken per AWS spec var federationNameRegex = regexp.MustCompile(`^[\w+=,.@-]+$`) +// roleSessionNameRegex validates RoleSessionName per AWS spec. +// Same character class as federation Name, but the length bounds differ +// (RoleSessionName is 2..64). +var roleSessionNameRegex = regexp.MustCompile(`^[\w+=,.@-]+$`) + +const ( + minRoleSessionNameLen = 2 + maxRoleSessionNameLen = 64 +) + +// validateRoleSessionName enforces the AWS RoleSessionName contract: +// length 2..64, characters [\w+=,.@-]+. Returns the STS error code and a +// descriptive error suitable for callers to surface to the caller. +func validateRoleSessionName(name string) (STSErrorCode, error) { + if name == "" { + return STSErrMissingParameter, fmt.Errorf("RoleSessionName is required") + } + if len(name) < minRoleSessionNameLen || len(name) > maxRoleSessionNameLen { + return STSErrInvalidParameterValue, + fmt.Errorf("RoleSessionName must be between %d and %d characters", minRoleSessionNameLen, maxRoleSessionNameLen) + } + if !roleSessionNameRegex.MatchString(name) { + return STSErrInvalidParameterValue, + fmt.Errorf(`RoleSessionName contains invalid characters; allowed: [\w+=,.@-]`) + } + return "", nil +} + // STS duration constants (AWS specification) const ( minDurationSeconds = int64(900) // 15 minutes @@ -61,6 +89,27 @@ const ( maxFederationDurationSeconds = int64(129600) // 36 hours (GetFederationToken max) ) +// AWS limits inline session policies to 2048 characters for AssumeRole, +// AssumeRoleWithWebIdentity, and AssumeRoleWithSAML. PackedPolicySize is +// returned as a percentage of that budget so callers can detect how close +// they are to the limit. +const sessionPolicyBudgetBytes = 2048 + +// computePackedPolicySize returns the inline session policy size as a +// percentage of the per-action budget, or nil when no session policy was +// provided. Output is bounded to [0, 100] for AWS-compat reporting; the +// actual policy size validation happens upstream in NormalizeSessionPolicy. +func computePackedPolicySize(policyJSON string) *int64 { + if policyJSON == "" { + return nil + } + pct := int64(len(policyJSON)) * 100 / sessionPolicyBudgetBytes + if pct > 100 { + pct = 100 + } + return &pct +} + // parseDurationSecondsWithBounds parses and validates the DurationSeconds parameter // against the given min and max bounds. Returns nil if the parameter is not provided. func parseDurationSecondsWithBounds(r *http.Request, minSec, maxSec int64) (*int64, STSErrorCode, error) { @@ -170,9 +219,8 @@ func (h *STSHandlers) handleAssumeRoleWithWebIdentity(w http.ResponseWriter, r * return } - if roleSessionName == "" { - h.writeSTSErrorResponse(w, r, STSErrMissingParameter, - fmt.Errorf("RoleSessionName is required")) + if errCode, err := validateRoleSessionName(roleSessionName); err != nil { + h.writeSTSErrorResponse(w, r, errCode, err) return } @@ -245,6 +293,7 @@ func (h *STSHandlers) handleAssumeRoleWithWebIdentity(w http.ResponseWriter, r * Expiration: response.Credentials.Expiration.Format(time.RFC3339), }, SubjectFromWebIdentityToken: response.AssumedRoleUser.Subject, + PackedPolicySize: computePackedPolicySize(sessionPolicyJSON), }, } xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) @@ -264,9 +313,8 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) { // Validate required parameters // RoleArn is optional to support S3-compatible clients that omit it - if roleSessionName == "" { - h.writeSTSErrorResponse(w, r, STSErrMissingParameter, - fmt.Errorf("RoleSessionName is required")) + if errCode, err := validateRoleSessionName(roleSessionName); err != nil { + h.writeSTSErrorResponse(w, r, errCode, err) return } @@ -373,8 +421,9 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) { // Build and return response xmlResponse := &AssumeRoleResponse{ Result: AssumeRoleResult{ - Credentials: stsCreds, - AssumedRoleUser: assumedUser, + Credentials: stsCreds, + AssumedRoleUser: assumedUser, + PackedPolicySize: computePackedPolicySize(sessionPolicyJSON), }, } xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) @@ -397,9 +446,8 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r return } - if roleSessionName == "" { - h.writeSTSErrorResponse(w, r, STSErrMissingParameter, - fmt.Errorf("RoleSessionName is required")) + if errCode, err := validateRoleSessionName(roleSessionName); err != nil { + h.writeSTSErrorResponse(w, r, errCode, err) return } @@ -514,8 +562,9 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r // Build and return response xmlResponse := &AssumeRoleWithLDAPIdentityResponse{ Result: LDAPIdentityResult{ - Credentials: stsCreds, - AssumedRoleUser: assumedUser, + Credentials: stsCreds, + AssumedRoleUser: assumedUser, + PackedPolicySize: computePackedPolicySize(sessionPolicyJSON), }, } xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) @@ -906,6 +955,7 @@ type WebIdentityResult struct { Credentials STSCredentials `xml:"Credentials"` SubjectFromWebIdentityToken string `xml:"SubjectFromWebIdentityToken,omitempty"` AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` + PackedPolicySize *int64 `xml:"PackedPolicySize,omitempty"` } // STSCredentials represents temporary security credentials @@ -933,8 +983,9 @@ type AssumeRoleResponse struct { // AssumeRoleResult contains the result of AssumeRole type AssumeRoleResult struct { - Credentials STSCredentials `xml:"Credentials"` - AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` + Credentials STSCredentials `xml:"Credentials"` + AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` + PackedPolicySize *int64 `xml:"PackedPolicySize,omitempty"` } // AssumeRoleWithLDAPIdentityResponse is the response for AssumeRoleWithLDAPIdentity @@ -948,8 +999,9 @@ type AssumeRoleWithLDAPIdentityResponse struct { // LDAPIdentityResult contains the result of AssumeRoleWithLDAPIdentity type LDAPIdentityResult struct { - Credentials STSCredentials `xml:"Credentials"` - AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` + Credentials STSCredentials `xml:"Credentials"` + AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` + PackedPolicySize *int64 `xml:"PackedPolicySize,omitempty"` } // GetCallerIdentityResponse is the response for GetCallerIdentity diff --git a/weed/s3api/sts_packed_policy_test.go b/weed/s3api/sts_packed_policy_test.go new file mode 100644 index 000000000..d2bc0b258 --- /dev/null +++ b/weed/s3api/sts_packed_policy_test.go @@ -0,0 +1,34 @@ +package s3api + +import "testing" + +func TestComputePackedPolicySize(t *testing.T) { + cases := []struct { + name string + policyLen int + empty bool + want int64 + }{ + {"empty -> nil", 0, true, 0}, + {"tiny policy -> 0%", 10, false, 0}, + {"half budget -> 50%", sessionPolicyBudgetBytes / 2, false, 50}, + {"full budget -> 100%", sessionPolicyBudgetBytes, false, 100}, + {"oversized -> capped at 100", sessionPolicyBudgetBytes * 3, false, 100}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + policy := repeat('a', tc.policyLen) + got := computePackedPolicySize(policy) + switch { + case tc.empty: + if got != nil { + t.Fatalf("expected nil for empty input, got %d", *got) + } + case got == nil: + t.Fatalf("expected non-nil result, got nil") + case *got != tc.want: + t.Fatalf("got=%d want=%d", *got, tc.want) + } + }) + } +} diff --git a/weed/s3api/sts_session_name_test.go b/weed/s3api/sts_session_name_test.go new file mode 100644 index 000000000..029cdf943 --- /dev/null +++ b/weed/s3api/sts_session_name_test.go @@ -0,0 +1,47 @@ +package s3api + +import "testing" + +func TestValidateRoleSessionName(t *testing.T) { + cases := []struct { + name string + input string + wantErr bool + // wantCode is checked only when wantErr is true + wantCode STSErrorCode + }{ + {"empty rejected", "", true, STSErrMissingParameter}, + {"single char rejected (below min len 2)", "a", true, STSErrInvalidParameterValue}, + {"min length 2 accepted", "ab", false, ""}, + {"plain ascii accepted", "session-name_1", false, ""}, + {"all special chars allowed", "+=,.@-", false, ""}, + {"email-style accepted", "alice@example.com", false, ""}, + {"max length 64 accepted", string(make([]byte, 64)), true, STSErrInvalidParameterValue}, // zero bytes -> invalid charset + {"max length 64 valid charset accepted", repeat('a', 64), false, ""}, + {"length 65 rejected", repeat('a', 65), true, STSErrInvalidParameterValue}, + {"space rejected", "alice bob", true, STSErrInvalidParameterValue}, + {"slash rejected", "alice/bob", true, STSErrInvalidParameterValue}, + {"colon rejected", "alice:bob", true, STSErrInvalidParameterValue}, + {"unicode rejected", "alicé", true, STSErrInvalidParameterValue}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + code, err := validateRoleSessionName(tc.input) + gotErr := err != nil + if gotErr != tc.wantErr { + t.Fatalf("err mismatch: got=%v want=%v (err=%v)", gotErr, tc.wantErr, err) + } + if tc.wantErr && code != tc.wantCode { + t.Fatalf("code mismatch: got=%s want=%s", code, tc.wantCode) + } + }) + } +} + +func repeat(b byte, n int) string { + out := make([]byte, n) + for i := range out { + out[i] = b + } + return string(out) +}