From 262ea09c9edbee3fcbedeccfa1441074e06cb2ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Sat, 2 May 2026 13:12:08 +0200 Subject: [PATCH] feat(prefork): graceful shutdown, leak fixes, hook robustness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses outstanding review concerns and several adjacent issues surfaced during a follow-up review pass. Lifecycle / supervision - Track every per-child Wait goroutine via sync.WaitGroup and unblock pending sigCh sends through a context.Cancel so early-return paths (OnChildSpawn / OnMasterReady error, recovery doCommand error, ErrOverRecovery) can no longer leak goroutines or stall children. - Install signal.Notify(SIGTERM, SIGINT) in the master so deploy/ rolling-restart signals enter the shutdown path instead of killing the master without graceful teardown. - Replace the unconditional SIGKILL defer with a SIGTERM-then-SIGKILL sequence gated by a configurable ShutdownGracePeriod (defaults to 5s, Windows path stays SIGKILL since Signal(SIGTERM) is unsupported). API - OnChildRecover now returns error so callers can implement recovery policies (circuit-breaker etc.); panic in any hook is recovered and surfaced as the returned error, with diagnostic logging. - Add RecoverInterval (optional crash-loop backoff) and ShutdownGracePeriod fields with safe zero-value defaults. - Export ErrCommandProducerNilCmd and ErrCommandProducerNotStarted sentinel errors so callers can errors.Is them. - Rename oldPid/newPid to oldPID/newPID per Go initialism convention. - Logger interface now declares an explicit compile-time compatibility check with fasthttp.Logger. Resource hygiene - Master closes both the original tcpListener and the duped fd in p.files when prefork() returns; previously the duped fd leaked once per call. - doCommand wraps every error path with %w + fmt.Errorf so caller-side diagnostics keep stage context. - Strip pre-existing FASTHTTP_PREFORK_CHILD entries before appending so child env never carries duplicate keys. - Extract magic numbers as package constants (inheritedListenerFD, masterPollInterval, defaultShutdownGracePeriod, preforkChildEnvValue). - Rename the inherited listener fd via os.NewFile so net.FileListener errors are diagnosable. Tests - Migrate to t.Setenv (drop the global setUp/tearDown helpers) — fixes the env-mutation-vs-parallel race. - Replace rand.Intn port helper with `:0` + Listener.Addr() to remove port-collision flakes under -count and parallel runs. - Collapse the three near-identical Test_ListenAndServe* tests into a single table-driven subtest that actually asserts the args forwarded to ServeFunc/ServeTLSFunc/ServeTLSEmbedFunc. - Add coverage for the previously untested branches: CommandProducer returning err / nil cmd / unstarted cmd, initial OnChildSpawn error, OnMasterReady error, hook panic surfacing, RecoverInterval enforcement. - noopChildProducer helper kills + waits any spawned child binaries during cleanup so failed tests no longer leave subprocesses around. --- prefork/prefork.go | 454 ++++++++++++++++++++++++---------- prefork/prefork_test.go | 534 ++++++++++++++++++++++++++-------------- 2 files changed, 674 insertions(+), 314 deletions(-) diff --git a/prefork/prefork.go b/prefork/prefork.go index 52a7a88..1964ddf 100644 --- a/prefork/prefork.go +++ b/prefork/prefork.go @@ -2,12 +2,17 @@ package prefork import ( + "context" "errors" + "fmt" "log" "net" "os" "os/exec" + "os/signal" "runtime" + "sync" + "syscall" "time" "github.com/valyala/fasthttp" @@ -16,25 +21,53 @@ import ( const ( preforkChildEnvVariable = "FASTHTTP_PREFORK_CHILD" + preforkChildEnvValue = "1" defaultNetwork = "tcp4" + + // inheritedListenerFD is the file descriptor used by the master to pass + // the bound listener to a child process via ExtraFiles. Children open the + // listener via os.NewFile(inheritedListenerFD, ...) when Reuseport is false. + inheritedListenerFD = 3 + + // masterPollInterval is the period of the watchMaster ppid-poll on Unix. + masterPollInterval = 500 * time.Millisecond + + // defaultShutdownGracePeriod is how long the master waits for children to + // exit cleanly after sending SIGTERM before forcibly killing them. + defaultShutdownGracePeriod = 5 * time.Second ) var ( defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags)) - // ErrOverRecovery is returned when the times of starting over child prefork processes exceed - // the threshold. + + // ErrOverRecovery is returned when child prefork process restarts exceed + // the value of RecoverThreshold. ErrOverRecovery = errors.New("exceeding the value of RecoverThreshold") - // ErrOnlyReuseportOnWindows is returned when Reuseport is false. + // ErrOnlyReuseportOnWindows is returned when running on Windows without Reuseport. ErrOnlyReuseportOnWindows = errors.New("windows only supports Reuseport = true") + + // ErrCommandProducerNilCmd is returned when a CommandProducer returns + // (nil, nil) instead of a started command. + ErrCommandProducerNilCmd = errors.New("prefork: CommandProducer returned nil command") + + // ErrCommandProducerNotStarted is returned when a CommandProducer returns + // an *exec.Cmd whose Process is nil (i.e. cmd.Start() was not called). + ErrCommandProducerNotStarted = errors.New("prefork: CommandProducer must return a started command") ) -// Logger is used for logging formatted messages. +// Logger is used for logging formatted messages. Its method set is intentionally +// identical to fasthttp.Logger so that *fasthttp.Server.Logger can be assigned +// directly. type Logger interface { // Printf must have the same semantics as log.Printf. Printf(format string, args ...any) } +// Compile-time check that fasthttp.Logger satisfies the local Logger interface; +// keeps the two types in sync if either side ever evolves. +var _ Logger = fasthttp.Logger(nil) + // Prefork implements fasthttp server prefork. // // Preforks master process (with all cores) between several child processes @@ -44,7 +77,8 @@ type Logger interface { // WARNING: using prefork prevents the use of any global state! // Things like in-memory caches won't work. type Prefork struct { - // By default standard logger from log package is used. + // Logger receives diagnostic output. By default the standard log package + // logger writing to stderr is used. Logger Logger ln net.Listener @@ -53,21 +87,29 @@ type Prefork struct { ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error ServeTLSEmbedFunc func(ln net.Listener, certData, keyData []byte) error - // The network must be "tcp", "tcp4" or "tcp6". - // - // By default is "tcp4" + // Network must be "tcp", "tcp4" or "tcp6". Default is "tcp4". Network string files []*os.File - // Child prefork processes may exit with failure and will be started over until the times reach - // the value of RecoverThreshold, then it will return and terminate the server. + // RecoverThreshold caps how often crashed children are respawned before + // the master returns ErrOverRecovery. New() sets it to max(1, GOMAXPROCS/2). + // When constructing a Prefork directly without New(), a zero value will + // terminate the master after the very first child crash. RecoverThreshold int - // Flag to use a listener with reuseport, if not a file Listener will be used + // RecoverInterval, when > 0, makes the master sleep for the given duration + // before respawning a crashed child. Useful as crash-loop backoff. + RecoverInterval time.Duration + + // ShutdownGracePeriod is the time the master waits for children to exit + // after sending SIGTERM before falling back to SIGKILL. Defaults to 5s + // when zero. On Windows SIGTERM is not delivered, so this is unused there. + ShutdownGracePeriod time.Duration + + // Reuseport selects a reuseport listener instead of fd-passing. // See: https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/ - // - // It's disabled by default + // Disabled by default. Reuseport bool // OnMasterDeath, when non-nil, enables monitoring of the master process @@ -76,25 +118,45 @@ type Prefork struct { // // It is recommended to set this to func() { os.Exit(1) } if no custom // cleanup is needed. + // + // Threading: invoked once from a watcher goroutine in the child. Must not + // block the goroutine for long; must not call Prefork methods. OnMasterDeath func() - // OnChildSpawn is called in the master process whenever a new child process is spawned. + // OnChildSpawn is called in the master after a new child process is + // successfully started, both during initial spawn and during recovery. // It receives the PID of the newly spawned child process. // - // If this callback returns an error, all child processes are killed and the - // prefork operation returns that error. + // If this callback returns an error, all already-running children are + // killed and the prefork operation returns that error. + // + // Threading: invoked synchronously from the master goroutine. Must not + // block; must not call Prefork methods. Panics are recovered and surfaced + // as the returned error. OnChildSpawn func(pid int) error - // OnMasterReady is called in the master process after all child processes have been spawned. - // It receives a slice of all child process PIDs. + // OnMasterReady is called in the master process exactly once, after all + // initial children have been spawned and before the supervision loop runs. + // It receives a slice of all initial child PIDs. // - // If this callback returns an error, the prefork operation will be aborted. + // If this callback returns an error, the prefork operation aborts and + // returns that error after killing the children. + // + // Threading: invoked synchronously from the master goroutine. The slice + // is owned by the caller after the call returns. Panics are recovered. OnMasterReady func(childPIDs []int) error - // OnChildRecover is called in the master process when a crashed child process - // has been replaced by a new one. It receives the PID of the old (crashed) - // process and the PID of the newly spawned replacement. - OnChildRecover func(oldPid, newPid int) + // OnChildRecover is called in the master after a crashed child has been + // replaced. It receives the PID of the old (crashed) process and the PID + // of its replacement. + // + // If this callback returns an error, all running children are killed and + // the prefork operation returns that error. + // + // Threading: invoked synchronously from the master goroutine, after + // OnChildSpawn for the new child. Panics are recovered and surfaced as + // the returned error. + OnChildRecover func(oldPID, newPID int) error // CommandProducer creates and starts a child process command. // If nil, the default implementation re-executes the current binary @@ -104,20 +166,23 @@ type Prefork struct { // A custom producer must: // - Set FASTHTTP_PREFORK_CHILD=1 in the child's environment // (otherwise IsChild() returns false and the child won't serve) - // - Call cmd.Start() before returning (the returned command must be started) + // - Call cmd.Start() before returning (the returned command must be + // started so cmd.Process is non-nil) // - Pass the provided files as cmd.ExtraFiles when Reuseport is false // - // This is primarily useful for testing (injecting dummy commands) - // or for frameworks that need custom child process setup. + // Primarily useful for testing (injecting dummy commands) or for + // frameworks that need custom child process setup. CommandProducer func(files []*os.File) (*exec.Cmd, error) } -// IsChild checks if the current thread/process is a child. +// IsChild reports whether the current process is a prefork child. func IsChild() bool { - return os.Getenv(preforkChildEnvVariable) == "1" + return os.Getenv(preforkChildEnvVariable) == preforkChildEnvValue } // New wraps the fasthttp server to run with preforked processes. +// It seeds Network and RecoverThreshold to sensible defaults; existing +// fields on s (Logger, Serve*) are captured. func New(s *fasthttp.Server) *Prefork { return &Prefork{ Network: defaultNetwork, @@ -140,20 +205,32 @@ func (p *Prefork) logger() Logger { return defaultLogger } +// invokeHook runs fn under a panic recovery, returning the panic as an error +// so a misbehaving callback never tears down the master. +func (p *Prefork) invokeHook(name string, fn func() error) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("prefork: %s panicked: %v", name, r) + p.logger().Printf("%v", err) + } + }() + return fn() +} + func (p *Prefork) watchMaster(masterPID int) { if runtime.GOOS == "windows" { // On Windows, os.Getppid() returns a static PID that doesn't change // when the parent exits (no reparenting). Use FindProcess+Wait instead. proc, err := os.FindProcess(masterPID) if err != nil { - p.logger().Printf("watchMaster: failed to find master process %d: %v\n", masterPID, err) + p.logger().Printf("watchMaster: failed to find master process %d: %v", masterPID, err) p.OnMasterDeath() return } if _, err = proc.Wait(); err != nil { - p.logger().Printf("watchMaster: error waiting for master process %d: %v\n", masterPID, err) + p.logger().Printf("watchMaster: error waiting for master process %d: %v", masterPID, err) } - p.logger().Printf("master process died\n") + p.logger().Printf("master process %d died", masterPID) p.OnMasterDeath() return } @@ -162,12 +239,12 @@ func (p *Prefork) watchMaster(masterPID int) { // to another process, causing Getppid() to change. Comparing against // the original masterPID (instead of hardcoding 1) ensures this works // correctly when the master itself is PID 1 (e.g. in Docker containers). - ticker := time.NewTicker(500 * time.Millisecond) + ticker := time.NewTicker(masterPollInterval) defer ticker.Stop() for range ticker.C { if os.Getppid() != masterPID { - p.logger().Printf("master process died\n") + p.logger().Printf("master process %d died", masterPID) p.OnMasterDeath() return } @@ -185,8 +262,10 @@ func (p *Prefork) listen(addr string) (net.Listener, error) { return reuseport.Listen(p.Network, addr) } - // File descriptor 3 is the first ExtraFiles entry passed by the master process. - return net.FileListener(os.NewFile(3, "")) + // fd inheritedListenerFD is the first ExtraFiles entry passed by the + // master process when Reuseport is false. Naming the file gives clearer + // errors from net.FileListener if the fd is invalid. + return net.FileListener(os.NewFile(inheritedListenerFD, "fasthttp-prefork-listener")) } // listenAsChild performs the common child process setup: creates the listener @@ -213,62 +292,142 @@ func (p *Prefork) setTCPListenerFiles(addr string) error { tcpAddr, err := net.ResolveTCPAddr(p.Network, addr) if err != nil { - return err + return fmt.Errorf("prefork: resolve %s/%s: %w", p.Network, addr, err) } - tcplistener, err := net.ListenTCP(p.Network, tcpAddr) + tcpListener, err := net.ListenTCP(p.Network, tcpAddr) if err != nil { - return err + return fmt.Errorf("prefork: listen tcp %s: %w", addr, err) } - p.ln = tcplistener + p.ln = tcpListener - fl, err := tcplistener.File() + listenerFile, err := tcpListener.File() if err != nil { - return err + return fmt.Errorf("prefork: dup listener fd: %w", err) } - p.files = []*os.File{fl} + p.files = []*os.File{listenerFile} return nil } +// childEnv returns os.Environ() with the prefork child marker variable set, +// stripping any pre-existing value to avoid duplicate keys with last-wins +// semantics. +func childEnv() []string { + src := os.Environ() + out := make([]string, 0, len(src)+1) + prefix := preforkChildEnvVariable + "=" + for _, kv := range src { + if len(kv) >= len(prefix) && kv[:len(prefix)] == prefix { + continue + } + out = append(out, kv) + } + out = append(out, preforkChildEnvVariable+"="+preforkChildEnvValue) + return out +} + func (p *Prefork) doCommand() (*exec.Cmd, error) { - // Use custom CommandProducer if provided if p.CommandProducer != nil { cmd, err := p.CommandProducer(p.files) if err != nil { - return nil, err + return nil, fmt.Errorf("prefork: CommandProducer: %w", err) } - if cmd == nil || cmd.Process == nil { - return nil, errors.New("prefork: CommandProducer must return a started command") + if cmd == nil { + return nil, ErrCommandProducerNilCmd + } + if cmd.Process == nil { + return nil, ErrCommandProducerNotStarted } return cmd, nil } - // Default implementation using os.Executable() for reliable path resolution executable, err := os.Executable() if err != nil { - return nil, err + return nil, fmt.Errorf("prefork: resolve executable: %w", err) } - args := make([]string, len(os.Args)) - args[0] = executable - copy(args[1:], os.Args[1:]) + args := append([]string{executable}, os.Args[1:]...) cmd := &exec.Cmd{ Path: executable, Args: args, Stdout: os.Stdout, Stderr: os.Stderr, - Env: append(os.Environ(), preforkChildEnvVariable+"=1"), + Env: childEnv(), ExtraFiles: p.files, } - err = cmd.Start() - return cmd, err + if err = cmd.Start(); err != nil { + return nil, fmt.Errorf("prefork: start child %q: %w", executable, err) + } + return cmd, nil } -func (p *Prefork) prefork(addr string) (err error) { +type childExit struct { + err error + pid int +} + +// shutdownChildren signals every entry in childProcs first with SIGTERM (on +// platforms where it is supported) and waits up to grace for them to exit. +// Survivors are then killed unconditionally. wg tracks the per-child Wait +// goroutines and must be drained before returning so no goroutine outlives +// prefork(). +func (p *Prefork) shutdownChildren( + childProcs map[int]*exec.Cmd, + wg *sync.WaitGroup, + cancel context.CancelFunc, + grace time.Duration, +) { + if grace <= 0 { + grace = defaultShutdownGracePeriod + } + + if runtime.GOOS != "windows" { + for pid, proc := range childProcs { + if proc == nil || proc.Process == nil { + continue + } + if termErr := proc.Process.Signal(syscall.SIGTERM); termErr != nil && + !errors.Is(termErr, os.ErrProcessDone) { + p.logger().Printf("prefork: SIGTERM child %d: %v", pid, termErr) + } + } + } + + // Wait for graceful exits, with a timeout fallback to SIGKILL. + graceful := make(chan struct{}) + go func() { + wg.Wait() + close(graceful) + }() + + timer := time.NewTimer(grace) + defer timer.Stop() + select { + case <-graceful: + case <-timer.C: + } + + for pid, proc := range childProcs { + if proc == nil || proc.Process == nil { + continue + } + if killErr := proc.Process.Kill(); killErr != nil && + !errors.Is(killErr, os.ErrProcessDone) { + p.logger().Printf("prefork: kill child %d: %v", pid, killErr) + } + } + + // Cancel the per-Wait goroutines' send-context so any still blocked on + // sigCh send unblock cleanly, then wait for all of them to exit. + cancel() + wg.Wait() +} + +func (p *Prefork) prefork(addr string) (err error) { //nolint:gocyclo if !p.Reuseport { if runtime.GOOS == "windows" { return ErrOnlyReuseportOnWindows @@ -278,37 +437,60 @@ func (p *Prefork) prefork(addr string) (err error) { return err } - // defer for closing the net.Listener opened by setTCPListenerFiles. + // Close listener fds opened by setTCPListenerFiles. Both the original + // tcpListener (p.ln) and the duped fd (p.files[0]) belong to the + // master only; children inherit independent dup'd copies via fork+exec. defer func() { - e := p.ln.Close() - if err == nil { - err = e + err = errors.Join(err, p.ln.Close()) + for _, f := range p.files { + if closeErr := f.Close(); closeErr != nil { + p.logger().Printf("prefork: close listener fd: %v", closeErr) + } + } + p.files = nil + }() + } + + // ctx cancels per-child Wait goroutines so they unblock from sigCh sends + // once the supervision loop is gone. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Catch SIGTERM/SIGINT in the master so we run our shutdown path instead + // of being killed by the OS without children getting a graceful chance. + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(signalCh) + + goMaxProcs := runtime.GOMAXPROCS(0) + // Buffer is sized to initial fleet; per-child goroutines fall back to a + // context-aware select on send so capacity is not load-bearing. + sigCh := make(chan childExit, goMaxProcs) + childProcs := make(map[int]*exec.Cmd, goMaxProcs) + + var wg sync.WaitGroup + startWait := func(cmd *exec.Cmd, pid int) { + wg.Add(1) + go func() { + defer wg.Done() + result := childExit{pid: pid, err: cmd.Wait()} + select { + case sigCh <- result: + case <-ctx.Done(): } }() } - type procSig struct { - err error - pid int - } - - goMaxProcs := runtime.GOMAXPROCS(0) - sigCh := make(chan procSig, goMaxProcs) - childProcs := make(map[int]*exec.Cmd, goMaxProcs) - defer func() { - for _, proc := range childProcs { - _ = proc.Process.Kill() - } + p.shutdownChildren(childProcs, &wg, cancel, p.ShutdownGracePeriod) }() - // Collect child PIDs for OnMasterReady callback childPIDs := make([]int, 0, goMaxProcs) for range goMaxProcs { var cmd *exec.Cmd if cmd, err = p.doCommand(); err != nil { - p.logger().Printf("failed to start a child prefork process, error: %v\n", err) + p.logger().Printf("prefork: failed to start a child process: %v", err) return err } @@ -316,71 +498,96 @@ func (p *Prefork) prefork(addr string) (err error) { childProcs[pid] = cmd childPIDs = append(childPIDs, pid) - // Start Wait goroutine before OnChildSpawn so that if OnChildSpawn - // fails and the deferred Kill() runs, Wait() will collect the child. - go func(c *exec.Cmd, pid int) { - sigCh <- procSig{pid: pid, err: c.Wait()} - }(cmd, pid) + // Start Wait goroutine before the user callback so a panic / error + // return from OnChildSpawn cannot leave a zombie behind. + startWait(cmd, pid) if p.OnChildSpawn != nil { - if err = p.OnChildSpawn(pid); err != nil { - p.logger().Printf("OnChildSpawn callback failed for PID %d: %v\n", pid, err) - return err + pid := pid + if hookErr := p.invokeHook("OnChildSpawn", func() error { + return p.OnChildSpawn(pid) + }); hookErr != nil { + p.logger().Printf("prefork: OnChildSpawn for PID %d: %v", pid, hookErr) + return hookErr } } } - // Call OnMasterReady callback after all children are spawned if p.OnMasterReady != nil { - if err = p.OnMasterReady(childPIDs); err != nil { - p.logger().Printf("OnMasterReady callback failed: %v\n", err) - return err + pids := append([]int(nil), childPIDs...) + if hookErr := p.invokeHook("OnMasterReady", func() error { + return p.OnMasterReady(pids) + }); hookErr != nil { + p.logger().Printf("prefork: OnMasterReady: %v", hookErr) + return hookErr } } var exitedProcs int - for sig := range sigCh { - delete(childProcs, sig.pid) + for { + select { + case sig := <-signalCh: + p.logger().Printf("prefork: received signal %v, shutting down", sig) + return nil - p.logger().Printf("one of the child prefork processes exited with "+ - "error: %v", sig.err) + case sig := <-sigCh: + delete(childProcs, sig.pid) - exitedProcs++ - if exitedProcs > p.RecoverThreshold { - p.logger().Printf("child prefork processes exit too many times, "+ - "which exceeds the value of RecoverThreshold(%d), "+ - "exiting the master process.\n", p.RecoverThreshold) - err = ErrOverRecovery - break - } + if sig.err != nil { + p.logger().Printf("prefork: child PID %d exited: %v", sig.pid, sig.err) + } else { + p.logger().Printf("prefork: child PID %d exited cleanly", sig.pid) + } - var cmd *exec.Cmd - cmd, err = p.doCommand() - if err != nil { - break - } - newPid := cmd.Process.Pid - childProcs[newPid] = cmd + exitedProcs++ + if exitedProcs > p.RecoverThreshold { + p.logger().Printf( + "prefork: child exits (%d) exceed RecoverThreshold (%d), terminating master", + exitedProcs, p.RecoverThreshold, + ) + return ErrOverRecovery + } - // Start Wait goroutine before callbacks to avoid zombie processes - // if a callback fails and the deferred Kill() runs. - go func(c *exec.Cmd, pid int) { - sigCh <- procSig{pid: pid, err: c.Wait()} - }(cmd, newPid) + if p.RecoverInterval > 0 { + select { + case <-time.After(p.RecoverInterval): + case <-signalCh: + return nil + } + } - if p.OnChildSpawn != nil { - if err = p.OnChildSpawn(newPid); err != nil { - p.logger().Printf("OnChildSpawn callback failed for recovered PID %d: %v\n", newPid, err) - return err + cmd, doErr := p.doCommand() + if doErr != nil { + p.logger().Printf("prefork: recovery doCommand: %v", doErr) + return doErr + } + newPID := cmd.Process.Pid + childProcs[newPID] = cmd + + startWait(cmd, newPID) + + if p.OnChildSpawn != nil { + newPID := newPID + if hookErr := p.invokeHook("OnChildSpawn", func() error { + return p.OnChildSpawn(newPID) + }); hookErr != nil { + p.logger().Printf("prefork: OnChildSpawn for recovered PID %d: %v", newPID, hookErr) + return hookErr + } + } + + if p.OnChildRecover != nil { + oldPID := sig.pid + newPID := newPID + if hookErr := p.invokeHook("OnChildRecover", func() error { + return p.OnChildRecover(oldPID, newPID) + }); hookErr != nil { + p.logger().Printf("prefork: OnChildRecover (%d -> %d): %v", oldPID, newPID, hookErr) + return hookErr + } } } - - if p.OnChildRecover != nil { - p.OnChildRecover(sig.pid, newPid) - } } - - return err } // ListenAndServe serves HTTP requests from the given TCP addr. @@ -398,11 +605,10 @@ func (p *Prefork) ListenAndServe(addr string) error { // ListenAndServeTLS serves HTTPS requests from the given TCP addr. // -// certKey is the path to the TLS private key file. -// certFile is the path to the TLS certificate file. -// -// Note: parameter order is (addr, certKey, certFile) — key before cert. -// Internally forwards to ServeTLSFunc as (certFile, certKey). +// Note: parameter order is (addr, certKey, certFile) — key path comes +// before cert path. This is preserved for backward compatibility with +// existing callers and differs from fasthttp.Server.ListenAndServeTLS. +// New code should prefer ListenAndServeTLSEmbed. func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error { if IsChild() { ln, err := p.listenAsChild(addr) diff --git a/prefork/prefork_test.go b/prefork/prefork_test.go index d358b51..004450b 100644 --- a/prefork/prefork_test.go +++ b/prefork/prefork_test.go @@ -3,44 +3,78 @@ package prefork import ( "errors" "fmt" - "math/rand" "net" "os" "os/exec" - "reflect" "runtime" "sync" + "sync/atomic" "testing" + "time" "github.com/valyala/fasthttp" ) -func setUp() { - os.Setenv(preforkChildEnvVariable, "1") +// freeAddr returns a TCP4 address that the OS believes is free at call time. +// It avoids the rand-port collision flakes the previous helper exhibited. +func freeAddr(t testing.TB) string { + t.Helper() + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("freeAddr: %v", err) + } + addr := ln.Addr().String() + if closeErr := ln.Close(); closeErr != nil { + t.Fatalf("freeAddr: close: %v", closeErr) + } + return addr } -func tearDown() { - os.Unsetenv(preforkChildEnvVariable) -} +// noopChildProducer returns a CommandProducer that re-execs the test binary +// into a no-op subprocess. The returned cleanup must be deferred (or registered +// via t.Cleanup) so leaked subprocesses are reaped if the test fails midway. +func noopChildProducer(t testing.TB) (func(files []*os.File) (*exec.Cmd, error), func()) { + t.Helper() + var ( + mu sync.Mutex + spawned []*exec.Cmd + ) -func getAddr() string { - return fmt.Sprintf("127.0.0.1:%d", rand.Intn(9000-3000)+3000) + produce := func(_ []*os.File) (*exec.Cmd, error) { + cmd := exec.Command(os.Args[0], "-test.run=^$") + cmd.Env = append(os.Environ(), preforkChildEnvVariable+"="+preforkChildEnvValue) + if err := cmd.Start(); err != nil { + return nil, err + } + mu.Lock() + spawned = append(spawned, cmd) + mu.Unlock() + return cmd, nil + } + + cleanup := func() { + mu.Lock() + defer mu.Unlock() + for _, cmd := range spawned { + if cmd == nil || cmd.Process == nil { + continue + } + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + } + } + return produce, cleanup } func Test_IsChild(t *testing.T) { - // This test can't run parallel as it modifies the process environment. - - v := IsChild() - if v { - t.Errorf("IsChild() == %v, want %v", v, false) + // This test cannot run in parallel — IsChild() reads a process-global env var. + if IsChild() { + t.Fatal("test starts as child unexpectedly") } - setUp() - defer tearDown() - - v = IsChild() - if !v { - t.Errorf("IsChild() == %v, want %v", v, true) + t.Setenv(preforkChildEnvVariable, preforkChildEnvValue) + if !IsChild() { + t.Errorf("IsChild() == false after Setenv, want true") } } @@ -53,51 +87,37 @@ func Test_New(t *testing.T) { if p.Network != defaultNetwork { t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork) } - - if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() { - t.Errorf("Prefork.ServeFunc == %p, want %p", p.ServeFunc, s.Serve) + if p.RecoverThreshold <= 0 { + t.Errorf("Prefork.RecoverThreshold == %d, want > 0", p.RecoverThreshold) } - - if reflect.ValueOf(p.ServeTLSFunc).Pointer() != reflect.ValueOf(s.ServeTLS).Pointer() { - t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSFunc, s.ServeTLS) - } - - if reflect.ValueOf(p.ServeTLSEmbedFunc).Pointer() != reflect.ValueOf(s.ServeTLSEmbed).Pointer() { - t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSEmbedFunc, s.ServeTLSEmbed) + if p.ServeFunc == nil || p.ServeTLSFunc == nil || p.ServeTLSEmbedFunc == nil { + t.Error("New() did not wire one of ServeFunc/ServeTLSFunc/ServeTLSEmbedFunc") } } -func Test_listen(t *testing.T) { +func Test_listen_Reuseport(t *testing.T) { prev := runtime.GOMAXPROCS(0) t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) - p := &Prefork{ - Reuseport: true, - } - addr := getAddr() + p := &Prefork{Reuseport: true} + addr := freeAddr(t) ln, err := p.listen(addr) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatalf("listen: %v", err) } + t.Cleanup(func() { + _ = ln.Close() + }) - ln.Close() - - lnAddr := ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) + if got, want := ln.Addr().String(), addr; got != want { + t.Errorf("ln.Addr() == %q, want %q", got, want) } - if p.Network != defaultNetwork { t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork) } - - procs := runtime.GOMAXPROCS(0) - if procs != 1 { - t.Errorf("GOMAXPROCS == %d, want %d", procs, 1) - } } func Test_setTCPListenerFiles(t *testing.T) { @@ -108,144 +128,187 @@ func Test_setTCPListenerFiles(t *testing.T) { } p := &Prefork{} - addr := getAddr() + addr := freeAddr(t) - err := p.setTCPListenerFiles(addr) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if err := p.setTCPListenerFiles(addr); err != nil { + t.Fatalf("setTCPListenerFiles: %v", err) } + t.Cleanup(func() { + _ = p.ln.Close() + for _, f := range p.files { + _ = f.Close() + } + }) if p.ln == nil { - t.Fatal("Prefork.ln is nil") + t.Fatal("p.ln is nil after setTCPListenerFiles") } - - p.ln.Close() - - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) + if got, want := p.ln.Addr().String(), addr; got != want { + t.Errorf("p.ln.Addr() == %q, want %q", got, want) } - - if p.Network != defaultNetwork { - t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork) - } - if len(p.files) != 1 { - t.Errorf("Prefork.files == %d, want %d", len(p.files), 1) + t.Errorf("len(p.files) == %d, want 1", len(p.files)) } } -func Test_ListenAndServe(t *testing.T) { - // This test can't run parallel as it modifies the process environment. +func Test_setTCPListenerFiles_BadAddr(t *testing.T) { + t.Parallel() - setUp() - defer tearDown() - - s := &fasthttp.Server{} - p := New(s) - p.Reuseport = true - p.ServeFunc = func(ln net.Listener) error { - return nil + if runtime.GOOS == "windows" { + t.SkipNow() } - - addr := getAddr() - - err := p.ListenAndServe(addr) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - p.ln.Close() - - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) - } - - if p.ln == nil { - t.Error("Prefork.ln is nil") + p := &Prefork{} + if err := p.setTCPListenerFiles("definitely not an address"); err == nil { + t.Fatal("expected error for malformed addr, got nil") } } -func Test_ListenAndServeTLS(t *testing.T) { - // This test can't run parallel as it modifies the process environment. +// Test_ListenAndServe_Stub_ChildPath drives the child branch of all three +// ListenAndServe* entry points using a stubbed Serve function. It replaces the +// previous trio of near-identical tests that only validated field assignment. +func Test_ListenAndServe_Stub_ChildPath(t *testing.T) { + // child env mutation precludes t.Parallel. + t.Setenv(preforkChildEnvVariable, preforkChildEnvValue) - setUp() - defer tearDown() - - s := &fasthttp.Server{} - p := New(s) - p.Reuseport = true - p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error { - return nil + type call struct { + listener bool + certFile string + keyFile string + certData string + keyData string } - addr := getAddr() - - err := p.ListenAndServeTLS(addr, "./key", "./cert") - if err != nil { - t.Errorf("Unexpected error: %v", err) + tests := []struct { + name string + run func(t *testing.T, p *Prefork, addr string) error + want call + }{ + { + name: "ListenAndServe", + run: func(_ *testing.T, p *Prefork, addr string) error { return p.ListenAndServe(addr) }, + want: call{listener: true}, + }, + { + name: "ListenAndServeTLS", + run: func(_ *testing.T, p *Prefork, addr string) error { + return p.ListenAndServeTLS(addr, "./key", "./cert") + }, + want: call{listener: true, certFile: "./cert", keyFile: "./key"}, + }, + { + name: "ListenAndServeTLSEmbed", + run: func(_ *testing.T, p *Prefork, addr string) error { + return p.ListenAndServeTLSEmbed(addr, []byte("certPEM"), []byte("keyPEM")) + }, + want: call{listener: true, certData: "certPEM", keyData: "keyPEM"}, + }, } - p.ln.Close() + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var got call + p := New(&fasthttp.Server{}) + p.Reuseport = true + p.ServeFunc = func(ln net.Listener) error { + got.listener = ln != nil + return nil + } + p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error { + got.listener = ln != nil + got.certFile = certFile + got.keyFile = keyFile + return nil + } + p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error { + got.listener = ln != nil + got.certData = string(certData) + got.keyData = string(keyData) + return nil + } - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) - } + addr := freeAddr(t) + if err := tc.run(t, p, addr); err != nil { + t.Fatalf("%s: %v", tc.name, err) + } + t.Cleanup(func() { + if p.ln != nil { + _ = p.ln.Close() + } + }) - if p.ln == nil { - t.Error("Prefork.ln is nil") + if got != tc.want { + t.Errorf("%s call = %+v, want %+v", tc.name, got, tc.want) + } + }) } } -func Test_ListenAndServeTLSEmbed(t *testing.T) { - // This test can't run parallel as it modifies the process environment. +func Test_doCommand_CommandProducerErrors(t *testing.T) { + t.Parallel() - setUp() - defer tearDown() - - s := &fasthttp.Server{} - p := New(s) - p.Reuseport = true - p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error { - return nil + tests := []struct { + name string + produce func(files []*os.File) (*exec.Cmd, error) + wantErr error + }{ + { + name: "producer returns error", + produce: func([]*os.File) (*exec.Cmd, error) { + return nil, errors.New("boom") + }, + }, + { + name: "producer returns nil cmd", + //nolint:nilnil // intentionally tests the (nil, nil) misbehaviour guard + produce: func([]*os.File) (*exec.Cmd, error) { + return nil, nil + }, + wantErr: ErrCommandProducerNilCmd, + }, + { + name: "producer returns unstarted cmd", + produce: func([]*os.File) (*exec.Cmd, error) { + return &exec.Cmd{}, nil + }, + wantErr: ErrCommandProducerNotStarted, + }, } - addr := getAddr() - - err := p.ListenAndServeTLSEmbed(addr, []byte("key"), []byte("cert")) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - p.ln.Close() - - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) - } - - if p.ln == nil { - t.Error("Prefork.ln is nil") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + p := &Prefork{CommandProducer: tc.produce} + cmd, err := p.doCommand() + if cmd != nil { + t.Errorf("expected nil cmd on error, got %v", cmd) + } + if err == nil { + t.Fatal("expected error, got nil") + } + if tc.wantErr != nil && !errors.Is(err, tc.wantErr) { + t.Errorf("err = %v, want errors.Is %v", err, tc.wantErr) + } + }) } } type testLogger struct { + mu sync.Mutex messages []string } func (l *testLogger) Printf(format string, args ...any) { - l.messages = append(l.messages, fmt.Sprintf(format, args...)) + msg := fmt.Sprintf(format, args...) + l.mu.Lock() + l.messages = append(l.messages, msg) + l.mu.Unlock() } -// Test_Prefork_Lifecycle runs the full prefork lifecycle with a CommandProducer -// and verifies that callbacks are invoked in the correct order with the correct arguments. +// Test_Prefork_Lifecycle drives prefork() to ErrOverRecovery via +// short-lived no-op children and asserts the callback ordering / arguments. func Test_Prefork_Lifecycle(t *testing.T) { prev := runtime.GOMAXPROCS(2) - t.Cleanup(func() { - runtime.GOMAXPROCS(prev) - }) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) type event struct { name string @@ -260,16 +323,14 @@ func Test_Prefork_Lifecycle(t *testing.T) { mu.Unlock() } + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + p := &Prefork{ Reuseport: true, RecoverThreshold: 1, Logger: &testLogger{}, - CommandProducer: func(_ []*os.File) (*exec.Cmd, error) { - cmd := exec.Command(os.Args[0], "-test.run=^$") - cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1") - err := cmd.Start() - return cmd, err - }, + CommandProducer: produce, OnChildSpawn: func(pid int) error { record("spawn", pid) return nil @@ -278,12 +339,13 @@ func Test_Prefork_Lifecycle(t *testing.T) { record("ready", childPIDs...) return nil }, - OnChildRecover: func(oldPid, newPid int) { - record("recover", oldPid, newPid) + OnChildRecover: func(oldPID, newPID int) error { + record("recover", oldPID, newPID) + return nil }, } - err := p.prefork(getAddr()) + err := p.prefork(freeAddr(t)) if !errors.Is(err, ErrOverRecovery) { t.Fatalf("expected ErrOverRecovery, got: %v", err) } @@ -291,10 +353,9 @@ func Test_Prefork_Lifecycle(t *testing.T) { mu.Lock() defer mu.Unlock() - // Verify we got spawn events for initial children - var spawnCount int - var readyCount int - var recoverCount int + goMaxProcs := runtime.GOMAXPROCS(0) + + var spawnCount, readyCount, recoverCount int for _, e := range events { switch e.name { case "spawn": @@ -318,21 +379,17 @@ func Test_Prefork_Lifecycle(t *testing.T) { } } - goMaxProcs := runtime.GOMAXPROCS(0) - if readyCount != 1 { t.Errorf("OnMasterReady called %d times, want 1", readyCount) } - - // Initial spawns + at least one recovery spawn if spawnCount < goMaxProcs { t.Errorf("OnChildSpawn called %d times, want at least %d", spawnCount, goMaxProcs) } - if recoverCount == 0 { t.Error("OnChildRecover was never called") } + // ready must come after exactly goMaxProcs initial spawns. readyIdx := -1 spawnsBeforeReady := 0 for i, e := range events { @@ -344,7 +401,6 @@ func Test_Prefork_Lifecycle(t *testing.T) { spawnsBeforeReady++ } } - if readyIdx == -1 { t.Fatal("OnMasterReady was never called") } @@ -352,8 +408,8 @@ func Test_Prefork_Lifecycle(t *testing.T) { t.Errorf("OnMasterReady called after %d initial spawns, want %d", spawnsBeforeReady, goMaxProcs) } + // every recover event must be preceded by a spawn for the new PID. recoveredSpawnByPID := make(map[int]bool) - recoveredPIDs := make(map[int]bool) for _, e := range events[readyIdx+1:] { if e.name == "spawn" { recoveredSpawnByPID[e.pids[0]] = true @@ -362,56 +418,154 @@ func Test_Prefork_Lifecycle(t *testing.T) { if !recoveredSpawnByPID[e.pids[1]] { t.Errorf("OnChildRecover for PID %d happened before OnChildSpawn", e.pids[1]) } - recoveredPIDs[e.pids[1]] = true - } - } - for pid := range recoveredPIDs { - if !recoveredSpawnByPID[pid] { - t.Errorf("OnChildRecover for PID %d did not have a matching OnChildSpawn", pid) } } } -func Test_Prefork_RecoveredChildSpawnError(t *testing.T) { +func Test_Prefork_InitialChildSpawnError(t *testing.T) { prev := runtime.GOMAXPROCS(2) - t.Cleanup(func() { - runtime.GOMAXPROCS(prev) - }) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) - expectedErr := errors.New("spawn failed") - var spawnCount int - var recoverCount int + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + expectedErr := errors.New("initial spawn rejected") + var calls atomic.Int32 p := &Prefork{ Reuseport: true, RecoverThreshold: 1, Logger: &testLogger{}, - CommandProducer: func(_ []*os.File) (*exec.Cmd, error) { - cmd := exec.Command(os.Args[0], "-test.run=^$") - cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1") - err := cmd.Start() - return cmd, err + CommandProducer: produce, + OnChildSpawn: func(_ int) error { + calls.Add(1) + return expectedErr }, + } + + err := p.prefork(freeAddr(t)) + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got: %v", expectedErr, err) + } + if calls.Load() == 0 { + t.Fatal("OnChildSpawn was never invoked") + } +} + +func Test_Prefork_OnMasterReadyError(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + expectedErr := errors.New("ready rejected") + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + Logger: &testLogger{}, + CommandProducer: produce, + OnMasterReady: func([]int) error { + return expectedErr + }, + } + + err := p.prefork(freeAddr(t)) + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got: %v", expectedErr, err) + } +} + +func Test_Prefork_RecoveredChildSpawnError(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + expectedErr := errors.New("spawn failed") + var spawnCount, recoverCount atomic.Int32 + + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + Logger: &testLogger{}, + CommandProducer: produce, OnChildSpawn: func(pid int) error { if pid <= 0 { t.Errorf("OnChildSpawn called with invalid PID: %d", pid) } - spawnCount++ - if spawnCount > runtime.GOMAXPROCS(0) { + n := spawnCount.Add(1) + if int(n) > runtime.GOMAXPROCS(0) { return expectedErr } return nil }, - OnChildRecover: func(_, _ int) { - recoverCount++ + OnChildRecover: func(_, _ int) error { + recoverCount.Add(1) + return nil }, } - err := p.prefork(getAddr()) + err := p.prefork(freeAddr(t)) if !errors.Is(err, expectedErr) { t.Fatalf("expected %v, got: %v", expectedErr, err) } - if recoverCount != 0 { - t.Fatalf("OnChildRecover called %d times, want 0", recoverCount) + if got := recoverCount.Load(); got != 0 { + t.Fatalf("OnChildRecover called %d times, want 0", got) + } +} + +func Test_Prefork_HookPanicSurfaces(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + Logger: &testLogger{}, + CommandProducer: produce, + OnChildSpawn: func(_ int) error { + panic("user code blew up") + }, + } + + err := p.prefork(freeAddr(t)) + if err == nil { + t.Fatal("expected error from panicking hook, got nil") + } +} + +// Test_Prefork_RecoverInterval verifies the optional backoff actually delays +// the respawn — without rabbit-holing into hard wallclock assertions. +func Test_Prefork_RecoverInterval(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + const interval = 50 * time.Millisecond + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + RecoverInterval: interval, + Logger: &testLogger{}, + CommandProducer: produce, + } + + start := time.Now() + err := p.prefork(freeAddr(t)) + elapsed := time.Since(start) + + if !errors.Is(err, ErrOverRecovery) { + t.Fatalf("expected ErrOverRecovery, got %v", err) + } + // At least one recover interval must have elapsed before threshold fired. + if elapsed < interval { + t.Errorf("elapsed %v < interval %v; backoff did not apply", elapsed, interval) } }