diff --git a/client.go b/client.go index 1fa072d..9a0ded6 100644 --- a/client.go +++ b/client.go @@ -446,6 +446,10 @@ func (c *Client) DoRedirects(req *Request, resp *Response, maxRedirectsCount int // and AcquireResponse in performance-critical code. func (c *Client) Do(req *Request, resp *Response) error { uri := req.URI() + if uri == nil { + return ErrorInvalidURI + } + host := uri.Host() isTLS := false @@ -910,7 +914,9 @@ func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedir for { req.SetRequestURI(url) - req.parseURI() + if err := req.parseURI(); err != nil { + return 0, nil, err + } if err = c.Do(req, resp); err != nil { break diff --git a/http.go b/http.go index a2a0c95..28ba09d 100644 --- a/http.go +++ b/http.go @@ -729,13 +729,13 @@ func (req *Request) URI() *URI { return &req.uri } -func (req *Request) parseURI() { +func (req *Request) parseURI() error { if req.parsedURI { - return + return nil } req.parsedURI = true - req.uri.parse(req.Header.Host(), req.Header.RequestURI(), req.isTLS) + return req.uri.parse(req.Header.Host(), req.Header.RequestURI(), req.isTLS) } // PostArgs returns POST arguments. diff --git a/uri.go b/uri.go index a8b915a..0e6eb1b 100644 --- a/uri.go +++ b/uri.go @@ -2,6 +2,7 @@ package fasthttp import ( "bytes" + "errors" "io" "sync" ) @@ -250,21 +251,25 @@ func (u *URI) SetHostBytes(host []byte) { lowercaseBytes(u.host) } +var ( + ErrorInvalidURI = errors.New("invalid uri") +) + // Parse initializes URI from the given host and uri. // // host may be nil. In this case uri must contain fully qualified uri, // i.e. with scheme and host. http is assumed if scheme is omitted. // // uri may contain e.g. RequestURI without scheme and host if host is non-empty. -func (u *URI) Parse(host, uri []byte) { - u.parse(host, uri, false) +func (u *URI) Parse(host, uri []byte) error { + return u.parse(host, uri, false) } -func (u *URI) parse(host, uri []byte, isTLS bool) { +func (u *URI) parse(host, uri []byte, isTLS bool) error { u.Reset() if stringContainsCTLByte(uri) { - return + return ErrorInvalidURI } if len(host) == 0 || bytes.Contains(uri, strColonSlashSlash) { @@ -306,7 +311,7 @@ func (u *URI) parse(host, uri []byte, isTLS bool) { if queryIndex < 0 && fragmentIndex < 0 { u.pathOriginal = append(u.pathOriginal, b...) u.path = normalizePath(u.path, u.pathOriginal) - return + return nil } if queryIndex >= 0 { @@ -320,7 +325,7 @@ func (u *URI) parse(host, uri []byte, isTLS bool) { u.queryString = append(u.queryString, b[queryIndex+1:fragmentIndex]...) u.hash = append(u.hash, b[fragmentIndex+1:]...) } - return + return nil } // fragmentIndex >= 0 && queryIndex < 0 @@ -328,6 +333,8 @@ func (u *URI) parse(host, uri []byte, isTLS bool) { u.pathOriginal = append(u.pathOriginal, b[:fragmentIndex]...) u.path = normalizePath(u.path, u.pathOriginal) u.hash = append(u.hash, b[fragmentIndex+1:]...) + + return nil } func normalizePath(dst, src []byte) []byte { @@ -470,7 +477,9 @@ func (u *URI) updateBytes(newURI, buf []byte) []byte { if len(u.scheme) > 0 { schemeOriginal = append([]byte(nil), u.scheme...) } - u.Parse(nil, newURI) + if err := u.Parse(nil, newURI); err != nil { + return nil + } if len(schemeOriginal) > 0 && len(u.scheme) == 0 { u.scheme = append(u.scheme[:0], schemeOriginal...) } @@ -481,7 +490,9 @@ func (u *URI) updateBytes(newURI, buf []byte) []byte { // uri without host buf = u.appendSchemeHost(buf[:0]) buf = append(buf, newURI...) - u.Parse(nil, buf) + if err := u.Parse(nil, buf); err != nil { + return nil + } return buf } @@ -505,7 +516,9 @@ func (u *URI) updateBytes(newURI, buf []byte) []byte { buf = u.appendSchemeHost(buf[:0]) buf = appendQuotedPath(buf, path[:n+1]) buf = append(buf, newURI...) - u.Parse(nil, buf) + if err := u.Parse(nil, buf); err != nil { + return nil + } return buf } }