// SPDX-License-Identifier: BSD-2-Clause // // Copyright (c) 2025 The FreeBSD Foundation. // // This software was developed by Hayzam Sherif // of Alchemilla Ventures Pvt. Ltd. , // under sponsorship from the FreeBSD Foundation. package utils import ( "bytes" "compress/gzip" "context" "crypto/tls" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net" "net/http" "strconv" "strings" "sync" "time" "github.com/gin-gonic/gin" ) var ( once sync.Once sharedClient *http.Client ) func GetTokenFromHeader(r http.Header) (string, error) { token := r.Get("Authorization") if token != "" { if len(token) < 8 || !strings.HasPrefix(token, "Bearer ") { return "", fmt.Errorf("invalid authorization header format") } return RemoveSpaces(token[7:]), nil } wsProtocol := r.Get("Sec-WebSocket-Protocol") if wsProtocol != "" { parts := strings.Split(wsProtocol, ",") if len(parts) == 2 && strings.TrimSpace(parts[0]) == "Bearer" { return RemoveSpaces(strings.TrimSpace(parts[1])), nil } return "", errors.New("invalid websocket protocol header format") } return "", errors.New("no token provided") } func GetClusterTokenFromHeader(r http.Header) (string, error) { if v := r.Get("ClusterToken"); v != "" { if len(v) < 8 || !strings.HasPrefix(v, "Bearer ") { return "", fmt.Errorf("invalid ClusterToken header format") } return RemoveSpaces(v[7:]), nil } if v := r.Get("X-Cluster-Authorization"); v != "" { if len(v) < 8 || !strings.HasPrefix(v, "Bearer ") { return "", fmt.Errorf("invalid X-Cluster-Authorization header format") } return RemoveSpaces(v[7:]), nil } if v := r.Get("X-Cluster-Token"); v != "" { if len(v) < 8 || !strings.HasPrefix(v, "Bearer ") { return "", fmt.Errorf("invalid X-Cluster-Token header format") } return RemoveSpaces(v[7:]), nil } if v := r.Get("Sec-WebSocket-Protocol"); v != "" { text := RemoveSpaces(v) data, err := hex.DecodeString(text) if err != nil { return "", fmt.Errorf("failed to decode hex: %w", err) } var obj struct { Hostname string `json:"hostname"` Token string `json:"token"` } if err := json.Unmarshal(data, &obj); err != nil { return "", fmt.Errorf("failed to unmarshal json: %w", err) } if obj.Token == "" { return "", errors.New("no_token_provided") } return obj.Token, nil } return "", errors.New("no cluster token provided") } func GetCurrentHostnameFromHeader(r http.Header, rC *http.Request) (string, error) { if v := r.Get("X-Current-Hostname"); v != "" { return RemoveSpaces(v), nil } if v := r.Get("Sec-WebSocket-Protocol"); v != "" { text := RemoveSpaces(v) data, err := hex.DecodeString(text) if err != nil { return "", fmt.Errorf("failed to decode hex: %w", err) } var obj struct { Hostname string `json:"hostname"` Token string `json:"token"` } if err := json.Unmarshal(data, &obj); err != nil { return "", fmt.Errorf("failed to unmarshal json: %w", err) } if obj.Hostname == "" { return "", errors.New("no_current_hostname_provided") } return obj.Hostname, nil } if v := rC.URL.Query().Get("auth"); v != "" { text := RemoveSpaces(v) data, err := hex.DecodeString(text) if err != nil { return "", fmt.Errorf("failed to decode hex: %w", err) } var obj struct { Hostname string `json:"hostname"` Token string `json:"token"` } if err := json.Unmarshal(data, &obj); err != nil { return "", fmt.Errorf("failed to unmarshal json: %w", err) } if obj.Hostname == "" { return "", errors.New("no_current_hostname_provided") } return obj.Hostname, nil } return "", errors.New("no_current_hostname_provided") } func GetIdFromParam(c *gin.Context) (int, error) { idStr := c.Param("id") id, err := strconv.Atoi(idStr) if err != nil { return 0, err } return id, nil } func FlatHeaders(c *gin.Context) map[string]string { var flatHeaders = make(map[string]string) for key, value := range c.Request.Header { if len(value) > 0 { flatHeaders[key] = value[0] } } return flatHeaders } func intraClusterClient() *http.Client { once.Do(func() { tr := &http.Transport{ MaxIdleConns: 200, MaxIdleConnsPerHost: 100, MaxConnsPerHost: 100, IdleConnTimeout: 90 * time.Second, DisableKeepAlives: false, DialContext: (&net.Dialer{ Timeout: 5 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, TLSHandshakeTimeout: 5 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } sharedClient = &http.Client{ Timeout: 8 * time.Second, Transport: tr, } }) return sharedClient } func HTTPPostJSON(url string, payload any, headers map[string]string) error { body, err := json.Marshal(payload) if err != nil { return err } req, err := http.NewRequest("POST", url, bytes.NewReader(body)) if err != nil { return err } for k, v := range headers { req.Header.Set(k, v) } if req.Header.Get("Accept-Encoding") == "" { req.Header.Set("Accept-Encoding", "gzip") } if req.Header.Get("Content-Type") == "" { req.Header.Set("Content-Type", "application/json") } resp, err := intraClusterClient().Do(req) if err != nil { return err } defer resp.Body.Close() var reader io.ReadCloser switch resp.Header.Get("Content-Encoding") { case "gzip": gz, err := gzip.NewReader(resp.Body) if err != nil { return fmt.Errorf("failed to create gzip reader: %w", err) } defer gz.Close() reader = gz default: reader = resp.Body } if resp.StatusCode < 200 || resp.StatusCode >= 300 { respBody, _ := io.ReadAll(reader) return fmt.Errorf("http error %d: %s", resp.StatusCode, string(respBody)) } return nil } func HTTPPostJSONRead(url string, payload any, headers map[string]string) ([]byte, int, error) { body, err := json.Marshal(payload) if err != nil { return nil, 0, err } req, err := http.NewRequest("POST", url, bytes.NewReader(body)) if err != nil { return nil, 0, err } for k, v := range headers { req.Header.Set(k, v) } if req.Header.Get("Accept-Encoding") == "" { req.Header.Set("Accept-Encoding", "gzip") } if req.Header.Get("Content-Type") == "" { req.Header.Set("Content-Type", "application/json") } resp, err := intraClusterClient().Do(req) if err != nil { return nil, 0, err } defer resp.Body.Close() var reader io.ReadCloser if resp.Header.Get("Content-Encoding") == "gzip" { gz, err := gzip.NewReader(resp.Body) if err != nil { return nil, resp.StatusCode, fmt.Errorf("failed to create gzip reader: %w", err) } defer gz.Close() reader = gz } else { reader = resp.Body } b, err := io.ReadAll(reader) if err != nil { return nil, resp.StatusCode, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, resp.StatusCode, fmt.Errorf("http error %d: %s", resp.StatusCode, string(b)) } return b, resp.StatusCode, nil } func HTTPGetJSONRead(url string, headers map[string]string) ([]byte, int, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, 0, err } for k, v := range headers { req.Header.Set(k, v) } if req.Header.Get("Accept-Encoding") == "" { req.Header.Set("Accept-Encoding", "gzip") } if req.Header.Get("Accept") == "" { req.Header.Set("Accept", "application/json") } resp, err := intraClusterClient().Do(req) if err != nil { return nil, 0, err } defer resp.Body.Close() var reader io.ReadCloser if resp.Header.Get("Content-Encoding") == "gzip" { gz, err := gzip.NewReader(resp.Body) if err != nil { return nil, resp.StatusCode, fmt.Errorf("failed to create gzip reader: %w", err) } defer gz.Close() reader = gz } else { reader = resp.Body } data, err := io.ReadAll(reader) if err != nil { return nil, resp.StatusCode, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, resp.StatusCode, fmt.Errorf("http error %d: %s", resp.StatusCode, string(data)) } return data, resp.StatusCode, nil } func HTTPGetStatus(url string, headers map[string]string) (int, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return 0, err } for k, v := range headers { req.Header.Set(k, v) } resp, err := intraClusterClient().Do(req) if err != nil { return 0, err } defer resp.Body.Close() _, _ = io.Copy(io.Discard, resp.Body) if resp.StatusCode != http.StatusOK { return resp.StatusCode, fmt.Errorf("http error %d", resp.StatusCode) } return resp.StatusCode, nil } func ParamUint(c *gin.Context, name string) (uint, error) { param := c.Param(name) value, err := strconv.ParseUint(param, 10, 32) if err != nil { return 0, fmt.Errorf("invalid %s parameter: %w", name, err) } return uint(value), nil } func HTTPPostJSONWithTimeout(url string, payload []byte, headers map[string]string, timeout time.Duration) ([]byte, int, error) { if timeout <= 0 { timeout = 8 * time.Second } ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) if err != nil { return nil, 0, err } for k, v := range headers { req.Header.Set(k, v) } if req.Header.Get("Accept-Encoding") == "" { req.Header.Set("Accept-Encoding", "gzip") } if req.Header.Get("Content-Type") == "" { req.Header.Set("Content-Type", "application/json") } client := intraClusterClient() clientCopy := *client clientCopy.Timeout = 0 resp, err := clientCopy.Do(req) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return nil, 0, fmt.Errorf("request timed out after %v", timeout) } return nil, 0, err } defer resp.Body.Close() var reader io.ReadCloser if resp.Header.Get("Content-Encoding") == "gzip" { gz, err := gzip.NewReader(resp.Body) if err != nil { return nil, resp.StatusCode, fmt.Errorf("failed to create gzip reader: %w", err) } defer gz.Close() reader = gz } else { reader = resp.Body } respBody, err := io.ReadAll(reader) if err != nil { return nil, resp.StatusCode, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, resp.StatusCode, fmt.Errorf("http error %d: %s", resp.StatusCode, string(respBody)) } return respBody, resp.StatusCode, nil }