diff --git a/fasthttpproxy/proxy_env.go b/fasthttpproxy/proxy_env.go index c5abe4a..9d5fe98 100644 --- a/fasthttpproxy/proxy_env.go +++ b/fasthttpproxy/proxy_env.go @@ -6,7 +6,7 @@ import ( "fmt" "net" "net/url" - "sync" + "sync/atomic" "time" "golang.org/x/net/http/httpproxy" @@ -32,16 +32,31 @@ func FasthttpProxyHTTPDialer() fasthttp.DialFunc { // c := &fasthttp.Client{ // Dial: FasthttpProxyHTTPDialerTimeout(time.Second * 2), // } + +const ( + httpsScheme = "https" + tlsPort = "443" +) + func FasthttpProxyHTTPDialerTimeout(timeout time.Duration) fasthttp.DialFunc { proxier := httpproxy.FromEnvironment().ProxyFunc() - // map on proxy URL and its encoded auth barrier - authBarriers := map[*url.URL]string{} - authBarriersLock := sync.RWMutex{} + // encoded auth barrier for http and https proxy. + authHTTPStorage := &atomic.Value{} + authHTTPSStorage := &atomic.Value{} return func(addr string) (net.Conn, error) { - proxyURL, err := proxier(&url.URL{Host: addr}) + port, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("unexpected addr format: %w", err) + } + + reqURL := &url.URL{Host: addr, Scheme: httpsScheme} + if port == tlsPort { + reqURL.Scheme = httpsScheme + } + proxyURL, err := proxier(reqURL) if err != nil { return nil, err } @@ -55,28 +70,30 @@ func FasthttpProxyHTTPDialerTimeout(timeout time.Duration) fasthttp.DialFunc { var conn net.Conn if timeout == 0 { - conn, err = fasthttp.Dial(proxyURL.String()) + conn, err = fasthttp.Dial(proxyURL.Host) } else { - conn, err = fasthttp.DialTimeout(proxyURL.String(), timeout) + conn, err = fasthttp.DialTimeout(proxyURL.Host, timeout) } if err != nil { return nil, err } req := "CONNECT " + addr + " HTTP/1.1\r\n" - if proxyURL.User != nil { - authBarriersLock.RLock() - barrier, ok := authBarriers[proxyURL] - authBarriersLock.RUnlock() - if !ok { - barrier = base64.StdEncoding.EncodeToString([]byte(proxyURL.User.String())) - authBarriersLock.Lock() - authBarriers[proxyURL] = barrier - authBarriersLock.Unlock() + if proxyURL.User != nil { + authBarrierStorage := authHTTPStorage + if port == tlsPort { + authBarrierStorage = authHTTPSStorage } - req += "Proxy-Authorization: Basic " + barrier + "\r\n" + auth := authBarrierStorage.Load() + if auth == nil { + authBarrier := base64.StdEncoding.EncodeToString([]byte(proxyURL.User.String())) + auth := &authBarrier + authBarrierStorage.Store(auth) + } + + req += "Proxy-Authorization: Basic " + *auth.(*string) + "\r\n" } req += "\r\n"