From 481e579af9e7d79f9ce27909edd2c42ef9dce173 Mon Sep 17 00:00:00 2001 From: RW Date: Sat, 2 May 2026 13:17:27 +0200 Subject: [PATCH] feat(prefork): Enhance prefork management with WatchMaster, CommandProducer, and Windows support (#2180) * feat(prefork): add WatchMaster and callback support for child process management * feat(prefork): add CommandProducer for customizable child process commands * refactor(prefork): improve comments and parameter order in ListenAndServeTLS * refactor(prefork): enhance logging message and clarify OnChildRecover callback comment * fix(prefork): add Windows support to watchMaster On Windows, os.Getppid() returns a static PID that doesn't change when the parent exits (no reparenting). Use FindProcess+Wait instead, which correctly detects parent exit. Also document why masterPID comparison works for Docker containers (master PID 1 case). Co-Authored-By: Claude Opus 4.6 (1M context) * refactor(prefork): extract listenAsChild to eliminate DRY violation The three ListenAndServe* methods had identical child setup code (listen, set ln, watch master). Extract to listenAsChild() for cleaner code. Also add comment for the magic file descriptor number 3. Co-Authored-By: Claude Opus 4.6 (1M context) * fix(prefork): restore upstream ListenAndServeTLS parameter order Keep upstream's (addr, certKey, certFile) signature to avoid breaking callers. Fix the doc comment to match the actual parameter order instead. Co-Authored-By: Claude Opus 4.6 (1M context) * fix(prefork): address lint errors and review feedback Lint fixes: - Remove unused Reuseport field write in test (govet/unusedwrite) - Replace fmt.Errorf with errors.New for static errors (perfsprint) Review feedback (Copilot): - Validate CommandProducer returns a started command (nil/Process check) - Clarify ListenAndServeTLS doc: parameter order and internal forwarding - Use hermetic test binary re-exec instead of external 'go' binary - Rename misleading test to reflect what it actually asserts Co-Authored-By: Claude Opus 4.6 (1M context) * refactor(prefork): address maintainer review feedback - watchMaster: log errors from FindProcess/Wait instead of swallowing - watchMaster: don't call OnMasterDeath if FindProcess fails - OnChildRecover: change signature to func(pid int), drop unused error return - OnChildSpawn: add comment clarifying deferred cleanup handles the child - CommandProducer: improve docs describing contract and use cases Co-Authored-By: Claude Opus 4.6 (1M context) * refactor(prefork): address erikdubbelboer review feedback - OnChildRecover: signature changed to func(oldPid, newPid int) so callers can track which process was replaced - OnChildSpawn: also called for recovered children (a recovered child is still a spawned child) - watchMaster: call OnMasterDeath when FindProcess fails (process is most likely gone) - CommandProducer: document that FASTHTTP_PREFORK_CHILD=1 must be set in the child env, and what the default does when nil Co-Authored-By: Claude Opus 4.6 (1M context) * fix(prefork): avoid zombie processes and replace shallow tests - Move Wait() goroutine before OnChildSpawn so Kill()+Wait() works correctly if a callback fails and the deferred cleanup runs - Add Wait() call in deferred cleanup after Kill() to reap children - Same fix in recovery loop - Remove shallow callback tests that only tested Go compiler - Add Test_Prefork_Lifecycle: runs full prefork with CommandProducer, verifies callbacks fire in correct order with correct arguments Co-Authored-By: Claude Opus 4.6 (1M context) * fix(prefork): ensure recovery default stays positive * test(prefork): isolate lifecycle tests * fix(prefork): tighten recovery callback flow --------- Co-authored-by: Claude Opus 4.6 (1M context) --- prefork/prefork.go | 525 +++++++++++++++++++++++++++------- prefork/prefork_test.go | 613 +++++++++++++++++++++++++++++++--------- 2 files changed, 900 insertions(+), 238 deletions(-) diff --git a/prefork/prefork.go b/prefork/prefork.go index d883640..1964ddf 100644 --- a/prefork/prefork.go +++ b/prefork/prefork.go @@ -2,12 +2,17 @@ package prefork import ( + "context" "errors" + "fmt" "log" "net" "os" "os/exec" + "os/signal" "runtime" + "sync" + "syscall" "time" "github.com/valyala/fasthttp" @@ -16,25 +21,53 @@ import ( const ( preforkChildEnvVariable = "FASTHTTP_PREFORK_CHILD" + preforkChildEnvValue = "1" defaultNetwork = "tcp4" + + // inheritedListenerFD is the file descriptor used by the master to pass + // the bound listener to a child process via ExtraFiles. Children open the + // listener via os.NewFile(inheritedListenerFD, ...) when Reuseport is false. + inheritedListenerFD = 3 + + // masterPollInterval is the period of the watchMaster ppid-poll on Unix. + masterPollInterval = 500 * time.Millisecond + + // defaultShutdownGracePeriod is how long the master waits for children to + // exit cleanly after sending SIGTERM before forcibly killing them. + defaultShutdownGracePeriod = 5 * time.Second ) var ( defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags)) - // ErrOverRecovery is returned when the times of starting over child prefork processes exceed - // the threshold. + + // ErrOverRecovery is returned when child prefork process restarts exceed + // the value of RecoverThreshold. ErrOverRecovery = errors.New("exceeding the value of RecoverThreshold") - // ErrOnlyReuseportOnWindows is returned when Reuseport is false. + // ErrOnlyReuseportOnWindows is returned when running on Windows without Reuseport. ErrOnlyReuseportOnWindows = errors.New("windows only supports Reuseport = true") + + // ErrCommandProducerNilCmd is returned when a CommandProducer returns + // (nil, nil) instead of a started command. + ErrCommandProducerNilCmd = errors.New("prefork: CommandProducer returned nil command") + + // ErrCommandProducerNotStarted is returned when a CommandProducer returns + // an *exec.Cmd whose Process is nil (i.e. cmd.Start() was not called). + ErrCommandProducerNotStarted = errors.New("prefork: CommandProducer must return a started command") ) -// Logger is used for logging formatted messages. +// Logger is used for logging formatted messages. Its method set is intentionally +// identical to fasthttp.Logger so that *fasthttp.Server.Logger can be assigned +// directly. type Logger interface { // Printf must have the same semantics as log.Printf. Printf(format string, args ...any) } +// Compile-time check that fasthttp.Logger satisfies the local Logger interface; +// keeps the two types in sync if either side ever evolves. +var _ Logger = fasthttp.Logger(nil) + // Prefork implements fasthttp server prefork. // // Preforks master process (with all cores) between several child processes @@ -44,7 +77,8 @@ type Logger interface { // WARNING: using prefork prevents the use of any global state! // Things like in-memory caches won't work. type Prefork struct { - // By default standard logger from log package is used. + // Logger receives diagnostic output. By default the standard log package + // logger writing to stderr is used. Logger Logger ln net.Listener @@ -53,21 +87,29 @@ type Prefork struct { ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error ServeTLSEmbedFunc func(ln net.Listener, certData, keyData []byte) error - // The network must be "tcp", "tcp4" or "tcp6". - // - // By default is "tcp4" + // Network must be "tcp", "tcp4" or "tcp6". Default is "tcp4". Network string files []*os.File - // Child prefork processes may exit with failure and will be started over until the times reach - // the value of RecoverThreshold, then it will return and terminate the server. + // RecoverThreshold caps how often crashed children are respawned before + // the master returns ErrOverRecovery. New() sets it to max(1, GOMAXPROCS/2). + // When constructing a Prefork directly without New(), a zero value will + // terminate the master after the very first child crash. RecoverThreshold int - // Flag to use a listener with reuseport, if not a file Listener will be used + // RecoverInterval, when > 0, makes the master sleep for the given duration + // before respawning a crashed child. Useful as crash-loop backoff. + RecoverInterval time.Duration + + // ShutdownGracePeriod is the time the master waits for children to exit + // after sending SIGTERM before falling back to SIGKILL. Defaults to 5s + // when zero. On Windows SIGTERM is not delivered, so this is unused there. + ShutdownGracePeriod time.Duration + + // Reuseport selects a reuseport listener instead of fd-passing. // See: https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/ - // - // It's disabled by default + // Disabled by default. Reuseport bool // OnMasterDeath, when non-nil, enables monitoring of the master process @@ -76,19 +118,75 @@ type Prefork struct { // // It is recommended to set this to func() { os.Exit(1) } if no custom // cleanup is needed. + // + // Threading: invoked once from a watcher goroutine in the child. Must not + // block the goroutine for long; must not call Prefork methods. OnMasterDeath func() + + // OnChildSpawn is called in the master after a new child process is + // successfully started, both during initial spawn and during recovery. + // It receives the PID of the newly spawned child process. + // + // If this callback returns an error, all already-running children are + // killed and the prefork operation returns that error. + // + // Threading: invoked synchronously from the master goroutine. Must not + // block; must not call Prefork methods. Panics are recovered and surfaced + // as the returned error. + OnChildSpawn func(pid int) error + + // OnMasterReady is called in the master process exactly once, after all + // initial children have been spawned and before the supervision loop runs. + // It receives a slice of all initial child PIDs. + // + // If this callback returns an error, the prefork operation aborts and + // returns that error after killing the children. + // + // Threading: invoked synchronously from the master goroutine. The slice + // is owned by the caller after the call returns. Panics are recovered. + OnMasterReady func(childPIDs []int) error + + // OnChildRecover is called in the master after a crashed child has been + // replaced. It receives the PID of the old (crashed) process and the PID + // of its replacement. + // + // If this callback returns an error, all running children are killed and + // the prefork operation returns that error. + // + // Threading: invoked synchronously from the master goroutine, after + // OnChildSpawn for the new child. Panics are recovered and surfaced as + // the returned error. + OnChildRecover func(oldPID, newPID int) error + + // CommandProducer creates and starts a child process command. + // If nil, the default implementation re-executes the current binary + // with FASTHTTP_PREFORK_CHILD=1 in the environment, stdout/stderr + // inherited from the parent, and the given files as ExtraFiles. + // + // A custom producer must: + // - Set FASTHTTP_PREFORK_CHILD=1 in the child's environment + // (otherwise IsChild() returns false and the child won't serve) + // - Call cmd.Start() before returning (the returned command must be + // started so cmd.Process is non-nil) + // - Pass the provided files as cmd.ExtraFiles when Reuseport is false + // + // Primarily useful for testing (injecting dummy commands) or for + // frameworks that need custom child process setup. + CommandProducer func(files []*os.File) (*exec.Cmd, error) } -// IsChild checks if the current thread/process is a child. +// IsChild reports whether the current process is a prefork child. func IsChild() bool { - return os.Getenv(preforkChildEnvVariable) == "1" + return os.Getenv(preforkChildEnvVariable) == preforkChildEnvValue } // New wraps the fasthttp server to run with preforked processes. +// It seeds Network and RecoverThreshold to sensible defaults; existing +// fields on s (Logger, Serve*) are captured. func New(s *fasthttp.Server) *Prefork { return &Prefork{ Network: defaultNetwork, - RecoverThreshold: runtime.GOMAXPROCS(0) / 2, + RecoverThreshold: defaultRecoverThreshold(), Logger: s.Logger, ServeFunc: s.Serve, ServeTLSFunc: s.ServeTLS, @@ -96,6 +194,10 @@ func New(s *fasthttp.Server) *Prefork { } } +func defaultRecoverThreshold() int { + return max(1, runtime.GOMAXPROCS(0)/2) +} + func (p *Prefork) logger() Logger { if p.Logger != nil { return p.Logger @@ -103,13 +205,46 @@ func (p *Prefork) logger() Logger { return defaultLogger } +// invokeHook runs fn under a panic recovery, returning the panic as an error +// so a misbehaving callback never tears down the master. +func (p *Prefork) invokeHook(name string, fn func() error) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("prefork: %s panicked: %v", name, r) + p.logger().Printf("%v", err) + } + }() + return fn() +} + func (p *Prefork) watchMaster(masterPID int) { - ticker := time.NewTicker(500 * time.Millisecond) + if runtime.GOOS == "windows" { + // On Windows, os.Getppid() returns a static PID that doesn't change + // when the parent exits (no reparenting). Use FindProcess+Wait instead. + proc, err := os.FindProcess(masterPID) + if err != nil { + p.logger().Printf("watchMaster: failed to find master process %d: %v", masterPID, err) + p.OnMasterDeath() + return + } + if _, err = proc.Wait(); err != nil { + p.logger().Printf("watchMaster: error waiting for master process %d: %v", masterPID, err) + } + p.logger().Printf("master process %d died", masterPID) + p.OnMasterDeath() + return + } + + // Unix/Linux/macOS: When the master exits, the OS reparents the child + // to another process, causing Getppid() to change. Comparing against + // the original masterPID (instead of hardcoding 1) ensures this works + // correctly when the master itself is PID 1 (e.g. in Docker containers). + ticker := time.NewTicker(masterPollInterval) defer ticker.Stop() for range ticker.C { if os.Getppid() != masterPID { - p.logger().Printf("master process died\n") + p.logger().Printf("master process %d died", masterPID) p.OnMasterDeath() return } @@ -127,7 +262,27 @@ func (p *Prefork) listen(addr string) (net.Listener, error) { return reuseport.Listen(p.Network, addr) } - return net.FileListener(os.NewFile(3, "")) + // fd inheritedListenerFD is the first ExtraFiles entry passed by the + // master process when Reuseport is false. Naming the file gives clearer + // errors from net.FileListener if the fd is invalid. + return net.FileListener(os.NewFile(inheritedListenerFD, "fasthttp-prefork-listener")) +} + +// listenAsChild performs the common child process setup: creates the listener +// and starts watching the master process if OnMasterDeath is configured. +func (p *Prefork) listenAsChild(addr string) (net.Listener, error) { + ln, err := p.listen(addr) + if err != nil { + return nil, err + } + + p.ln = ln + + if p.OnMasterDeath != nil { + go p.watchMaster(os.Getppid()) + } + + return ln, nil } func (p *Prefork) setTCPListenerFiles(addr string) error { @@ -137,49 +292,142 @@ func (p *Prefork) setTCPListenerFiles(addr string) error { tcpAddr, err := net.ResolveTCPAddr(p.Network, addr) if err != nil { - return err + return fmt.Errorf("prefork: resolve %s/%s: %w", p.Network, addr, err) } - tcplistener, err := net.ListenTCP(p.Network, tcpAddr) + tcpListener, err := net.ListenTCP(p.Network, tcpAddr) if err != nil { - return err + return fmt.Errorf("prefork: listen tcp %s: %w", addr, err) } - p.ln = tcplistener + p.ln = tcpListener - fl, err := tcplistener.File() + listenerFile, err := tcpListener.File() if err != nil { - return err + return fmt.Errorf("prefork: dup listener fd: %w", err) } - p.files = []*os.File{fl} + p.files = []*os.File{listenerFile} return nil } +// childEnv returns os.Environ() with the prefork child marker variable set, +// stripping any pre-existing value to avoid duplicate keys with last-wins +// semantics. +func childEnv() []string { + src := os.Environ() + out := make([]string, 0, len(src)+1) + prefix := preforkChildEnvVariable + "=" + for _, kv := range src { + if len(kv) >= len(prefix) && kv[:len(prefix)] == prefix { + continue + } + out = append(out, kv) + } + out = append(out, preforkChildEnvVariable+"="+preforkChildEnvValue) + return out +} + func (p *Prefork) doCommand() (*exec.Cmd, error) { - executable, err := os.Executable() - if err != nil { - return nil, err + if p.CommandProducer != nil { + cmd, err := p.CommandProducer(p.files) + if err != nil { + return nil, fmt.Errorf("prefork: CommandProducer: %w", err) + } + if cmd == nil { + return nil, ErrCommandProducerNilCmd + } + if cmd.Process == nil { + return nil, ErrCommandProducerNotStarted + } + return cmd, nil } - args := make([]string, len(os.Args)) - args[0] = executable - copy(args[1:], os.Args[1:]) + executable, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("prefork: resolve executable: %w", err) + } + + args := append([]string{executable}, os.Args[1:]...) cmd := &exec.Cmd{ Path: executable, Args: args, Stdout: os.Stdout, Stderr: os.Stderr, - Env: append(os.Environ(), preforkChildEnvVariable+"=1"), + Env: childEnv(), ExtraFiles: p.files, } - err = cmd.Start() - return cmd, err + if err = cmd.Start(); err != nil { + return nil, fmt.Errorf("prefork: start child %q: %w", executable, err) + } + return cmd, nil } -func (p *Prefork) prefork(addr string) (err error) { +type childExit struct { + err error + pid int +} + +// shutdownChildren signals every entry in childProcs first with SIGTERM (on +// platforms where it is supported) and waits up to grace for them to exit. +// Survivors are then killed unconditionally. wg tracks the per-child Wait +// goroutines and must be drained before returning so no goroutine outlives +// prefork(). +func (p *Prefork) shutdownChildren( + childProcs map[int]*exec.Cmd, + wg *sync.WaitGroup, + cancel context.CancelFunc, + grace time.Duration, +) { + if grace <= 0 { + grace = defaultShutdownGracePeriod + } + + if runtime.GOOS != "windows" { + for pid, proc := range childProcs { + if proc == nil || proc.Process == nil { + continue + } + if termErr := proc.Process.Signal(syscall.SIGTERM); termErr != nil && + !errors.Is(termErr, os.ErrProcessDone) { + p.logger().Printf("prefork: SIGTERM child %d: %v", pid, termErr) + } + } + } + + // Wait for graceful exits, with a timeout fallback to SIGKILL. + graceful := make(chan struct{}) + go func() { + wg.Wait() + close(graceful) + }() + + timer := time.NewTimer(grace) + defer timer.Stop() + select { + case <-graceful: + case <-timer.C: + } + + for pid, proc := range childProcs { + if proc == nil || proc.Process == nil { + continue + } + if killErr := proc.Process.Kill(); killErr != nil && + !errors.Is(killErr, os.ErrProcessDone) { + p.logger().Printf("prefork: kill child %d: %v", pid, killErr) + } + } + + // Cancel the per-Wait goroutines' send-context so any still blocked on + // sigCh send unblock cleanly, then wait for all of them to exit. + cancel() + wg.Wait() +} + +func (p *Prefork) prefork(addr string) (err error) { //nolint:gocyclo if !p.Reuseport { if runtime.GOOS == "windows" { return ErrOnlyReuseportOnWindows @@ -189,86 +437,166 @@ func (p *Prefork) prefork(addr string) (err error) { return err } - // defer for closing the net.Listener opened by setTCPListenerFiles. + // Close listener fds opened by setTCPListenerFiles. Both the original + // tcpListener (p.ln) and the duped fd (p.files[0]) belong to the + // master only; children inherit independent dup'd copies via fork+exec. defer func() { - e := p.ln.Close() - if err == nil { - err = e + err = errors.Join(err, p.ln.Close()) + for _, f := range p.files { + if closeErr := f.Close(); closeErr != nil { + p.logger().Printf("prefork: close listener fd: %v", closeErr) + } + } + p.files = nil + }() + } + + // ctx cancels per-child Wait goroutines so they unblock from sigCh sends + // once the supervision loop is gone. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Catch SIGTERM/SIGINT in the master so we run our shutdown path instead + // of being killed by the OS without children getting a graceful chance. + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(signalCh) + + goMaxProcs := runtime.GOMAXPROCS(0) + // Buffer is sized to initial fleet; per-child goroutines fall back to a + // context-aware select on send so capacity is not load-bearing. + sigCh := make(chan childExit, goMaxProcs) + childProcs := make(map[int]*exec.Cmd, goMaxProcs) + + var wg sync.WaitGroup + startWait := func(cmd *exec.Cmd, pid int) { + wg.Add(1) + go func() { + defer wg.Done() + result := childExit{pid: pid, err: cmd.Wait()} + select { + case sigCh <- result: + case <-ctx.Done(): } }() } - type procSig struct { - err error - pid int - } - - goMaxProcs := runtime.GOMAXPROCS(0) - sigCh := make(chan procSig, goMaxProcs) - childProcs := make(map[int]*exec.Cmd) - defer func() { - for _, proc := range childProcs { - _ = proc.Process.Kill() - } + p.shutdownChildren(childProcs, &wg, cancel, p.ShutdownGracePeriod) }() + childPIDs := make([]int, 0, goMaxProcs) + for range goMaxProcs { var cmd *exec.Cmd if cmd, err = p.doCommand(); err != nil { - p.logger().Printf("failed to start a child prefork process, error: %v\n", err) + p.logger().Printf("prefork: failed to start a child process: %v", err) return err } - childProcs[cmd.Process.Pid] = cmd - go func() { - sigCh <- procSig{pid: cmd.Process.Pid, err: cmd.Wait()} - }() + pid := cmd.Process.Pid + childProcs[pid] = cmd + childPIDs = append(childPIDs, pid) + + // Start Wait goroutine before the user callback so a panic / error + // return from OnChildSpawn cannot leave a zombie behind. + startWait(cmd, pid) + + if p.OnChildSpawn != nil { + pid := pid + if hookErr := p.invokeHook("OnChildSpawn", func() error { + return p.OnChildSpawn(pid) + }); hookErr != nil { + p.logger().Printf("prefork: OnChildSpawn for PID %d: %v", pid, hookErr) + return hookErr + } + } + } + + if p.OnMasterReady != nil { + pids := append([]int(nil), childPIDs...) + if hookErr := p.invokeHook("OnMasterReady", func() error { + return p.OnMasterReady(pids) + }); hookErr != nil { + p.logger().Printf("prefork: OnMasterReady: %v", hookErr) + return hookErr + } } var exitedProcs int - for sig := range sigCh { - delete(childProcs, sig.pid) + for { + select { + case sig := <-signalCh: + p.logger().Printf("prefork: received signal %v, shutting down", sig) + return nil - p.logger().Printf("one of the child prefork processes exited with "+ - "error: %v", sig.err) + case sig := <-sigCh: + delete(childProcs, sig.pid) - exitedProcs++ - if exitedProcs > p.RecoverThreshold { - p.logger().Printf("child prefork processes exit too many times, "+ - "which exceeds the value of RecoverThreshold(%d), "+ - "exiting the master process.\n", exitedProcs) - err = ErrOverRecovery - break + if sig.err != nil { + p.logger().Printf("prefork: child PID %d exited: %v", sig.pid, sig.err) + } else { + p.logger().Printf("prefork: child PID %d exited cleanly", sig.pid) + } + + exitedProcs++ + if exitedProcs > p.RecoverThreshold { + p.logger().Printf( + "prefork: child exits (%d) exceed RecoverThreshold (%d), terminating master", + exitedProcs, p.RecoverThreshold, + ) + return ErrOverRecovery + } + + if p.RecoverInterval > 0 { + select { + case <-time.After(p.RecoverInterval): + case <-signalCh: + return nil + } + } + + cmd, doErr := p.doCommand() + if doErr != nil { + p.logger().Printf("prefork: recovery doCommand: %v", doErr) + return doErr + } + newPID := cmd.Process.Pid + childProcs[newPID] = cmd + + startWait(cmd, newPID) + + if p.OnChildSpawn != nil { + newPID := newPID + if hookErr := p.invokeHook("OnChildSpawn", func() error { + return p.OnChildSpawn(newPID) + }); hookErr != nil { + p.logger().Printf("prefork: OnChildSpawn for recovered PID %d: %v", newPID, hookErr) + return hookErr + } + } + + if p.OnChildRecover != nil { + oldPID := sig.pid + newPID := newPID + if hookErr := p.invokeHook("OnChildRecover", func() error { + return p.OnChildRecover(oldPID, newPID) + }); hookErr != nil { + p.logger().Printf("prefork: OnChildRecover (%d -> %d): %v", oldPID, newPID, hookErr) + return hookErr + } + } } - - var cmd *exec.Cmd - if cmd, err = p.doCommand(); err != nil { - break - } - childProcs[cmd.Process.Pid] = cmd - go func() { - sigCh <- procSig{pid: cmd.Process.Pid, err: cmd.Wait()} - }() } - - return err } // ListenAndServe serves HTTP requests from the given TCP addr. func (p *Prefork) ListenAndServe(addr string) error { if IsChild() { - ln, err := p.listen(addr) + ln, err := p.listenAsChild(addr) if err != nil { return err } - - p.ln = ln - - if p.OnMasterDeath != nil { - go p.watchMaster(os.Getppid()) - } - return p.ServeFunc(ln) } @@ -277,20 +605,16 @@ func (p *Prefork) ListenAndServe(addr string) error { // ListenAndServeTLS serves HTTPS requests from the given TCP addr. // -// certFile and keyFile are paths to TLS certificate and key files. +// Note: parameter order is (addr, certKey, certFile) — key path comes +// before cert path. This is preserved for backward compatibility with +// existing callers and differs from fasthttp.Server.ListenAndServeTLS. +// New code should prefer ListenAndServeTLSEmbed. func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error { if IsChild() { - ln, err := p.listen(addr) + ln, err := p.listenAsChild(addr) if err != nil { return err } - - p.ln = ln - - if p.OnMasterDeath != nil { - go p.watchMaster(os.Getppid()) - } - return p.ServeTLSFunc(ln, certFile, certKey) } @@ -302,17 +626,10 @@ func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error { // certData and keyData must contain valid TLS certificate and key data. func (p *Prefork) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error { if IsChild() { - ln, err := p.listen(addr) + ln, err := p.listenAsChild(addr) if err != nil { return err } - - p.ln = ln - - if p.OnMasterDeath != nil { - go p.watchMaster(os.Getppid()) - } - return p.ServeTLSEmbedFunc(ln, certData, keyData) } diff --git a/prefork/prefork_test.go b/prefork/prefork_test.go index 8236e12..004450b 100644 --- a/prefork/prefork_test.go +++ b/prefork/prefork_test.go @@ -1,43 +1,80 @@ package prefork import ( + "errors" "fmt" - "math/rand" "net" "os" - "reflect" + "os/exec" "runtime" + "sync" + "sync/atomic" "testing" + "time" "github.com/valyala/fasthttp" ) -func setUp() { - os.Setenv(preforkChildEnvVariable, "1") +// freeAddr returns a TCP4 address that the OS believes is free at call time. +// It avoids the rand-port collision flakes the previous helper exhibited. +func freeAddr(t testing.TB) string { + t.Helper() + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("freeAddr: %v", err) + } + addr := ln.Addr().String() + if closeErr := ln.Close(); closeErr != nil { + t.Fatalf("freeAddr: close: %v", closeErr) + } + return addr } -func tearDown() { - os.Unsetenv(preforkChildEnvVariable) -} +// noopChildProducer returns a CommandProducer that re-execs the test binary +// into a no-op subprocess. The returned cleanup must be deferred (or registered +// via t.Cleanup) so leaked subprocesses are reaped if the test fails midway. +func noopChildProducer(t testing.TB) (func(files []*os.File) (*exec.Cmd, error), func()) { + t.Helper() + var ( + mu sync.Mutex + spawned []*exec.Cmd + ) -func getAddr() string { - return fmt.Sprintf("127.0.0.1:%d", rand.Intn(9000-3000)+3000) + produce := func(_ []*os.File) (*exec.Cmd, error) { + cmd := exec.Command(os.Args[0], "-test.run=^$") + cmd.Env = append(os.Environ(), preforkChildEnvVariable+"="+preforkChildEnvValue) + if err := cmd.Start(); err != nil { + return nil, err + } + mu.Lock() + spawned = append(spawned, cmd) + mu.Unlock() + return cmd, nil + } + + cleanup := func() { + mu.Lock() + defer mu.Unlock() + for _, cmd := range spawned { + if cmd == nil || cmd.Process == nil { + continue + } + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + } + } + return produce, cleanup } func Test_IsChild(t *testing.T) { - // This test can't run parallel as it modifies os.Args. - - v := IsChild() - if v { - t.Errorf("IsChild() == %v, want %v", v, false) + // This test cannot run in parallel — IsChild() reads a process-global env var. + if IsChild() { + t.Fatal("test starts as child unexpectedly") } - setUp() - defer tearDown() - - v = IsChild() - if !v { - t.Errorf("IsChild() == %v, want %v", v, true) + t.Setenv(preforkChildEnvVariable, preforkChildEnvValue) + if !IsChild() { + t.Errorf("IsChild() == false after Setenv, want true") } } @@ -50,48 +87,37 @@ func Test_New(t *testing.T) { if p.Network != defaultNetwork { t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork) } - - if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() { - t.Errorf("Prefork.ServeFunc == %p, want %p", p.ServeFunc, s.Serve) + if p.RecoverThreshold <= 0 { + t.Errorf("Prefork.RecoverThreshold == %d, want > 0", p.RecoverThreshold) } - - if reflect.ValueOf(p.ServeTLSFunc).Pointer() != reflect.ValueOf(s.ServeTLS).Pointer() { - t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSFunc, s.ServeTLS) - } - - if reflect.ValueOf(p.ServeTLSEmbedFunc).Pointer() != reflect.ValueOf(s.ServeTLSEmbed).Pointer() { - t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSEmbedFunc, s.ServeTLSEmbed) + if p.ServeFunc == nil || p.ServeTLSFunc == nil || p.ServeTLSEmbedFunc == nil { + t.Error("New() did not wire one of ServeFunc/ServeTLSFunc/ServeTLSEmbedFunc") } } -func Test_listen(t *testing.T) { - t.Parallel() +func Test_listen_Reuseport(t *testing.T) { + prev := runtime.GOMAXPROCS(0) + t.Cleanup(func() { + runtime.GOMAXPROCS(prev) + }) - p := &Prefork{ - Reuseport: true, - } - addr := getAddr() + p := &Prefork{Reuseport: true} + addr := freeAddr(t) ln, err := p.listen(addr) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Fatalf("listen: %v", err) } + t.Cleanup(func() { + _ = ln.Close() + }) - ln.Close() - - lnAddr := ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) + if got, want := ln.Addr().String(), addr; got != want { + t.Errorf("ln.Addr() == %q, want %q", got, want) } - if p.Network != defaultNetwork { t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork) } - - procs := runtime.GOMAXPROCS(0) - if procs != 1 { - t.Errorf("GOMAXPROCS == %d, want %d", procs, 1) - } } func Test_setTCPListenerFiles(t *testing.T) { @@ -102,125 +128,444 @@ func Test_setTCPListenerFiles(t *testing.T) { } p := &Prefork{} - addr := getAddr() + addr := freeAddr(t) - err := p.setTCPListenerFiles(addr) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + if err := p.setTCPListenerFiles(addr); err != nil { + t.Fatalf("setTCPListenerFiles: %v", err) } + t.Cleanup(func() { + _ = p.ln.Close() + for _, f := range p.files { + _ = f.Close() + } + }) if p.ln == nil { - t.Fatal("Prefork.ln is nil") + t.Fatal("p.ln is nil after setTCPListenerFiles") } - - p.ln.Close() - - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) + if got, want := p.ln.Addr().String(), addr; got != want { + t.Errorf("p.ln.Addr() == %q, want %q", got, want) } - - if p.Network != defaultNetwork { - t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork) - } - if len(p.files) != 1 { - t.Errorf("Prefork.files == %d, want %d", len(p.files), 1) + t.Errorf("len(p.files) == %d, want 1", len(p.files)) } } -func Test_ListenAndServe(t *testing.T) { - // This test can't run parallel as it modifies os.Args. +func Test_setTCPListenerFiles_BadAddr(t *testing.T) { + t.Parallel() - setUp() - defer tearDown() - - s := &fasthttp.Server{} - p := New(s) - p.Reuseport = true - p.ServeFunc = func(ln net.Listener) error { - return nil + if runtime.GOOS == "windows" { + t.SkipNow() } - - addr := getAddr() - - err := p.ListenAndServe(addr) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - p.ln.Close() - - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) - } - - if p.ln == nil { - t.Error("Prefork.ln is nil") + p := &Prefork{} + if err := p.setTCPListenerFiles("definitely not an address"); err == nil { + t.Fatal("expected error for malformed addr, got nil") } } -func Test_ListenAndServeTLS(t *testing.T) { - // This test can't run parallel as it modifies os.Args. +// Test_ListenAndServe_Stub_ChildPath drives the child branch of all three +// ListenAndServe* entry points using a stubbed Serve function. It replaces the +// previous trio of near-identical tests that only validated field assignment. +func Test_ListenAndServe_Stub_ChildPath(t *testing.T) { + // child env mutation precludes t.Parallel. + t.Setenv(preforkChildEnvVariable, preforkChildEnvValue) - setUp() - defer tearDown() - - s := &fasthttp.Server{} - p := New(s) - p.Reuseport = true - p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error { - return nil + type call struct { + listener bool + certFile string + keyFile string + certData string + keyData string } - addr := getAddr() - - err := p.ListenAndServeTLS(addr, "./key", "./cert") - if err != nil { - t.Errorf("Unexpected error: %v", err) + tests := []struct { + name string + run func(t *testing.T, p *Prefork, addr string) error + want call + }{ + { + name: "ListenAndServe", + run: func(_ *testing.T, p *Prefork, addr string) error { return p.ListenAndServe(addr) }, + want: call{listener: true}, + }, + { + name: "ListenAndServeTLS", + run: func(_ *testing.T, p *Prefork, addr string) error { + return p.ListenAndServeTLS(addr, "./key", "./cert") + }, + want: call{listener: true, certFile: "./cert", keyFile: "./key"}, + }, + { + name: "ListenAndServeTLSEmbed", + run: func(_ *testing.T, p *Prefork, addr string) error { + return p.ListenAndServeTLSEmbed(addr, []byte("certPEM"), []byte("keyPEM")) + }, + want: call{listener: true, certData: "certPEM", keyData: "keyPEM"}, + }, } - p.ln.Close() + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var got call + p := New(&fasthttp.Server{}) + p.Reuseport = true + p.ServeFunc = func(ln net.Listener) error { + got.listener = ln != nil + return nil + } + p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error { + got.listener = ln != nil + got.certFile = certFile + got.keyFile = keyFile + return nil + } + p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error { + got.listener = ln != nil + got.certData = string(certData) + got.keyData = string(keyData) + return nil + } - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) - } + addr := freeAddr(t) + if err := tc.run(t, p, addr); err != nil { + t.Fatalf("%s: %v", tc.name, err) + } + t.Cleanup(func() { + if p.ln != nil { + _ = p.ln.Close() + } + }) - if p.ln == nil { - t.Error("Prefork.ln is nil") + if got != tc.want { + t.Errorf("%s call = %+v, want %+v", tc.name, got, tc.want) + } + }) } } -func Test_ListenAndServeTLSEmbed(t *testing.T) { - // This test can't run parallel as it modifies os.Args. +func Test_doCommand_CommandProducerErrors(t *testing.T) { + t.Parallel() - setUp() - defer tearDown() - - s := &fasthttp.Server{} - p := New(s) - p.Reuseport = true - p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error { - return nil + tests := []struct { + name string + produce func(files []*os.File) (*exec.Cmd, error) + wantErr error + }{ + { + name: "producer returns error", + produce: func([]*os.File) (*exec.Cmd, error) { + return nil, errors.New("boom") + }, + }, + { + name: "producer returns nil cmd", + //nolint:nilnil // intentionally tests the (nil, nil) misbehaviour guard + produce: func([]*os.File) (*exec.Cmd, error) { + return nil, nil + }, + wantErr: ErrCommandProducerNilCmd, + }, + { + name: "producer returns unstarted cmd", + produce: func([]*os.File) (*exec.Cmd, error) { + return &exec.Cmd{}, nil + }, + wantErr: ErrCommandProducerNotStarted, + }, } - addr := getAddr() + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + p := &Prefork{CommandProducer: tc.produce} + cmd, err := p.doCommand() + if cmd != nil { + t.Errorf("expected nil cmd on error, got %v", cmd) + } + if err == nil { + t.Fatal("expected error, got nil") + } + if tc.wantErr != nil && !errors.Is(err, tc.wantErr) { + t.Errorf("err = %v, want errors.Is %v", err, tc.wantErr) + } + }) + } +} + +type testLogger struct { + mu sync.Mutex + messages []string +} + +func (l *testLogger) Printf(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + l.mu.Lock() + l.messages = append(l.messages, msg) + l.mu.Unlock() +} + +// Test_Prefork_Lifecycle drives prefork() to ErrOverRecovery via +// short-lived no-op children and asserts the callback ordering / arguments. +func Test_Prefork_Lifecycle(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + type event struct { + name string + pids []int + } + + var mu sync.Mutex + var events []event + record := func(name string, pids ...int) { + mu.Lock() + events = append(events, event{name, pids}) + mu.Unlock() + } + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + Logger: &testLogger{}, + CommandProducer: produce, + OnChildSpawn: func(pid int) error { + record("spawn", pid) + return nil + }, + OnMasterReady: func(childPIDs []int) error { + record("ready", childPIDs...) + return nil + }, + OnChildRecover: func(oldPID, newPID int) error { + record("recover", oldPID, newPID) + return nil + }, + } + + err := p.prefork(freeAddr(t)) + if !errors.Is(err, ErrOverRecovery) { + t.Fatalf("expected ErrOverRecovery, got: %v", err) + } + + mu.Lock() + defer mu.Unlock() + + goMaxProcs := runtime.GOMAXPROCS(0) + + var spawnCount, readyCount, recoverCount int + for _, e := range events { + switch e.name { + case "spawn": + spawnCount++ + if len(e.pids) != 1 || e.pids[0] <= 0 { + t.Errorf("spawn event has invalid PID: %v", e.pids) + } + case "ready": + readyCount++ + if len(e.pids) == 0 { + t.Error("ready event received empty PID list") + } + case "recover": + recoverCount++ + if len(e.pids) != 2 || e.pids[0] <= 0 || e.pids[1] <= 0 { + t.Errorf("recover event has invalid PIDs: %v", e.pids) + } + if e.pids[0] == e.pids[1] { + t.Error("recover old and new PID should differ") + } + } + } + + if readyCount != 1 { + t.Errorf("OnMasterReady called %d times, want 1", readyCount) + } + if spawnCount < goMaxProcs { + t.Errorf("OnChildSpawn called %d times, want at least %d", spawnCount, goMaxProcs) + } + if recoverCount == 0 { + t.Error("OnChildRecover was never called") + } + + // ready must come after exactly goMaxProcs initial spawns. + readyIdx := -1 + spawnsBeforeReady := 0 + for i, e := range events { + if e.name == "ready" { + readyIdx = i + break + } + if e.name == "spawn" { + spawnsBeforeReady++ + } + } + if readyIdx == -1 { + t.Fatal("OnMasterReady was never called") + } + if spawnsBeforeReady != goMaxProcs { + t.Errorf("OnMasterReady called after %d initial spawns, want %d", spawnsBeforeReady, goMaxProcs) + } + + // every recover event must be preceded by a spawn for the new PID. + recoveredSpawnByPID := make(map[int]bool) + for _, e := range events[readyIdx+1:] { + if e.name == "spawn" { + recoveredSpawnByPID[e.pids[0]] = true + } + if e.name == "recover" { + if !recoveredSpawnByPID[e.pids[1]] { + t.Errorf("OnChildRecover for PID %d happened before OnChildSpawn", e.pids[1]) + } + } + } +} + +func Test_Prefork_InitialChildSpawnError(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) - err := p.ListenAndServeTLSEmbed(addr, []byte("key"), []byte("cert")) - if err != nil { - t.Errorf("Unexpected error: %v", err) + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + expectedErr := errors.New("initial spawn rejected") + var calls atomic.Int32 + + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + Logger: &testLogger{}, + CommandProducer: produce, + OnChildSpawn: func(_ int) error { + calls.Add(1) + return expectedErr + }, + } + + err := p.prefork(freeAddr(t)) + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got: %v", expectedErr, err) + } + if calls.Load() == 0 { + t.Fatal("OnChildSpawn was never invoked") + } +} + +func Test_Prefork_OnMasterReadyError(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + expectedErr := errors.New("ready rejected") + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + Logger: &testLogger{}, + CommandProducer: produce, + OnMasterReady: func([]int) error { + return expectedErr + }, + } + + err := p.prefork(freeAddr(t)) + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got: %v", expectedErr, err) + } +} + +func Test_Prefork_RecoveredChildSpawnError(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + expectedErr := errors.New("spawn failed") + var spawnCount, recoverCount atomic.Int32 + + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + Logger: &testLogger{}, + CommandProducer: produce, + OnChildSpawn: func(pid int) error { + if pid <= 0 { + t.Errorf("OnChildSpawn called with invalid PID: %d", pid) + } + n := spawnCount.Add(1) + if int(n) > runtime.GOMAXPROCS(0) { + return expectedErr + } + return nil + }, + OnChildRecover: func(_, _ int) error { + recoverCount.Add(1) + return nil + }, + } + + err := p.prefork(freeAddr(t)) + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got: %v", expectedErr, err) + } + if got := recoverCount.Load(); got != 0 { + t.Fatalf("OnChildRecover called %d times, want 0", got) + } +} + +func Test_Prefork_HookPanicSurfaces(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) + + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + Logger: &testLogger{}, + CommandProducer: produce, + OnChildSpawn: func(_ int) error { + panic("user code blew up") + }, + } + + err := p.prefork(freeAddr(t)) + if err == nil { + t.Fatal("expected error from panicking hook, got nil") } +} - p.ln.Close() +// Test_Prefork_RecoverInterval verifies the optional backoff actually delays +// the respawn — without rabbit-holing into hard wallclock assertions. +func Test_Prefork_RecoverInterval(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { runtime.GOMAXPROCS(prev) }) - lnAddr := p.ln.Addr().String() - if lnAddr != addr { - t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr) + produce, cleanup := noopChildProducer(t) + t.Cleanup(cleanup) + + const interval = 50 * time.Millisecond + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + RecoverInterval: interval, + Logger: &testLogger{}, + CommandProducer: produce, } + + start := time.Now() + err := p.prefork(freeAddr(t)) + elapsed := time.Since(start) - if p.ln == nil { - t.Error("Prefork.ln is nil") + if !errors.Is(err, ErrOverRecovery) { + t.Fatalf("expected ErrOverRecovery, got %v", err) + } + // At least one recover interval must have elapsed before threshold fired. + if elapsed < interval { + t.Errorf("elapsed %v < interval %v; backoff did not apply", elapsed, interval) } }