From 2c1590038fa5d0e04896c8b0aafc138ce286bfa7 Mon Sep 17 00:00:00 2001 From: RW Date: Sun, 21 Jun 2026 10:26:43 +0200 Subject: [PATCH] feat(prefork): graceful shutdown, leak fixes, hook robustness (re-open of #2180 follow-up) (#2199) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(prefork): graceful shutdown, leak fixes, hook robustness Addresses outstanding review concerns and several adjacent issues surfaced during a follow-up review pass. Lifecycle / supervision - Track every per-child Wait goroutine via sync.WaitGroup and unblock pending sigCh sends through a context.Cancel so early-return paths (OnChildSpawn / OnMasterReady error, recovery doCommand error, ErrOverRecovery) can no longer leak goroutines or stall children. - Install signal.Notify(SIGTERM, SIGINT) in the master so deploy/ rolling-restart signals enter the shutdown path instead of killing the master without graceful teardown. - Replace the unconditional SIGKILL defer with a SIGTERM-then-SIGKILL sequence gated by a configurable ShutdownGracePeriod (defaults to 5s, Windows path stays SIGKILL since Signal(SIGTERM) is unsupported). API - OnChildRecover now returns error so callers can implement recovery policies (circuit-breaker etc.); panic in any hook is recovered and surfaced as the returned error, with diagnostic logging. - Add RecoverInterval (optional crash-loop backoff) and ShutdownGracePeriod fields with safe zero-value defaults. - Export ErrCommandProducerNilCmd and ErrCommandProducerNotStarted sentinel errors so callers can errors.Is them. - Rename oldPid/newPid to oldPID/newPID per Go initialism convention. - Logger interface now declares an explicit compile-time compatibility check with fasthttp.Logger. Resource hygiene - Master closes both the original tcpListener and the duped fd in p.files when prefork() returns; previously the duped fd leaked once per call. - doCommand wraps every error path with %w + fmt.Errorf so caller-side diagnostics keep stage context. - Strip pre-existing FASTHTTP_PREFORK_CHILD entries before appending so child env never carries duplicate keys. - Extract magic numbers as package constants (inheritedListenerFD, masterPollInterval, defaultShutdownGracePeriod, preforkChildEnvValue). - Rename the inherited listener fd via os.NewFile so net.FileListener errors are diagnosable. Tests - Migrate to t.Setenv (drop the global setUp/tearDown helpers) — fixes the env-mutation-vs-parallel race. - Replace rand.Intn port helper with `:0` + Listener.Addr() to remove port-collision flakes under -count and parallel runs. - Collapse the three near-identical Test_ListenAndServe* tests into a single table-driven subtest that actually asserts the args forwarded to ServeFunc/ServeTLSFunc/ServeTLSEmbedFunc. - Add coverage for the previously untested branches: CommandProducer returning err / nil cmd / unstarted cmd, initial OnChildSpawn error, OnMasterReady error, hook panic surfacing, RecoverInterval enforcement. - noopChildProducer helper kills + waits any spawned child binaries during cleanup so failed tests no longer leave subprocesses around. * fix(prefork): address Copilot review on #2199 - listen(): the *os.File wrapping the inherited fd was never closed. net.FileListener dups the fd, so the original was leaking on every child startup. Close it explicitly and return the dup'd listener. - setTCPListenerFiles(): if tcpListener.File() failed, the bound net.Listener stayed open and p.ln pointed at it. Close the listener on the error path and only assign p.ln after the dup succeeds. - prefork(): replace time.After in the RecoverInterval branch with a time.NewTimer that we Stop+drain when a shutdown signal wins the select, so the timer goroutine and channel allocation don't linger during crash-loop shutdown. - invokeHook(): drop the panic log line. The hook caller logs the returned error already, so logging in the recover block produced duplicate output for the same panic. * fix(prefork): harden recovery timer cleanup * fix(prefork): align follow-up with review feedback * fix(prefork): simplify child exit loop * fix(prefork): address review on slice copy and recover backoff - OnMasterReady now receives the internal childPIDs slice directly instead of a defensive copy; the doc states the slice is only valid for the call so callers copy it if they need to keep it. - RecoverInterval backoff moves from an inline time.Sleep in the supervision loop into the per-child Wait goroutine. Concurrent crashes now each restart RecoverInterval after they exit instead of serializing the wait through the loop. The wait is interruptible via ctx so shutdown does not have to outlast a full interval. Co-Authored-By: Claude Opus 4.7 (1M context) * fix(prefork): silence modernize WaitGroup.Go suggestion The CI linter analyzes at go1.26 and modernize suggests sync.WaitGroup.Go, but that API needs go1.25 while the module targets go1.24 (so it would fail the build/vet there). Suppress the suggestion with an inline nolint instead of bumping the module's minimum Go version. Co-Authored-By: Claude Opus 4.7 (1M context) * fix(prefork): cancel Wait goroutines up front so shutdown is not blocked Move cancel() to the top of shutdownChildren so a child that already exited while parked on its RecoverInterval backoff (or a sigCh send) cannot delay teardown for a full RecoverInterval or ShutdownGracePeriod. cmd.Wait() is independent of the context, so the graceful SIGTERM wait still tracks the real process exits. Return early on the graceful path so the kill loop is skipped when every child is already gone. Follow-up review polish: - Clarify the RecoverInterval and OnChildRecover doc comments and document the previously undocumented Serve* fields. - Use sync.WaitGroup.Go (module targets go1.25) and drop the now-stale nolint:modernize. - Tests: replace the unreachable listener-close assertion with the real p.ln==nil contract, restore GOMAXPROCS in the child-path test, and add Test_childEnv plus Test_Prefork_ShutdownDoesNotBlockOnRecoverInterval. Co-Authored-By: Claude Opus 4.8 (1M context) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- prefork/prefork.go | 451 ++++++++++++++++++++------- prefork/prefork_test.go | 676 ++++++++++++++++++++++++---------------- 2 files changed, 740 insertions(+), 387 deletions(-) diff --git a/prefork/prefork.go b/prefork/prefork.go index c4ee3e2..9199fc4 100644 --- a/prefork/prefork.go +++ b/prefork/prefork.go @@ -2,12 +2,16 @@ package prefork import ( + "context" "errors" + "fmt" "log" "net" "os" "os/exec" "runtime" + "sync" + "syscall" "time" "github.com/valyala/fasthttp" @@ -16,26 +20,57 @@ import ( const ( preforkChildEnvVariable = "FASTHTTP_PREFORK_CHILD" + preforkChildEnvValue = "1" defaultNetwork = "tcp4" + + // inheritedListenerFD is the file descriptor used by the master to pass + // the bound listener to a child process via ExtraFiles. Children open the + // listener via os.NewFile(inheritedListenerFD, ...) when Reuseport is false. + inheritedListenerFD = 3 + + // masterPollInterval is the period of the watchMaster ppid-poll on Unix. + masterPollInterval = 500 * time.Millisecond + + // defaultShutdownGracePeriod is how long the master waits for children to + // exit cleanly after sending SIGTERM before forcibly killing them. + defaultShutdownGracePeriod = 5 * time.Second ) var ( - defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags)) + defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags)) + + // tcpListenerFile is a hook for (*net.TCPListener).File so tests can + // inject failure paths without binding a real socket. tcpListenerFile = (*net.TCPListener).File - // ErrOverRecovery is returned when the times of starting over child prefork processes exceed - // the threshold. + + // ErrOverRecovery is returned when child prefork process restarts exceed + // the value of RecoverThreshold. ErrOverRecovery = errors.New("exceeding the value of RecoverThreshold") - // ErrOnlyReuseportOnWindows is returned when Reuseport is false. + // ErrOnlyReuseportOnWindows is returned when running on Windows without Reuseport. ErrOnlyReuseportOnWindows = errors.New("windows only supports Reuseport = true") + + // ErrCommandProducerNilCmd is returned when a CommandProducer returns + // (nil, nil) instead of a started command. + ErrCommandProducerNilCmd = errors.New("prefork: CommandProducer returned nil command") + + // ErrCommandProducerNotStarted is returned when a CommandProducer returns + // an *exec.Cmd whose Process is nil (i.e. cmd.Start() was not called). + ErrCommandProducerNotStarted = errors.New("prefork: CommandProducer must return a started command") ) -// Logger is used for logging formatted messages. +// Logger is used for logging formatted messages. Its method set is intentionally +// identical to fasthttp.Logger so that *fasthttp.Server.Logger can be assigned +// directly. type Logger interface { // Printf must have the same semantics as log.Printf. Printf(format string, args ...any) } +// Compile-time check that fasthttp.Logger satisfies the local Logger interface; +// keeps the two types in sync if either side ever evolves. +var _ Logger = fasthttp.Logger(nil) + // Prefork implements fasthttp server prefork. // // Preforks master process (with all cores) between several child processes @@ -45,30 +80,49 @@ type Logger interface { // WARNING: using prefork prevents the use of any global state! // Things like in-memory caches won't work. type Prefork struct { - // By default standard logger from log package is used. + // Logger receives diagnostic output. By default the standard log package + // logger writing to stderr is used. Logger Logger ln net.Listener + // ServeFunc, ServeTLSFunc and ServeTLSEmbedFunc serve the inherited + // listener inside each child process. New() wires them to the matching + // *fasthttp.Server methods. When constructing a Prefork directly, set the + // one matching the ListenAndServe* entry point you call; otherwise the + // child panics on a nil call. ServeFunc func(ln net.Listener) error ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error ServeTLSEmbedFunc func(ln net.Listener, certData, keyData []byte) error - // The network must be "tcp", "tcp4" or "tcp6". - // - // By default is "tcp4" + // Network must be "tcp", "tcp4" or "tcp6". Default is "tcp4". Network string files []*os.File - // Child prefork processes may exit with failure and will be started over until the times reach - // the value of RecoverThreshold, then it will return and terminate the server. + // RecoverThreshold caps how often crashed children are respawned before + // the master returns ErrOverRecovery. New() sets it to max(1, GOMAXPROCS/2). + // When constructing a Prefork directly without New(), a zero value will + // terminate the master after the very first child crash. RecoverThreshold int - // Flag to use a listener with reuseport, if not a file Listener will be used + // RecoverInterval, when > 0, delays the respawn of a crashed child by the + // given duration. The delay is applied per child in that child's own Wait + // goroutine before its exit is reported, so simultaneous crashes are not + // serialized: each child is respawned roughly RecoverInterval after it + // exits, independently of the others. The wait is interruptible, so a + // shutdown does not have to outlast a pending interval. Useful as + // crash-loop backoff. + RecoverInterval time.Duration + + // ShutdownGracePeriod is the time the master waits for children to exit + // after sending SIGTERM before falling back to SIGKILL. Defaults to 5s + // when zero. On Windows SIGTERM is not delivered, so this is unused there. + ShutdownGracePeriod time.Duration + + // Reuseport selects a reuseport listener instead of fd-passing. // See: https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/ - // - // It's disabled by default + // Disabled by default. Reuseport bool // OnMasterDeath, when non-nil, enables monitoring of the master process @@ -77,25 +131,43 @@ type Prefork struct { // // It is recommended to set this to func() { os.Exit(1) } if no custom // cleanup is needed. + // + // Threading: invoked once from a watcher goroutine in the child. Must not + // block the goroutine for long; must not call Prefork methods. OnMasterDeath func() - // OnChildSpawn is called in the master process whenever a new child process is spawned. + // OnChildSpawn is called in the master after a new child process is + // successfully started, both during initial spawn and during recovery. // It receives the PID of the newly spawned child process. // - // If this callback returns an error, all child processes are killed and the - // prefork operation returns that error. + // If this callback returns an error, all already-running children are + // killed and the prefork operation returns that error. + // + // Threading: invoked synchronously from the master goroutine. Must not + // block; must not call Prefork methods. OnChildSpawn func(pid int) error - // OnMasterReady is called in the master process after all child processes have been spawned. - // It receives a slice of all child process PIDs. + // OnMasterReady is called in the master process exactly once, after all + // initial children have been spawned and before the supervision loop runs. + // It receives a slice of all initial child PIDs. // - // If this callback returns an error, the prefork operation will be aborted. + // If this callback returns an error, the prefork operation aborts and + // returns that error after killing the children. + // + // Threading: invoked synchronously from the master goroutine. The slice + // is only valid for the duration of the call; copy it if you need to + // retain the PIDs after the callback returns. OnMasterReady func(childPIDs []int) error - // OnChildRecover is called in the master process when a crashed child process - // has been replaced by a new one. It receives the PID of the old (crashed) - // process and the PID of the newly spawned replacement. - OnChildRecover func(oldPid, newPid int) + // OnChildRecover is called in the master after a crashed child has been + // replaced. It receives the PID of the old (crashed) process and the PID + // of its replacement. This is a notification only: unlike OnChildSpawn it + // cannot abort recovery. + // + // Threading: invoked synchronously from the master goroutine, after + // OnChildSpawn for the new child. Must not block; must not call Prefork + // methods. + OnChildRecover func(oldPID, newPID int) // CommandProducer creates and starts a child process command. // If nil, the default implementation re-executes the current binary @@ -105,20 +177,23 @@ type Prefork struct { // A custom producer must: // - Set FASTHTTP_PREFORK_CHILD=1 in the child's environment // (otherwise IsChild() returns false and the child won't serve) - // - Call cmd.Start() before returning (the returned command must be started) + // - Call cmd.Start() before returning (the returned command must be + // started so cmd.Process is non-nil) // - Pass the provided files as cmd.ExtraFiles when Reuseport is false // - // This is primarily useful for testing (injecting dummy commands) - // or for frameworks that need custom child process setup. + // Primarily useful for testing (injecting dummy commands) or for + // frameworks that need custom child process setup. CommandProducer func(files []*os.File) (*exec.Cmd, error) } -// IsChild checks if the current thread/process is a child. +// IsChild reports whether the current process is a prefork child. func IsChild() bool { - return os.Getenv(preforkChildEnvVariable) == "1" + return os.Getenv(preforkChildEnvVariable) == preforkChildEnvValue } // New wraps the fasthttp server to run with preforked processes. +// It seeds Network and RecoverThreshold to sensible defaults; existing +// fields on s (Logger, Serve*) are captured. func New(s *fasthttp.Server) *Prefork { return &Prefork{ Network: defaultNetwork, @@ -147,14 +222,14 @@ func (p *Prefork) watchMaster(masterPID int) { // when the parent exits (no reparenting). Use FindProcess+Wait instead. proc, err := os.FindProcess(masterPID) if err != nil { - p.logger().Printf("watchMaster: failed to find master process %d: %v\n", masterPID, err) + p.logger().Printf("watchMaster: failed to find master process %d: %v", masterPID, err) p.OnMasterDeath() return } if _, err = proc.Wait(); err != nil { - p.logger().Printf("watchMaster: error waiting for master process %d: %v\n", masterPID, err) + p.logger().Printf("watchMaster: error waiting for master process %d: %v", masterPID, err) } - p.logger().Printf("master process died\n") + p.logger().Printf("master process %d died", masterPID) p.OnMasterDeath() return } @@ -163,12 +238,12 @@ func (p *Prefork) watchMaster(masterPID int) { // to another process, causing Getppid() to change. Comparing against // the original masterPID (instead of hardcoding 1) ensures this works // correctly when the master itself is PID 1 (e.g. in Docker containers). - ticker := time.NewTicker(500 * time.Millisecond) + ticker := time.NewTicker(masterPollInterval) defer ticker.Stop() for range ticker.C { if os.Getppid() != masterPID { - p.logger().Printf("master process died\n") + p.logger().Printf("master process %d died", masterPID) p.OnMasterDeath() return } @@ -186,8 +261,25 @@ func (p *Prefork) listen(addr string) (net.Listener, error) { return reuseport.Listen(p.Network, addr) } - // File descriptor 3 is the first ExtraFiles entry passed by the master process. - return net.FileListener(os.NewFile(3, "")) + // fd inheritedListenerFD is the first ExtraFiles entry passed by the + // master process when Reuseport is false. Naming the file gives clearer + // errors from net.FileListener if the fd is invalid. + // + // net.FileListener dups the fd, so we close the wrapping *os.File after + // it returns to avoid leaking the original descriptor. The returned + // listener owns its own dup'd fd and is unaffected by this close. + f := os.NewFile(inheritedListenerFD, "fasthttp-prefork-listener") + ln, err := net.FileListener(f) + if closeErr := f.Close(); closeErr != nil && err == nil { + err = fmt.Errorf("prefork: close inherited listener fd: %w", closeErr) + } + if err != nil { + if ln != nil { + _ = ln.Close() + } + return nil, err + } + return ln, nil } // listenAsChild performs the common child process setup: creates the listener @@ -214,64 +306,165 @@ func (p *Prefork) setTCPListenerFiles(addr string) error { tcpAddr, err := net.ResolveTCPAddr(p.Network, addr) if err != nil { - return err + return fmt.Errorf("prefork: resolve %s/%s: %w", p.Network, addr, err) } - tcplistener, err := net.ListenTCP(p.Network, tcpAddr) + tcpListener, err := net.ListenTCP(p.Network, tcpAddr) if err != nil { - return err + return fmt.Errorf("prefork: listen tcp %s: %w", addr, err) } - p.ln = tcplistener - - fl, err := tcpListenerFile(tcplistener) + listenerFile, err := tcpListenerFile(tcpListener) if err != nil { - _ = tcplistener.Close() - p.ln = nil - return err + // Close the bound listener so we don't leak the socket/fd when + // File() fails. p.ln is intentionally only assigned after this + // point so the caller never sees a half-initialised state. + _ = tcpListener.Close() + return fmt.Errorf("prefork: dup listener fd: %w", err) } - p.files = []*os.File{fl} + p.ln = tcpListener + p.files = []*os.File{listenerFile} return nil } +// childEnv returns os.Environ() with the prefork child marker variable set, +// stripping any pre-existing value to avoid duplicate keys with last-wins +// semantics. +func childEnv() []string { + src := os.Environ() + out := make([]string, 0, len(src)+1) + prefix := preforkChildEnvVariable + "=" + for _, kv := range src { + if len(kv) >= len(prefix) && kv[:len(prefix)] == prefix { + continue + } + out = append(out, kv) + } + out = append(out, preforkChildEnvVariable+"="+preforkChildEnvValue) + return out +} + func (p *Prefork) doCommand() (*exec.Cmd, error) { - // Use custom CommandProducer if provided if p.CommandProducer != nil { cmd, err := p.CommandProducer(p.files) if err != nil { - return nil, err + return nil, fmt.Errorf("prefork: CommandProducer: %w", err) } - if cmd == nil || cmd.Process == nil { - return nil, errors.New("prefork: CommandProducer must return a started command") + if cmd == nil { + return nil, ErrCommandProducerNilCmd + } + if cmd.Process == nil { + return nil, ErrCommandProducerNotStarted } return cmd, nil } - // Default implementation using os.Executable() for reliable path resolution executable, err := os.Executable() if err != nil { - return nil, err + return nil, fmt.Errorf("prefork: resolve executable: %w", err) } - args := make([]string, len(os.Args)) - args[0] = executable - copy(args[1:], os.Args[1:]) + args := append([]string{executable}, os.Args[1:]...) cmd := &exec.Cmd{ Path: executable, Args: args, Stdout: os.Stdout, Stderr: os.Stderr, - Env: append(os.Environ(), preforkChildEnvVariable+"=1"), + Env: childEnv(), ExtraFiles: p.files, } - err = cmd.Start() - return cmd, err + if err = cmd.Start(); err != nil { + return nil, fmt.Errorf("prefork: start child %q: %w", executable, err) + } + return cmd, nil } -func (p *Prefork) prefork(addr string) (err error) { +type childExit struct { + err error + pid int +} + +// shutdownChildren tears down every entry in childProcs. It first cancels the +// per-child Wait goroutines' context so any parked on a RecoverInterval backoff +// or a sigCh send return immediately; cmd.Wait() is not tied to the context, so +// this only strips the artificial delay from the shutdown path while still +// letting us wait for the children to actually exit. Children are then sent +// SIGTERM (on platforms where it is supported) and given up to grace to exit +// before survivors are killed unconditionally. wg tracks the per-child Wait +// goroutines and is drained before returning so no goroutine outlives prefork(). +func (p *Prefork) shutdownChildren( + childProcs map[int]*exec.Cmd, + wg *sync.WaitGroup, + cancel context.CancelFunc, + grace time.Duration, +) { + if grace <= 0 { + grace = defaultShutdownGracePeriod + } + + // Cancel up front, before waiting on wg. A child that already exited may + // have its Wait goroutine parked on the RecoverInterval backoff or a sigCh + // send; without this, shutdown could block for a full RecoverInterval (or + // until ShutdownGracePeriod expires) even though that child is already gone. + // cmd.Wait() is independent of ctx, so the graceful wait below still tracks + // the real process exits. + cancel() + + if runtime.GOOS == "windows" { + for pid, proc := range childProcs { + p.killChild(pid, proc) + } + wg.Wait() + return + } + + for pid, proc := range childProcs { + if proc == nil || proc.Process == nil { + continue + } + if termErr := proc.Process.Signal(syscall.SIGTERM); termErr != nil && + !errors.Is(termErr, os.ErrProcessDone) { + p.logger().Printf("prefork: SIGTERM child %d: %v", pid, termErr) + } + } + + // Wait for graceful exits, with a timeout fallback to SIGKILL. + graceful := make(chan struct{}) + go func() { + wg.Wait() + close(graceful) + }() + + timer := time.NewTimer(grace) + defer timer.Stop() + select { + case <-graceful: + // All children exited within grace; nothing left to kill. + return + case <-timer.C: + } + + for pid, proc := range childProcs { + p.killChild(pid, proc) + } + + wg.Wait() +} + +func (p *Prefork) killChild(pid int, proc *exec.Cmd) { + if proc == nil || proc.Process == nil { + return + } + if killErr := proc.Process.Kill(); killErr != nil && + !errors.Is(killErr, os.ErrProcessDone) { + p.logger().Printf("prefork: kill child %d: %v", pid, killErr) + } +} + +func (p *Prefork) prefork(addr string) (err error) { //nolint:gocyclo if !p.Reuseport { if runtime.GOOS == "windows" { return ErrOnlyReuseportOnWindows @@ -281,42 +474,69 @@ func (p *Prefork) prefork(addr string) (err error) { return err } - // defer for closing the net.Listener opened by setTCPListenerFiles. + // Close listener fds opened by setTCPListenerFiles. Both the original + // tcpListener (p.ln) and the duped fd (p.files[0]) belong to the + // master only; children inherit independent dup'd copies via fork+exec. defer func() { + err = errors.Join(err, p.ln.Close()) for _, f := range p.files { - _ = f.Close() + if closeErr := f.Close(); closeErr != nil { + p.logger().Printf("prefork: close listener fd: %v", closeErr) + } } p.files = nil - - e := p.ln.Close() - if err == nil { - err = e - } }() } - type procSig struct { - err error - pid int - } + // ctx cancels per-child Wait goroutines so they unblock from sigCh sends + // once the supervision loop is gone. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() goMaxProcs := runtime.GOMAXPROCS(0) - sigCh := make(chan procSig, goMaxProcs) + // Buffer is sized to initial fleet; per-child goroutines fall back to a + // context-aware select on send so capacity is not load-bearing. + sigCh := make(chan childExit, goMaxProcs) childProcs := make(map[int]*exec.Cmd, goMaxProcs) + var wg sync.WaitGroup + startWait := func(cmd *exec.Cmd, pid int) { + wg.Go(func() { + result := childExit{pid: pid, err: cmd.Wait()} + + // Apply the crash-loop backoff here, per child, before reporting + // the exit. Sleeping in the supervision loop instead would + // serialize the wait across all crashed children, so N children + // dying at once would restart RecoverInterval apart rather than + // each one RecoverInterval after it quit. The wait is interruptible + // so shutdown does not have to outlast a full interval. + if p.RecoverInterval > 0 { + timer := time.NewTimer(p.RecoverInterval) + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return + } + } + + select { + case sigCh <- result: + case <-ctx.Done(): + } + }) + } + defer func() { - for _, proc := range childProcs { - _ = proc.Process.Kill() - } + p.shutdownChildren(childProcs, &wg, cancel, p.ShutdownGracePeriod) }() - // Collect child PIDs for OnMasterReady callback childPIDs := make([]int, 0, goMaxProcs) for range goMaxProcs { var cmd *exec.Cmd if cmd, err = p.doCommand(); err != nil { - p.logger().Printf("failed to start a child prefork process, error: %v\n", err) + p.logger().Printf("prefork: failed to start a child process: %v", err) return err } @@ -324,25 +544,22 @@ func (p *Prefork) prefork(addr string) (err error) { childProcs[pid] = cmd childPIDs = append(childPIDs, pid) - // Start Wait goroutine before OnChildSpawn so that if OnChildSpawn - // fails and the deferred Kill() runs, Wait() will collect the child. - go func(c *exec.Cmd, pid int) { - sigCh <- procSig{pid: pid, err: c.Wait()} - }(cmd, pid) + // Start Wait goroutine before the user callback so a panic / error + // return from OnChildSpawn cannot leave a zombie behind. + startWait(cmd, pid) if p.OnChildSpawn != nil { - if err = p.OnChildSpawn(pid); err != nil { - p.logger().Printf("OnChildSpawn callback failed for PID %d: %v\n", pid, err) - return err + if hookErr := p.OnChildSpawn(pid); hookErr != nil { + p.logger().Printf("prefork: OnChildSpawn for PID %d: %v", pid, hookErr) + return hookErr } } } - // Call OnMasterReady callback after all children are spawned if p.OnMasterReady != nil { - if err = p.OnMasterReady(childPIDs); err != nil { - p.logger().Printf("OnMasterReady callback failed: %v\n", err) - return err + if hookErr := p.OnMasterReady(childPIDs); hookErr != nil { + p.logger().Printf("prefork: OnMasterReady: %v", hookErr) + return hookErr } } @@ -350,45 +567,46 @@ func (p *Prefork) prefork(addr string) (err error) { for sig := range sigCh { delete(childProcs, sig.pid) - p.logger().Printf("one of the child prefork processes exited with "+ - "error: %v", sig.err) + if sig.err != nil { + p.logger().Printf("prefork: child PID %d exited: %v", sig.pid, sig.err) + } else { + p.logger().Printf("prefork: child PID %d exited cleanly", sig.pid) + } exitedProcs++ if exitedProcs > p.RecoverThreshold { - p.logger().Printf("child prefork processes exit too many times, "+ - "which exceeds the value of RecoverThreshold(%d), "+ - "exiting the master process.\n", p.RecoverThreshold) - err = ErrOverRecovery - break + p.logger().Printf( + "prefork: child exits (%d) exceed RecoverThreshold (%d), terminating master", + exitedProcs, p.RecoverThreshold, + ) + return ErrOverRecovery } - var cmd *exec.Cmd - cmd, err = p.doCommand() - if err != nil { - break + // The RecoverInterval backoff is applied in the per-child Wait + // goroutine (see startWait) before the exit is reported, so it does + // not block recovery of other children here. + cmd, doErr := p.doCommand() + if doErr != nil { + p.logger().Printf("prefork: recovery doCommand: %v", doErr) + return doErr } - newPid := cmd.Process.Pid - childProcs[newPid] = cmd + newPID := cmd.Process.Pid + childProcs[newPID] = cmd - // Start Wait goroutine before callbacks to avoid zombie processes - // if a callback fails and the deferred Kill() runs. - go func(c *exec.Cmd, pid int) { - sigCh <- procSig{pid: pid, err: c.Wait()} - }(cmd, newPid) + startWait(cmd, newPID) if p.OnChildSpawn != nil { - if err = p.OnChildSpawn(newPid); err != nil { - p.logger().Printf("OnChildSpawn callback failed for recovered PID %d: %v\n", newPid, err) - return err + if hookErr := p.OnChildSpawn(newPID); hookErr != nil { + p.logger().Printf("prefork: OnChildSpawn for recovered PID %d: %v", newPID, hookErr) + return hookErr } } if p.OnChildRecover != nil { - p.OnChildRecover(sig.pid, newPid) + p.OnChildRecover(sig.pid, newPID) } } - - return err + return nil } // ListenAndServe serves HTTP requests from the given TCP addr. @@ -406,11 +624,10 @@ func (p *Prefork) ListenAndServe(addr string) error { // ListenAndServeTLS serves HTTPS requests from the given TCP addr. // -// certKey is the path to the TLS private key file. -// certFile is the path to the TLS certificate file. -// -// Note: parameter order is (addr, certKey, certFile) — key before cert. -// Internally forwards to ServeTLSFunc as (certFile, certKey). +// Note: parameter order is (addr, certKey, certFile) — key path comes +// before cert path. This is preserved for backward compatibility with +// existing callers and differs from fasthttp.Server.ListenAndServeTLS. +// New code should prefer ListenAndServeTLSEmbed. func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error { if IsChild() { ln, err := p.listenAsChild(addr) diff --git a/prefork/prefork_test.go b/prefork/prefork_test.go index 37856e5..707de0f 100644 --- a/prefork/prefork_test.go +++ b/prefork/prefork_test.go @@ -3,101 +3,63 @@ package prefork import ( "errors" "fmt" - "math/rand" "net" "os" "os/exec" - "reflect" "runtime" "sync" + "sync/atomic" "testing" "time" "github.com/valyala/fasthttp" ) -func setUp() { - os.Setenv(preforkChildEnvVariable, "1") -} +// noopChildProducer returns a CommandProducer that re-execs the test binary +// into a no-op subprocess. The returned cleanup must be deferred (or registered +// via t.Cleanup) so leaked subprocesses are reaped if the test fails midway. +func noopChildProducer(t testing.TB) (func(files []*os.File) (*exec.Cmd, error), func()) { + t.Helper() + var ( + mu sync.Mutex + spawned []*exec.Cmd + ) -func tearDown() { - os.Unsetenv(preforkChildEnvVariable) -} + produce := func(_ []*os.File) (*exec.Cmd, error) { + cmd := exec.Command(os.Args[0], "-test.run=^$") + cmd.Env = append(os.Environ(), preforkChildEnvVariable+"="+preforkChildEnvValue) + if err := cmd.Start(); err != nil { + return nil, err + } + mu.Lock() + spawned = append(spawned, cmd) + mu.Unlock() + return cmd, nil + } -func getAddr() string { - return fmt.Sprintf("127.0.0.1:%d", rand.Intn(9000-3000)+3000) + cleanup := func() { + mu.Lock() + defer mu.Unlock() + for _, cmd := range spawned { + if cmd == nil || cmd.Process == nil { + continue + } + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + } + } + return produce, cleanup } func Test_IsChild(t *testing.T) { - // This test can't run parallel as it modifies the process environment. - - v := IsChild() - if v { - t.Errorf("IsChild() == %v, want %v", v, false) + // This test cannot run in parallel — IsChild() reads a process-global env var. + if IsChild() { + t.Fatal("test starts as child unexpectedly") } - setUp() - defer tearDown() - - v = IsChild() - if !v { - t.Errorf("IsChild() == %v, want %v", v, true) - } -} - -func Test_New(t *testing.T) { - t.Parallel() - - s := &fasthttp.Server{} - p := New(s) - - if p.Network != defaultNetwork { - t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork) - } - - if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() { - t.Errorf("Prefork.ServeFunc == %p, want %p", p.ServeFunc, s.Serve) - } - - if reflect.ValueOf(p.ServeTLSFunc).Pointer() != reflect.ValueOf(s.ServeTLS).Pointer() { - t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSFunc, s.ServeTLS) - } - - if reflect.ValueOf(p.ServeTLSEmbedFunc).Pointer() != reflect.ValueOf(s.ServeTLSEmbed).Pointer() { - t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSEmbedFunc, s.ServeTLSEmbed) - } -} - -func Test_listen(t *testing.T) { - prev := runtime.GOMAXPROCS(0) - t.Cleanup(func() { - runtime.GOMAXPROCS(prev) - }) - - p := &Prefork{ - Reuseport: true, - } - addr := getAddr() - - ln, err := p.listen(addr) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - ln.Close() - - lnAddr := ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) - } - - if p.Network != defaultNetwork { - t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork) - } - - procs := runtime.GOMAXPROCS(0) - if procs != 1 { - t.Errorf("GOMAXPROCS == %d, want %d", procs, 1) + t.Setenv(preforkChildEnvVariable, preforkChildEnvValue) + if !IsChild() { + t.Errorf("IsChild() == false after Setenv, want true") } } @@ -109,33 +71,98 @@ func Test_setTCPListenerFiles(t *testing.T) { } p := &Prefork{} - addr := getAddr() + addr := "127.0.0.1:0" - err := p.setTCPListenerFiles(addr) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if err := p.setTCPListenerFiles(addr); err != nil { + t.Fatalf("setTCPListenerFiles: %v", err) } + t.Cleanup(func() { + _ = p.ln.Close() + for _, f := range p.files { + _ = f.Close() + } + }) if p.ln == nil { - t.Fatal("Prefork.ln is nil") + t.Fatal("p.ln is nil after setTCPListenerFiles") } - - p.ln.Close() - - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) + if got := p.ln.Addr().String(); got == "" { + t.Error("p.ln.Addr() is empty") } - - if p.Network != defaultNetwork { - t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork) - } - if len(p.files) != 1 { - t.Errorf("Prefork.files == %d, want %d", len(p.files), 1) + t.Errorf("len(p.files) == %d, want 1", len(p.files)) } } +// Test_setTCPListenerFilesClosesListenerOnFileError verifies that +// setTCPListenerFiles closes the bound listener and leaves Prefork.files nil +// when the duplicate-fd step fails. Injects the failure through tcpListenerFile +// so no real socket is leaked. +func Test_setTCPListenerFilesClosesListenerOnFileError(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + fileErr := errors.New("file error") + oldTCPListenerFile := tcpListenerFile + tcpListenerFile = func(*net.TCPListener) (*os.File, error) { + return nil, fileErr + } + t.Cleanup(func() { + tcpListenerFile = oldTCPListenerFile + }) + + p := &Prefork{} + err := p.setTCPListenerFiles("127.0.0.1:0") + if !errors.Is(err, fileErr) { + t.Fatalf("Unexpected error: %v. Expecting %v", err, fileErr) + } + if p.files != nil { + t.Fatalf("Prefork.files = %v, want nil", p.files) + } + // setTCPListenerFiles assigns p.ln only after File() succeeds, so on the + // File() error path the bound listener must have been closed and p.ln left + // nil rather than exposing a half-initialised state. + if p.ln != nil { + t.Fatalf("Prefork.ln = %v, want nil after File error", p.ln) + } +} + +// Test_childEnv verifies the child environment carries exactly one canonical +// prefork marker: a pre-existing marker (any value) is stripped and replaced, +// while an unrelated variable that merely shares the prefix is left untouched. +func Test_childEnv(t *testing.T) { + t.Setenv(preforkChildEnvVariable, "stale") + t.Setenv(preforkChildEnvVariable+"_SIBLING", "keep") + + prefix := preforkChildEnvVariable + "=" + want := prefix + preforkChildEnvValue + + var markerCount, canonicalCount, siblingCount int + for _, kv := range childEnv() { + switch { + case len(kv) >= len(prefix) && kv[:len(prefix)] == prefix: + markerCount++ + if kv == want { + canonicalCount++ + } + case kv == preforkChildEnvVariable+"_SIBLING=keep": + siblingCount++ + } + } + + if markerCount != 1 || canonicalCount != 1 { + t.Fatalf("childEnv() marker entries = %d (canonical %d), want exactly 1 canonical %q", + markerCount, canonicalCount, want) + } + if siblingCount != 1 { + t.Fatalf("childEnv() sibling-prefixed var count = %d, want 1 (must not be stripped)", siblingCount) + } +} + +// Test_preforkClosesParentListenerFiles asserts the master closes the duped +// listener fd it passed to the child after prefork() returns, so the parent +// process does not leak file descriptors across restarts. func Test_preforkClosesParentListenerFiles(t *testing.T) { if runtime.GOOS == "windows" { t.SkipNow() @@ -159,16 +186,16 @@ func Test_preforkClosesParentListenerFiles(t *testing.T) { parentFile = files[0] cmd := exec.Command(os.Args[0], "-test.run=^$") - cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1") + cmd.Env = append(os.Environ(), preforkChildEnvVariable+"="+preforkChildEnvValue) err := cmd.Start() return cmd, err }, - OnChildSpawn: func(pid int) error { + OnChildSpawn: func(_ int) error { return stopErr }, } - if err := p.prefork(getAddr()); !errors.Is(err, stopErr) { + if err := p.prefork("127.0.0.1:0"); !errors.Is(err, stopErr) { t.Fatalf("Unexpected error: %v. Expecting %v", err, stopErr) } if parentFile == nil { @@ -179,157 +206,159 @@ func Test_preforkClosesParentListenerFiles(t *testing.T) { } } -func Test_setTCPListenerFilesClosesListenerOnFileError(t *testing.T) { - if runtime.GOOS == "windows" { - t.SkipNow() +// Test_ListenAndServe_Stub_ChildPath drives the child branch of all three +// ListenAndServe* entry points using a stubbed Serve function. It replaces the +// previous trio of near-identical tests that only validated field assignment. +func Test_ListenAndServe_Stub_ChildPath(t *testing.T) { + // child env mutation precludes t.Parallel. + t.Setenv(preforkChildEnvVariable, preforkChildEnvValue) + + // listen() sets GOMAXPROCS(1) on the child path; restore it so the value + // does not leak into later tests and make their fleet size ordering-dependent. + prevGOMAXPROCS := runtime.GOMAXPROCS(0) + t.Cleanup(func() { runtime.GOMAXPROCS(prevGOMAXPROCS) }) + + type call struct { + listener bool + certFile string + keyFile string + certData string + keyData string } - fileErr := errors.New("file error") - oldTCPListenerFile := tcpListenerFile - tcpListenerFile = func(*net.TCPListener) (*os.File, error) { - return nil, fileErr - } - t.Cleanup(func() { - tcpListenerFile = oldTCPListenerFile - }) - - p := &Prefork{} - err := p.setTCPListenerFiles(getAddr()) - if !errors.Is(err, fileErr) { - t.Fatalf("Unexpected error: %v. Expecting %v", err, fileErr) - } - if p.files != nil { - t.Fatalf("Prefork.files = %v, want nil", p.files) - } - if p.ln == nil { - return + tests := []struct { + name string + run func(t *testing.T, p *Prefork, addr string) error + want call + }{ + { + name: "ListenAndServe", + run: func(_ *testing.T, p *Prefork, addr string) error { return p.ListenAndServe(addr) }, + want: call{listener: true}, + }, + { + name: "ListenAndServeTLS", + run: func(_ *testing.T, p *Prefork, addr string) error { + return p.ListenAndServeTLS(addr, "./key", "./cert") + }, + want: call{listener: true, certFile: "./cert", keyFile: "./key"}, + }, + { + name: "ListenAndServeTLSEmbed", + run: func(_ *testing.T, p *Prefork, addr string) error { + return p.ListenAndServeTLSEmbed(addr, []byte("certPEM"), []byte("keyPEM")) + }, + want: call{listener: true, certData: "certPEM", keyData: "keyPEM"}, + }, } - closeErrCh := make(chan error, 1) - go func() { - closeErrCh <- p.ln.Close() - }() - select { - case err := <-closeErrCh: - if err == nil { - t.Fatal("listener remained open after File error") - } - case <-time.After(time.Second): - t.Fatal("timeout closing listener") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var got call + p := New(&fasthttp.Server{}) + p.Reuseport = true + p.ServeFunc = func(ln net.Listener) error { + got.listener = ln != nil + return nil + } + p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error { + got.listener = ln != nil + got.certFile = certFile + got.keyFile = keyFile + return nil + } + p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error { + got.listener = ln != nil + got.certData = string(certData) + got.keyData = string(keyData) + return nil + } + + addr := "127.0.0.1:0" + if err := tc.run(t, p, addr); err != nil { + t.Fatalf("%s: %v", tc.name, err) + } + t.Cleanup(func() { + if p.ln != nil { + _ = p.ln.Close() + } + }) + + if got != tc.want { + t.Errorf("%s call = %+v, want %+v", tc.name, got, tc.want) + } + }) } } -func Test_ListenAndServe(t *testing.T) { - // This test can't run parallel as it modifies the process environment. +func Test_doCommand_CommandProducerErrors(t *testing.T) { + t.Parallel() - setUp() - defer tearDown() - - s := &fasthttp.Server{} - p := New(s) - p.Reuseport = true - p.ServeFunc = func(ln net.Listener) error { - return nil + producerErr := errors.New("boom") + tests := []struct { + name string + produce func(files []*os.File) (*exec.Cmd, error) + wantErr error + }{ + { + name: "producer returns error", + produce: func([]*os.File) (*exec.Cmd, error) { + return nil, producerErr + }, + wantErr: producerErr, + }, + { + name: "producer returns nil cmd", + //nolint:nilnil // intentionally tests the (nil, nil) misbehaviour guard + produce: func([]*os.File) (*exec.Cmd, error) { + return nil, nil + }, + wantErr: ErrCommandProducerNilCmd, + }, + { + name: "producer returns unstarted cmd", + produce: func([]*os.File) (*exec.Cmd, error) { + return &exec.Cmd{}, nil + }, + wantErr: ErrCommandProducerNotStarted, + }, } - addr := getAddr() - - err := p.ListenAndServe(addr) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - p.ln.Close() - - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) - } - - if p.ln == nil { - t.Error("Prefork.ln is nil") - } -} - -func Test_ListenAndServeTLS(t *testing.T) { - // This test can't run parallel as it modifies the process environment. - - setUp() - defer tearDown() - - s := &fasthttp.Server{} - p := New(s) - p.Reuseport = true - p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error { - return nil - } - - addr := getAddr() - - err := p.ListenAndServeTLS(addr, "./key", "./cert") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - p.ln.Close() - - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) - } - - if p.ln == nil { - t.Error("Prefork.ln is nil") - } -} - -func Test_ListenAndServeTLSEmbed(t *testing.T) { - // This test can't run parallel as it modifies the process environment. - - setUp() - defer tearDown() - - s := &fasthttp.Server{} - p := New(s) - p.Reuseport = true - p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error { - return nil - } - - addr := getAddr() - - err := p.ListenAndServeTLSEmbed(addr, []byte("key"), []byte("cert")) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - p.ln.Close() - - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) - } - - if p.ln == nil { - t.Error("Prefork.ln is nil") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + p := &Prefork{CommandProducer: tc.produce} + cmd, err := p.doCommand() + if cmd != nil { + t.Errorf("expected nil cmd on error, got %v", cmd) + } + if err == nil { + t.Fatal("expected error, got nil") + } + if tc.wantErr != nil && !errors.Is(err, tc.wantErr) { + t.Errorf("err = %v, want errors.Is %v", err, tc.wantErr) + } + }) } } type testLogger struct { + mu sync.Mutex messages []string } func (l *testLogger) Printf(format string, args ...any) { - l.messages = append(l.messages, fmt.Sprintf(format, args...)) + msg := fmt.Sprintf(format, args...) + l.mu.Lock() + l.messages = append(l.messages, msg) + l.mu.Unlock() } -// Test_Prefork_Lifecycle runs the full prefork lifecycle with a CommandProducer -// and verifies that callbacks are invoked in the correct order with the correct arguments. +// Test_Prefork_Lifecycle drives prefork() to ErrOverRecovery via +// short-lived no-op children and asserts the callback ordering / arguments. func Test_Prefork_Lifecycle(t *testing.T) { prev := runtime.GOMAXPROCS(2) - t.Cleanup(func() { - runtime.GOMAXPROCS(prev) - }) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) type event struct { name string @@ -344,16 +373,14 @@ func Test_Prefork_Lifecycle(t *testing.T) { mu.Unlock() } + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + p := &Prefork{ Reuseport: true, RecoverThreshold: 1, Logger: &testLogger{}, - CommandProducer: func(_ []*os.File) (*exec.Cmd, error) { - cmd := exec.Command(os.Args[0], "-test.run=^$") - cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1") - err := cmd.Start() - return cmd, err - }, + CommandProducer: produce, OnChildSpawn: func(pid int) error { record("spawn", pid) return nil @@ -362,12 +389,12 @@ func Test_Prefork_Lifecycle(t *testing.T) { record("ready", childPIDs...) return nil }, - OnChildRecover: func(oldPid, newPid int) { - record("recover", oldPid, newPid) + OnChildRecover: func(oldPID, newPID int) { + record("recover", oldPID, newPID) }, } - err := p.prefork(getAddr()) + err := p.prefork("127.0.0.1:0") if !errors.Is(err, ErrOverRecovery) { t.Fatalf("expected ErrOverRecovery, got: %v", err) } @@ -375,10 +402,9 @@ func Test_Prefork_Lifecycle(t *testing.T) { mu.Lock() defer mu.Unlock() - // Verify we got spawn events for initial children - var spawnCount int - var readyCount int - var recoverCount int + goMaxProcs := runtime.GOMAXPROCS(0) + + var spawnCount, readyCount, recoverCount int for _, e := range events { switch e.name { case "spawn": @@ -402,21 +428,17 @@ func Test_Prefork_Lifecycle(t *testing.T) { } } - goMaxProcs := runtime.GOMAXPROCS(0) - if readyCount != 1 { t.Errorf("OnMasterReady called %d times, want 1", readyCount) } - - // Initial spawns + at least one recovery spawn if spawnCount < goMaxProcs { t.Errorf("OnChildSpawn called %d times, want at least %d", spawnCount, goMaxProcs) } - if recoverCount == 0 { t.Error("OnChildRecover was never called") } + // ready must come after exactly goMaxProcs initial spawns. readyIdx := -1 spawnsBeforeReady := 0 for i, e := range events { @@ -428,7 +450,6 @@ func Test_Prefork_Lifecycle(t *testing.T) { spawnsBeforeReady++ } } - if readyIdx == -1 { t.Fatal("OnMasterReady was never called") } @@ -436,8 +457,8 @@ func Test_Prefork_Lifecycle(t *testing.T) { t.Errorf("OnMasterReady called after %d initial spawns, want %d", spawnsBeforeReady, goMaxProcs) } + // every recover event must be preceded by a spawn for the new PID. recoveredSpawnByPID := make(map[int]bool) - recoveredPIDs := make(map[int]bool) for _, e := range events[readyIdx+1:] { if e.name == "spawn" { recoveredSpawnByPID[e.pids[0]] = true @@ -446,56 +467,171 @@ func Test_Prefork_Lifecycle(t *testing.T) { if !recoveredSpawnByPID[e.pids[1]] { t.Errorf("OnChildRecover for PID %d happened before OnChildSpawn", e.pids[1]) } - recoveredPIDs[e.pids[1]] = true - } - } - for pid := range recoveredPIDs { - if !recoveredSpawnByPID[pid] { - t.Errorf("OnChildRecover for PID %d did not have a matching OnChildSpawn", pid) } } } -func Test_Prefork_RecoveredChildSpawnError(t *testing.T) { +func Test_Prefork_InitialChildSpawnError(t *testing.T) { prev := runtime.GOMAXPROCS(2) - t.Cleanup(func() { - runtime.GOMAXPROCS(prev) - }) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) - expectedErr := errors.New("spawn failed") - var spawnCount int - var recoverCount int + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + expectedErr := errors.New("initial spawn rejected") + var calls atomic.Int32 p := &Prefork{ Reuseport: true, RecoverThreshold: 1, Logger: &testLogger{}, - CommandProducer: func(_ []*os.File) (*exec.Cmd, error) { - cmd := exec.Command(os.Args[0], "-test.run=^$") - cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1") - err := cmd.Start() - return cmd, err + CommandProducer: produce, + OnChildSpawn: func(_ int) error { + calls.Add(1) + return expectedErr }, + } + + err := p.prefork("127.0.0.1:0") + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got: %v", expectedErr, err) + } + if calls.Load() == 0 { + t.Fatal("OnChildSpawn was never invoked") + } +} + +func Test_Prefork_OnMasterReadyError(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + expectedErr := errors.New("ready rejected") + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + Logger: &testLogger{}, + CommandProducer: produce, + OnMasterReady: func([]int) error { + return expectedErr + }, + } + + err := p.prefork("127.0.0.1:0") + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got: %v", expectedErr, err) + } +} + +func Test_Prefork_RecoveredChildSpawnError(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + expectedErr := errors.New("spawn failed") + var spawnCount, recoverCount atomic.Int32 + + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + Logger: &testLogger{}, + CommandProducer: produce, OnChildSpawn: func(pid int) error { if pid <= 0 { t.Errorf("OnChildSpawn called with invalid PID: %d", pid) } - spawnCount++ - if spawnCount > runtime.GOMAXPROCS(0) { + n := spawnCount.Add(1) + if int(n) > runtime.GOMAXPROCS(0) { return expectedErr } return nil }, OnChildRecover: func(_, _ int) { - recoverCount++ + recoverCount.Add(1) }, } - err := p.prefork(getAddr()) + err := p.prefork("127.0.0.1:0") if !errors.Is(err, expectedErr) { t.Fatalf("expected %v, got: %v", expectedErr, err) } - if recoverCount != 0 { - t.Fatalf("OnChildRecover called %d times, want 0", recoverCount) + if got := recoverCount.Load(); got != 0 { + t.Fatalf("OnChildRecover called %d times, want 0", got) + } +} + +// Test_Prefork_RecoverInterval verifies the optional backoff delays the respawn. +func Test_Prefork_RecoverInterval(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + const interval = 50 * time.Millisecond + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + RecoverInterval: interval, + Logger: &testLogger{}, + CommandProducer: produce, + } + + start := time.Now() + err := p.prefork("127.0.0.1:0") + elapsed := time.Since(start) + + if !errors.Is(err, ErrOverRecovery) { + t.Fatalf("expected ErrOverRecovery, got %v", err) + } + // At least one recover interval must have elapsed before threshold fired. + if elapsed < interval { + t.Errorf("elapsed %v < interval %v; backoff did not apply", elapsed, interval) + } +} + +// Test_Prefork_ShutdownDoesNotBlockOnRecoverInterval guards against a child +// that already exited (and whose Wait goroutine is parked on the RecoverInterval +// backoff) holding up shutdown. prefork() returns immediately via an +// OnMasterReady error while the noop children are sitting in a long backoff; +// shutdown must cancel the backoff and return promptly instead of waiting for +// RecoverInterval or ShutdownGracePeriod to elapse. +func Test_Prefork_ShutdownDoesNotBlockOnRecoverInterval(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + const longWait = 10 * time.Second + expectedErr := errors.New("ready rejected") + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + RecoverInterval: longWait, + ShutdownGracePeriod: longWait, + Logger: &testLogger{}, + CommandProducer: produce, + OnMasterReady: func([]int) error { + return expectedErr + }, + } + + start := time.Now() + err := p.prefork("127.0.0.1:0") + elapsed := time.Since(start) + + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got: %v", expectedErr, err) + } + // With the backoff cancelled up front, shutdown is near-instant. Allow a + // generous margin for subprocess spawn/teardown but well below longWait so a + // regression (blocking on RecoverInterval / ShutdownGracePeriod) is caught. + if elapsed >= longWait/2 { + t.Fatalf("shutdown took %v; it blocked on the RecoverInterval backoff", elapsed) } }