diff --git a/server.go b/server.go index ca24e64..33f287a 100644 --- a/server.go +++ b/server.go @@ -64,6 +64,17 @@ func ServeTLS(ln net.Listener, certFile, keyFile string, handler RequestHandler) return s.ServeTLS(ln, certFile, keyFile) } +// ServeTLSEmbed serves HTTPS requests from the given net.Listener +// using the given handler. +// +// certData and keyData must contain valid TLS certificate and key data. +func ServeTLSEmbed(ln net.Listener, certData, keyData []byte, handler RequestHandler) error { + s := &Server{ + Handler: handler, + } + return s.ServeTLSEmbed(ln, certData, keyData) +} + // ListenAndServe serves HTTP requests from the given TCP addr // using the given handler. func ListenAndServe(addr string, handler RequestHandler) error { @@ -97,6 +108,17 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler RequestHandler) e return s.ListenAndServeTLS(addr, certFile, keyFile) } +// ListenAndServeTLSEmbed serves HTTPS requests from the given TCP addr +// using the given handler. +// +// certData and keyData must contain valid TLS certificate and key data. +func ListenAndServeTLSEmbed(addr string, certData, keyData []byte, handler RequestHandler) error { + s := &Server{ + Handler: handler, + } + return s.ListenAndServeTLSEmbed(addr, certData, keyData) +} + // RequestHandler must process incoming requests. // // RequestHandler must call ctx.TimeoutError() before returning @@ -1066,6 +1088,17 @@ func (s *Server) ListenAndServeTLS(addr, certFile, keyFile string) error { return s.ServeTLS(ln, certFile, keyFile) } +// ListenAndServeTLSEmbed serves HTTPS requests from the given TCP addr. +// +// certData and keyData must contain valid TLS certificate and key data. +func (s *Server) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error { + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return s.ServeTLSEmbed(ln, certData, keyData) +} + // ServeTLS serves HTTPS requests from the given listener. // // certFile and keyFile are paths to TLS certificate and key files. @@ -1077,15 +1110,39 @@ func (s *Server) ServeTLS(ln net.Listener, certFile, keyFile string) error { return s.Serve(lnTLS) } +// ServeTLSEmbed serves HTTPS requests from the given listener. +// +// certData and keyData must contain valid TLS certificate and key data. +func (s *Server) ServeTLSEmbed(ln net.Listener, certData, keyData []byte) error { + lnTLS, err := newTLSListenerEmbed(ln, certData, keyData) + if err != nil { + return err + } + return s.Serve(lnTLS) +} + func newTLSListener(ln net.Listener, certFile, keyFile string) (net.Listener, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, fmt.Errorf("cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err) } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, + return newCertListener(ln, &cert), nil +} + +func newTLSListenerEmbed(ln net.Listener, certData, keyData []byte) (net.Listener, error) { + cert, err := tls.X509KeyPair(certData, keyData) + if err != nil { + return nil, fmt.Errorf("cannot load TLS key pair from the provided certData(%d) and keyData(%d): %s", + len(certData), len(keyData), err) } - return tls.NewListener(ln, tlsConfig), nil + return newCertListener(ln, &cert), nil +} + +func newCertListener(ln net.Listener, cert *tls.Certificate) net.Listener { + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{*cert}, + } + return tls.NewListener(ln, tlsConfig) } // DefaultConcurrency is the maximum number of concurrent connections diff --git a/server_test.go b/server_test.go index d6356f9..63be3c2 100644 --- a/server_test.go +++ b/server_test.go @@ -3,6 +3,7 @@ package fasthttp import ( "bufio" "bytes" + "crypto/tls" "fmt" "io" "io/ioutil" @@ -15,6 +16,78 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) +func TestServerServeTLSEmbed(t *testing.T) { + ln := fasthttputil.NewInmemoryListener() + + certFile := "./ssl-cert-snakeoil.pem" + keyFile := "./ssl-cert-snakeoil.key" + + certData, err := ioutil.ReadFile(certFile) + if err != nil { + t.Fatalf("unexpected error when reading %q: %s", certFile, err) + } + keyData, err := ioutil.ReadFile(keyFile) + if err != nil { + t.Fatalf("unexpected error when reading %q: %s", keyFile, err) + } + + // start the server + ch := make(chan struct{}) + go func() { + err := ServeTLSEmbed(ln, certData, keyData, func(ctx *RequestCtx) { + ctx.WriteString("success") + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + close(ch) + }() + + // establish connection to the server + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tlsConn := tls.Client(conn, &tls.Config{ + InsecureSkipVerify: true, + }) + + // send request + if _, err = tlsConn.Write([]byte("GET / HTTP/1.1\r\nHost: aaa\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + // read response + respCh := make(chan struct{}) + go func() { + br := bufio.NewReader(tlsConn) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error") + } + body := resp.Body() + if string(body) != "success" { + t.Fatalf("unexpected response body %q. Expecting %q", body, "success") + } + close(respCh) + }() + select { + case <-respCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + // close the server + if err = ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("timeout") + } +} + func TestServerMultipartFormDataRequest(t *testing.T) { reqS := `POST /upload HTTP/1.1 Host: qwerty.com