diff --git a/test/s3/remote_cache/Makefile b/test/s3/remote_cache/Makefile index af1537d53..762db65c5 100644 --- a/test/s3/remote_cache/Makefile +++ b/test/s3/remote_cache/Makefile @@ -127,6 +127,7 @@ start-primary: check-deps -webdav.port=$(PRIMARY_WEBDAV_PORT) \ -s3.allowDeleteBucketNotEmpty=true \ -s3.config=s3_config.json \ + -volume.allowUntrustedRemoteEndpoints \ -dir=$(PRIMARY_DIR) \ -ip=127.0.0.1 \ -ip.bind=127.0.0.1 \ diff --git a/test/volume_server/framework/cluster.go b/test/volume_server/framework/cluster.go index 437f17d3e..86d75415d 100644 --- a/test/volume_server/framework/cluster.go +++ b/test/volume_server/framework/cluster.go @@ -232,6 +232,10 @@ func (c *Cluster) startVolume(dataDirs []string) error { "-readMode=" + c.profile.ReadMode, "-concurrentUploadLimitMB=" + strconv.Itoa(c.profile.ConcurrentUploadLimitMB), "-concurrentDownloadLimitMB=" + strconv.Itoa(c.profile.ConcurrentDownloadLimitMB), + // Integration tests deliberately exercise loopback S3 endpoints + // (the test rig boots weed-mini next to the volume server); allow + // the SSRF guard to be bypassed for them. + "-volume.allowUntrustedRemoteEndpoints", } if c.profile.InflightUploadTimeout > 0 { args = append(args, "-inflightUploadDataTimeout="+c.profile.InflightUploadTimeout.String()) diff --git a/test/volume_server/framework/cluster_mixed.go b/test/volume_server/framework/cluster_mixed.go index 6c6b76adf..30a850223 100644 --- a/test/volume_server/framework/cluster_mixed.go +++ b/test/volume_server/framework/cluster_mixed.go @@ -256,6 +256,8 @@ func (c *MixedVolumeCluster) startGoVolume(index int, dataDir string) error { "-readMode=" + c.profile.ReadMode, "-concurrentUploadLimitMB=" + strconv.Itoa(c.profile.ConcurrentUploadLimitMB), "-concurrentDownloadLimitMB=" + strconv.Itoa(c.profile.ConcurrentDownloadLimitMB), + // Integration tests deliberately exercise loopback S3 endpoints; allow the SSRF guard to be bypassed for them. + "-volume.allowUntrustedRemoteEndpoints", } if c.profile.InflightUploadTimeout > 0 { args = append(args, "-inflightUploadDataTimeout="+c.profile.InflightUploadTimeout.String()) diff --git a/test/volume_server/framework/cluster_multi.go b/test/volume_server/framework/cluster_multi.go index 57748bcb0..5a869999b 100644 --- a/test/volume_server/framework/cluster_multi.go +++ b/test/volume_server/framework/cluster_multi.go @@ -227,6 +227,8 @@ func (c *MultiVolumeCluster) startVolume(index int, dataDir string) error { "-readMode=" + c.profile.ReadMode, "-concurrentUploadLimitMB=" + strconv.Itoa(c.profile.ConcurrentUploadLimitMB), "-concurrentDownloadLimitMB=" + strconv.Itoa(c.profile.ConcurrentDownloadLimitMB), + // Integration tests deliberately exercise loopback S3 endpoints; allow the SSRF guard to be bypassed for them. + "-volume.allowUntrustedRemoteEndpoints", } if c.profile.InflightUploadTimeout > 0 { args = append(args, "-inflightUploadDataTimeout="+c.profile.InflightUploadTimeout.String()) diff --git a/weed/command/mini.go b/weed/command/mini.go index 860053752..53fbd4b21 100644 --- a/weed/command/mini.go +++ b/weed/command/mini.go @@ -346,6 +346,7 @@ func initMiniVolumeFlags() { miniOptions.v.inflightDownloadDataTimeout = cmdMini.Flag.Duration("volume.inflightDownloadDataTimeout", 60*time.Second, "inflight download data wait timeout") miniOptions.v.hasSlowRead = cmdMini.Flag.Bool("volume.hasSlowRead", true, "if true, prevents slow reads from blocking other requests") miniOptions.v.readBufferSizeMB = cmdMini.Flag.Int("volume.readBufferSizeMB", 4, "read buffer size in MB") + miniOptions.v.allowUntrustedRemoteEndpoints = cmdMini.Flag.Bool("volume.allowUntrustedRemoteEndpoints", false, "if true, FetchAndWriteNeedle accepts arbitrary remote S3 endpoints including loopback / link-local hosts. Default rejects internal / metadata endpoints.") miniOptions.v.preStopSeconds = cmdMini.Flag.Int("volume.preStopSeconds", 1, "number of seconds between stop send heartbeats and stop volume server (default: 1 for mini)") } diff --git a/weed/command/server.go b/weed/command/server.go index 96e0c1f11..b52ad0f87 100644 --- a/weed/command/server.go +++ b/weed/command/server.go @@ -155,6 +155,7 @@ func init() { serverOptions.v.hasSlowRead = cmdServer.Flag.Bool("volume.hasSlowRead", true, " if true, this prevents slow reads from blocking other requests, but large file read P99 latency will increase.") serverOptions.v.readBufferSizeMB = cmdServer.Flag.Int("volume.readBufferSizeMB", 4, " larger values can optimize query performance but will increase some memory usage,Use with hasSlowRead normally") + serverOptions.v.allowUntrustedRemoteEndpoints = cmdServer.Flag.Bool("volume.allowUntrustedRemoteEndpoints", false, "if true, FetchAndWriteNeedle accepts arbitrary remote S3 endpoints including loopback / link-local hosts. Default rejects internal / metadata endpoints.") s3Options.port = cmdServer.Flag.Int("s3.port", 8333, "s3 server http listen port") s3Options.portHttps = cmdServer.Flag.Int("s3.port.https", 0, "s3 server https listen port") diff --git a/weed/command/volume.go b/weed/command/volume.go index 2b86f16c6..bc38a1f6d 100644 --- a/weed/command/volume.go +++ b/weed/command/volume.go @@ -70,13 +70,14 @@ type VolumeServerOptions struct { metricsHttpPort *int metricsHttpIp *string // pulseSeconds *int - inflightUploadDataTimeout *time.Duration - inflightDownloadDataTimeout *time.Duration - hasSlowRead *bool - readBufferSizeMB *int - ldbTimeout *int64 - debug *bool - debugPort *int + inflightUploadDataTimeout *time.Duration + inflightDownloadDataTimeout *time.Duration + hasSlowRead *bool + readBufferSizeMB *int + ldbTimeout *int64 + allowUntrustedRemoteEndpoints *bool + debug *bool + debugPort *int // shutdownCtx, when non-nil, tells startVolumeServer to shut down once the // ctx is cancelled. Used by integration tests and by weed mini; nil for // standalone weed volume. @@ -120,6 +121,7 @@ func init() { v.inflightDownloadDataTimeout = cmdVolume.Flag.Duration("inflightDownloadDataTimeout", 60*time.Second, "inflight download data wait timeout of volume servers") v.hasSlowRead = cmdVolume.Flag.Bool("hasSlowRead", true, " if true, this prevents slow reads from blocking other requests, but large file read P99 latency will increase.") v.readBufferSizeMB = cmdVolume.Flag.Int("readBufferSizeMB", 4, " larger values can optimize query performance but will increase some memory usage,Use with hasSlowRead normally.") + v.allowUntrustedRemoteEndpoints = cmdVolume.Flag.Bool("volume.allowUntrustedRemoteEndpoints", false, "if true, FetchAndWriteNeedle accepts arbitrary remote S3 endpoints including loopback / link-local hosts. Default rejects internal / metadata endpoints.") v.debug = cmdVolume.Flag.Bool("debug", false, "serves runtime profiling data via pprof on the port specified by -debug.port") v.debugPort = cmdVolume.Flag.Int("debug.port", 6060, "http port for debugging") } @@ -302,6 +304,7 @@ func (v VolumeServerOptions) startVolumeServer(volumeFolders, maxVolumeCounts, v *v.hasSlowRead, *v.readBufferSizeMB, *v.ldbTimeout, + *v.allowUntrustedRemoteEndpoints, ) // starting grpc server grpcS := v.startGrpcService(volumeServer) diff --git a/weed/remote_storage/s3/s3_storage_client.go b/weed/remote_storage/s3/s3_storage_client.go index d4e00ce7f..023d6cb7f 100644 --- a/weed/remote_storage/s3/s3_storage_client.go +++ b/weed/remote_storage/s3/s3_storage_client.go @@ -36,6 +36,14 @@ func (s s3RemoteStorageMaker) HasBucket() bool { } func (s s3RemoteStorageMaker) Make(conf *remote_pb.RemoteConf) (remote_storage.RemoteStorageClient, error) { + return MakeWithHTTPClient(conf, nil) +} + +// MakeWithHTTPClient builds an s3 remote storage client using the supplied +// *http.Client (or the AWS SDK default when nil). Callers that need to pin +// the dial path against DNS rebinding can pass a client whose transport has +// a guarded DialContext. +func MakeWithHTTPClient(conf *remote_pb.RemoteConf, httpClient *http.Client) (remote_storage.RemoteStorageClient, error) { client := &s3RemoteStorageClient{ supportTagging: true, conf: conf, @@ -46,6 +54,9 @@ func (s s3RemoteStorageMaker) Make(conf *remote_pb.RemoteConf) (remote_storage.R S3ForcePathStyle: aws.Bool(conf.S3ForcePathStyle), S3DisableContentMD5Validation: aws.Bool(true), } + if httpClient != nil { + config.HTTPClient = httpClient + } if conf.S3AccessKey != "" && conf.S3SecretKey != "" { config.Credentials = credentials.NewStaticCredentials(conf.S3AccessKey, conf.S3SecretKey, "") } else if conf.S3AccessKey == "" && conf.S3SecretKey == "" { diff --git a/weed/server/volume_grpc_remote.go b/weed/server/volume_grpc_remote.go index 3658cb256..ad6eebe21 100644 --- a/weed/server/volume_grpc_remote.go +++ b/weed/server/volume_grpc_remote.go @@ -3,18 +3,179 @@ package weed_server import ( "context" "fmt" + "net" + "net/http" + "net/url" + "strings" "sync" "time" "github.com/seaweedfs/seaweedfs/weed/operation" "github.com/seaweedfs/seaweedfs/weed/pb/volume_server_pb" "github.com/seaweedfs/seaweedfs/weed/remote_storage" + s3remote "github.com/seaweedfs/seaweedfs/weed/remote_storage/s3" "github.com/seaweedfs/seaweedfs/weed/security" "github.com/seaweedfs/seaweedfs/weed/storage/needle" "github.com/seaweedfs/seaweedfs/weed/storage/types" ) +// lookupIPAddrFunc resolves a host to one or more IP addresses. It is a +// package-level variable so tests can substitute a deterministic resolver. +var lookupIPAddrFunc = net.DefaultResolver.LookupIPAddr + +// blockedIMDSHosts lists hostnames that target cloud instance metadata +// services (IMDS). These are blocked regardless of how they happen to +// resolve, because some environments alias the IMDS address under a name. +var blockedIMDSHosts = map[string]struct{}{ + "metadata.google.internal": {}, + "metadata": {}, +} + +// validateRemoteEndpoint returns an error if the supplied S3 endpoint is not +// safe to dial from a server that has network access to cluster-internal +// hosts. It rejects empty/non-http(s) schemes, loopback/link-local/ +// unspecified addresses, RFC 1918 + CGNAT ranges, and well-known IMDS +// hostnames. Operators that legitimately fetch from private hosts can opt +// out with -volume.allowUntrustedRemoteEndpoints. +func validateRemoteEndpoint(ctx context.Context, endpoint string) error { + if strings.TrimSpace(endpoint) == "" { + return fmt.Errorf("remote endpoint is empty") + } + u, parseErr := url.Parse(endpoint) + if parseErr != nil { + return fmt.Errorf("parse remote endpoint %q: %w", endpoint, parseErr) + } + scheme := strings.ToLower(u.Scheme) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("remote endpoint %q must use http or https, got %q", endpoint, u.Scheme) + } + host := u.Hostname() + if host == "" { + return fmt.Errorf("remote endpoint %q has no host", endpoint) + } + lowerHost := strings.ToLower(host) + if _, ok := blockedIMDSHosts[lowerHost]; ok { + return fmt.Errorf("remote endpoint %q targets instance metadata service", endpoint) + } + if ip := net.ParseIP(host); ip != nil { + if err := checkBlockedIP(endpoint, ip); err != nil { + return err + } + return nil + } + resolveCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + addrs, lookupErr := lookupIPAddrFunc(resolveCtx, host) + if lookupErr != nil { + return fmt.Errorf("resolve remote endpoint host %q: %w", host, lookupErr) + } + for _, addr := range addrs { + if err := checkBlockedIP(endpoint, addr.IP); err != nil { + return err + } + } + return nil +} + +// imdsIPv4 is the AWS/Azure/GCP IPv4 IMDS address. It is link-local and is +// already covered by IsLinkLocalUnicast, but is named explicitly so the +// error message is unambiguous in logs. +var imdsIPv4 = net.ParseIP("169.254.169.254") + +// cgnatNet is the RFC 6598 carrier-grade NAT range (100.64.0.0/10). The +// stdlib's IsPrivate covers RFC 1918 but not CGNAT, so check it explicitly. +var cgnatNet = &net.IPNet{IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)} + +func checkBlockedIP(endpoint string, ip net.IP) error { + if ip == nil { + return nil + } + if ip.Equal(imdsIPv4) { + return fmt.Errorf("remote endpoint %q targets instance metadata service %s", endpoint, ip) + } + switch { + case ip.IsLoopback(): + return fmt.Errorf("remote endpoint %q resolves to loopback address %s", endpoint, ip) + case ip.IsUnspecified(): + return fmt.Errorf("remote endpoint %q resolves to unspecified address %s", endpoint, ip) + case ip.IsLinkLocalUnicast(), ip.IsLinkLocalMulticast(): + return fmt.Errorf("remote endpoint %q resolves to link-local address %s", endpoint, ip) + case ip.IsInterfaceLocalMulticast(): + return fmt.Errorf("remote endpoint %q resolves to interface-local address %s", endpoint, ip) + case ip.IsPrivate(): + return fmt.Errorf("remote endpoint %q resolves to private address %s", endpoint, ip) + case cgnatNet.Contains(ip): + return fmt.Errorf("remote endpoint %q resolves to CGNAT address %s", endpoint, ip) + } + return nil +} + +// guardedDialer returns a DialContext that resolves the host itself and +// re-applies checkBlockedIP to every resolved address immediately before +// dialing. This closes the DNS-rebinding window between +// validateRemoteEndpoint and the actual TCP connect performed by the AWS S3 +// client: even if the attacker's DNS flips to 127.0.0.1 (or any other +// blocked range) after the up-front check, the dial is refused. +func guardedDialer(endpoint string) func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &net.Dialer{Timeout: 30 * time.Second, KeepAlive: 30 * time.Second} + return func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, splitErr := net.SplitHostPort(addr) + if splitErr != nil { + return nil, splitErr + } + // If the host is already a literal IP just validate and dial it. + if ip := net.ParseIP(host); ip != nil { + if err := checkBlockedIP(endpoint, ip); err != nil { + return nil, err + } + return dialer.DialContext(ctx, network, addr) + } + // Otherwise resolve, validate every answer, and dial the first IP + // that passes the deny list. Using a literal-IP target prevents the + // kernel resolver in net.Dialer from looking the name up a second + // time inside Dial and getting a different answer. + addrs, lookupErr := lookupIPAddrFunc(ctx, host) + if lookupErr != nil { + return nil, fmt.Errorf("resolve remote endpoint host %q: %w", host, lookupErr) + } + var firstBlockErr error + for _, a := range addrs { + if err := checkBlockedIP(endpoint, a.IP); err != nil { + if firstBlockErr == nil { + firstBlockErr = err + } + continue + } + return dialer.DialContext(ctx, network, net.JoinHostPort(a.IP.String(), port)) + } + if firstBlockErr != nil { + return nil, firstBlockErr + } + return nil, fmt.Errorf("resolve remote endpoint host %q: no addresses", host) + } +} + +// newGuardedHTTPClient returns an *http.Client whose transport refuses to +// dial addresses that fail checkBlockedIP at connect time. It is meant for +// per-request use; do not share across remote configs. +func newGuardedHTTPClient(endpoint string) *http.Client { + return &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: guardedDialer(endpoint), + ForceAttemptHTTP2: true, + MaxIdleConns: 16, + IdleConnTimeout: 60 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + } +} + func (vs *VolumeServer) FetchAndWriteNeedle(ctx context.Context, req *volume_server_pb.FetchAndWriteNeedleRequest) (resp *volume_server_pb.FetchAndWriteNeedleResponse, err error) { + if err := vs.checkGrpcAdminAuth(ctx); err != nil { + return nil, err + } if err := vs.CheckMaintenanceMode(); err != nil { return nil, err } @@ -27,7 +188,25 @@ func (vs *VolumeServer) FetchAndWriteNeedle(ctx context.Context, req *volume_ser remoteConf := req.RemoteConf - client, getClientErr := remote_storage.GetRemoteStorage(remoteConf) + var client remote_storage.RemoteStorageClient + var getClientErr error + if !vs.AllowUntrustedRemoteEndpoints && remoteConf != nil && remoteConf.Type == "s3" { + // Endpoint validation is S3-specific: only RemoteConf.S3Endpoint + // is a URL the volume server dials directly. Other backends + // (gcs, azure, ...) authenticate against their own SDKs and + // don't accept an attacker-controlled host. + if validateErr := validateRemoteEndpoint(ctx, remoteConf.S3Endpoint); validateErr != nil { + return nil, fmt.Errorf("reject remote endpoint: %w", validateErr) + } + // Build a one-shot S3 client whose dial path re-validates the + // resolved IP every time. This pins the validated endpoint against + // DNS rebinding (a hostname that resolves to a public IP for + // validateRemoteEndpoint and then flips to 127.0.0.1 / 169.254.x.x + // when the AWS SDK dials). + client, getClientErr = s3remote.MakeWithHTTPClient(remoteConf, newGuardedHTTPClient(remoteConf.S3Endpoint)) + } else { + client, getClientErr = remote_storage.GetRemoteStorage(remoteConf) + } if getClientErr != nil { return nil, fmt.Errorf("get remote client: %w", getClientErr) } diff --git a/weed/server/volume_grpc_remote_test.go b/weed/server/volume_grpc_remote_test.go new file mode 100644 index 000000000..9992bebc0 --- /dev/null +++ b/weed/server/volume_grpc_remote_test.go @@ -0,0 +1,270 @@ +package weed_server + +import ( + "context" + "errors" + "net" + "strings" + "sync/atomic" + "testing" +) + +// stubLookup returns a resolver func that maps the supplied hostnames to +// the supplied IP addresses, and errors for any host that is not in the map. +func stubLookup(t *testing.T, mapping map[string][]net.IP) func(ctx context.Context, host string) ([]net.IPAddr, error) { + t.Helper() + return func(_ context.Context, host string) ([]net.IPAddr, error) { + ips, ok := mapping[host] + if !ok { + return nil, &net.DNSError{Err: "no such host", Name: host, IsNotFound: true} + } + out := make([]net.IPAddr, 0, len(ips)) + for _, ip := range ips { + out = append(out, net.IPAddr{IP: ip}) + } + return out, nil + } +} + +func TestValidateRemoteEndpoint(t *testing.T) { + originalLookup := lookupIPAddrFunc + t.Cleanup(func() { lookupIPAddrFunc = originalLookup }) + + lookupIPAddrFunc = stubLookup(t, map[string][]net.IP{ + "s3.us-east-1.amazonaws.com": {net.ParseIP("52.216.10.10")}, + "internal.example.com": {net.ParseIP("127.0.0.1")}, + "linklocal.example.com": {net.ParseIP("169.254.10.20")}, + "private.example.com": {net.ParseIP("10.1.2.3")}, + "private172.example.com": {net.ParseIP("172.20.0.5")}, + "private192.example.com": {net.ParseIP("192.168.1.1")}, + "cgnat.example.com": {net.ParseIP("100.64.0.42")}, + }) + + cases := []struct { + name string + endpoint string + wantErr bool + wantSub string + }{ + { + name: "empty", + endpoint: "", + wantErr: true, + wantSub: "empty", + }, + { + name: "loopback literal", + endpoint: "http://127.0.0.1:8080", + wantErr: true, + wantSub: "loopback", + }, + { + name: "ipv6 loopback", + endpoint: "http://[::1]:8080", + wantErr: true, + wantSub: "loopback", + }, + { + name: "imds ipv4", + endpoint: "http://169.254.169.254/", + wantErr: true, + wantSub: "metadata", + }, + { + name: "unspecified ipv4", + endpoint: "http://0.0.0.0/", + wantErr: true, + wantSub: "unspecified", + }, + { + name: "link-local ipv6", + endpoint: "http://[fe80::1]/", + wantErr: true, + wantSub: "link-local", + }, + { + name: "ftp scheme", + endpoint: "ftp://example.com/", + wantErr: true, + wantSub: "http or https", + }, + { + name: "missing scheme", + endpoint: "example.com/", + wantErr: true, + wantSub: "http or https", + }, + { + name: "imds hostname", + endpoint: "http://metadata.google.internal/", + wantErr: true, + wantSub: "metadata service", + }, + { + name: "imds short hostname", + endpoint: "http://metadata/", + wantErr: true, + wantSub: "metadata service", + }, + { + name: "host resolves to loopback", + endpoint: "https://internal.example.com/", + wantErr: true, + wantSub: "loopback", + }, + { + name: "host resolves to link-local", + endpoint: "https://linklocal.example.com/", + wantErr: true, + wantSub: "link-local", + }, + { + name: "rfc1918 10/8 literal", + endpoint: "http://10.0.0.1/", + wantErr: true, + wantSub: "private", + }, + { + name: "rfc1918 172.16/12 literal", + endpoint: "http://172.16.5.5/", + wantErr: true, + wantSub: "private", + }, + { + name: "rfc1918 192.168/16 literal", + endpoint: "http://192.168.0.1/", + wantErr: true, + wantSub: "private", + }, + { + name: "cgnat literal", + endpoint: "http://100.64.0.1/", + wantErr: true, + wantSub: "CGNAT", + }, + { + name: "host resolves to rfc1918 10/8", + endpoint: "https://private.example.com/", + wantErr: true, + wantSub: "private", + }, + { + name: "host resolves to rfc1918 172/12", + endpoint: "https://private172.example.com/", + wantErr: true, + wantSub: "private", + }, + { + name: "host resolves to rfc1918 192.168/16", + endpoint: "https://private192.example.com/", + wantErr: true, + wantSub: "private", + }, + { + name: "host resolves to cgnat", + endpoint: "https://cgnat.example.com/", + wantErr: true, + wantSub: "CGNAT", + }, + { + name: "public s3", + endpoint: "https://s3.us-east-1.amazonaws.com/", + wantErr: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := validateRemoteEndpoint(context.Background(), tc.endpoint) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error for %q, got nil", tc.endpoint) + } + if tc.wantSub != "" && !strings.Contains(err.Error(), tc.wantSub) { + t.Fatalf("expected error to contain %q, got %v", tc.wantSub, err) + } + return + } + if err != nil { + t.Fatalf("unexpected error for %q: %v", tc.endpoint, err) + } + }) + } +} + +func TestValidateRemoteEndpointResolverFailure(t *testing.T) { + originalLookup := lookupIPAddrFunc + t.Cleanup(func() { lookupIPAddrFunc = originalLookup }) + + resolveErr := errors.New("simulated DNS failure") + lookupIPAddrFunc = func(_ context.Context, _ string) ([]net.IPAddr, error) { + return nil, resolveErr + } + + err := validateRemoteEndpoint(context.Background(), "https://does-not-resolve.example.com/") + if err == nil { + t.Fatal("expected error when resolver fails") + } + if !strings.Contains(err.Error(), "resolve remote endpoint host") { + t.Fatalf("expected resolver error wrapping, got %v", err) + } +} + +// TestGuardedDialerRebind simulates a DNS rebinding attack: the host first +// resolves to a public address (passing validateRemoteEndpoint) and then +// flips to 127.0.0.1 on the very next lookup (what the AWS SDK would do at +// dial time). The dial path must refuse the loopback answer instead of +// connecting to it. +func TestGuardedDialerRebind(t *testing.T) { + originalLookup := lookupIPAddrFunc + t.Cleanup(func() { lookupIPAddrFunc = originalLookup }) + + const host = "rebind.example.com" + endpoint := "https://" + host + "/" + + var calls atomic.Int32 + lookupIPAddrFunc = func(_ context.Context, name string) ([]net.IPAddr, error) { + if name != host { + return nil, &net.DNSError{Err: "no such host", Name: name, IsNotFound: true} + } + if calls.Add(1) == 1 { + return []net.IPAddr{{IP: net.ParseIP("52.216.10.10")}}, nil + } + return []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}}, nil + } + + if err := validateRemoteEndpoint(context.Background(), endpoint); err != nil { + t.Fatalf("first-pass validation should accept public IP, got %v", err) + } + + dial := guardedDialer(endpoint) + conn, err := dial(context.Background(), "tcp", host+":443") + if conn != nil { + conn.Close() + t.Fatalf("guarded dialer must refuse loopback rebind, got conn") + } + if err == nil || !strings.Contains(err.Error(), "loopback") { + t.Fatalf("guarded dialer should fail with loopback error, got %v", err) + } +} + +// TestGuardedDialerLiteralBlocked confirms that a literal blocked IP target +// is refused without any DNS lookup. +func TestGuardedDialerLiteralBlocked(t *testing.T) { + originalLookup := lookupIPAddrFunc + t.Cleanup(func() { lookupIPAddrFunc = originalLookup }) + lookupIPAddrFunc = func(_ context.Context, name string) ([]net.IPAddr, error) { + t.Fatalf("resolver should not be called for IP literal target, got lookup of %q", name) + return nil, nil + } + + dial := guardedDialer("http://10.0.0.5:80") + conn, err := dial(context.Background(), "tcp", "10.0.0.5:80") + if conn != nil { + conn.Close() + t.Fatalf("guarded dialer must refuse rfc1918 literal, got conn") + } + if err == nil || !strings.Contains(err.Error(), "private") { + t.Fatalf("guarded dialer should fail with private-address error, got %v", err) + } +} diff --git a/weed/server/volume_server.go b/weed/server/volume_server.go index 7bab70276..974156e28 100644 --- a/weed/server/volume_server.go +++ b/weed/server/volume_server.go @@ -43,17 +43,18 @@ type VolumeServer struct { guard *security.Guard grpcDialOption grpc.DialOption - needleMapKind storage.NeedleMapKind - ldbTimout int64 - FixJpgOrientation bool - ReadMode string - compactionBytePerSecond int64 - maintenanceBytePerSecond int64 - metricsAddress string - metricsIntervalSec int - fileSizeLimitBytes int64 - isHeartbeating bool - stopChan chan bool + needleMapKind storage.NeedleMapKind + ldbTimout int64 + FixJpgOrientation bool + ReadMode string + AllowUntrustedRemoteEndpoints bool + compactionBytePerSecond int64 + maintenanceBytePerSecond int64 + metricsAddress string + metricsIntervalSec int + fileSizeLimitBytes int64 + isHeartbeating bool + stopChan chan bool } func NewVolumeServer(adminMux, publicMux *http.ServeMux, ip string, @@ -76,6 +77,7 @@ func NewVolumeServer(adminMux, publicMux *http.ServeMux, ip string, hasSlowRead bool, readBufferSizeMB int, ldbTimeout int64, + allowUntrustedRemoteEndpoints bool, ) *VolumeServer { v := util.GetViper() @@ -111,6 +113,7 @@ func NewVolumeServer(adminMux, publicMux *http.ServeMux, ip string, readBufferSizeMB: readBufferSizeMB, ldbTimout: ldbTimeout, whiteList: whiteList, + AllowUntrustedRemoteEndpoints: allowUntrustedRemoteEndpoints, } whiteList = append(whiteList, util.StringSplit(v.GetString("guard.white_list"), ",")...)