feat(prefork): graceful shutdown, leak fixes, hook robustness

Addresses outstanding review concerns and several adjacent issues
surfaced during a follow-up review pass.

Lifecycle / supervision
- Track every per-child Wait goroutine via sync.WaitGroup and unblock
  pending sigCh sends through a context.Cancel so early-return paths
  (OnChildSpawn / OnMasterReady error, recovery doCommand error,
  ErrOverRecovery) can no longer leak goroutines or stall children.
- Install signal.Notify(SIGTERM, SIGINT) in the master so deploy/
  rolling-restart signals enter the shutdown path instead of killing
  the master without graceful teardown.
- Replace the unconditional SIGKILL defer with a SIGTERM-then-SIGKILL
  sequence gated by a configurable ShutdownGracePeriod (defaults to 5s,
  Windows path stays SIGKILL since Signal(SIGTERM) is unsupported).

API
- OnChildRecover now returns error so callers can implement recovery
  policies (circuit-breaker etc.); panic in any hook is recovered and
  surfaced as the returned error, with diagnostic logging.
- Add RecoverInterval (optional crash-loop backoff) and
  ShutdownGracePeriod fields with safe zero-value defaults.
- Export ErrCommandProducerNilCmd and ErrCommandProducerNotStarted
  sentinel errors so callers can errors.Is them.
- Rename oldPid/newPid to oldPID/newPID per Go initialism convention.
- Logger interface now declares an explicit compile-time compatibility
  check with fasthttp.Logger.

Resource hygiene
- Master closes both the original tcpListener and the duped fd in
  p.files when prefork() returns; previously the duped fd leaked once
  per call.
- doCommand wraps every error path with %w + fmt.Errorf so caller-side
  diagnostics keep stage context.
- Strip pre-existing FASTHTTP_PREFORK_CHILD entries before appending so
  child env never carries duplicate keys.
- Extract magic numbers as package constants
  (inheritedListenerFD, masterPollInterval, defaultShutdownGracePeriod,
  preforkChildEnvValue).
- Rename the inherited listener fd via os.NewFile so net.FileListener
  errors are diagnosable.

Tests
- Migrate to t.Setenv (drop the global setUp/tearDown helpers) — fixes
  the env-mutation-vs-parallel race.
- Replace rand.Intn port helper with `:0` + Listener.Addr() to remove
  port-collision flakes under -count and parallel runs.
- Collapse the three near-identical Test_ListenAndServe* tests into a
  single table-driven subtest that actually asserts the args forwarded
  to ServeFunc/ServeTLSFunc/ServeTLSEmbedFunc.
- Add coverage for the previously untested branches:
  CommandProducer returning err / nil cmd / unstarted cmd,
  initial OnChildSpawn error, OnMasterReady error,
  hook panic surfacing, RecoverInterval enforcement.
- noopChildProducer helper kills + waits any spawned child binaries
  during cleanup so failed tests no longer leave subprocesses around.
This commit is contained in:
René
2026-05-02 13:12:08 +02:00
parent db78ffe6f1
commit 262ea09c9e
2 changed files with 674 additions and 314 deletions
+330 -124
View File
@@ -2,12 +2,17 @@
package prefork
import (
"context"
"errors"
"fmt"
"log"
"net"
"os"
"os/exec"
"os/signal"
"runtime"
"sync"
"syscall"
"time"
"github.com/valyala/fasthttp"
@@ -16,25 +21,53 @@ import (
const (
preforkChildEnvVariable = "FASTHTTP_PREFORK_CHILD"
preforkChildEnvValue = "1"
defaultNetwork = "tcp4"
// inheritedListenerFD is the file descriptor used by the master to pass
// the bound listener to a child process via ExtraFiles. Children open the
// listener via os.NewFile(inheritedListenerFD, ...) when Reuseport is false.
inheritedListenerFD = 3
// masterPollInterval is the period of the watchMaster ppid-poll on Unix.
masterPollInterval = 500 * time.Millisecond
// defaultShutdownGracePeriod is how long the master waits for children to
// exit cleanly after sending SIGTERM before forcibly killing them.
defaultShutdownGracePeriod = 5 * time.Second
)
var (
defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags))
// ErrOverRecovery is returned when the times of starting over child prefork processes exceed
// the threshold.
// ErrOverRecovery is returned when child prefork process restarts exceed
// the value of RecoverThreshold.
ErrOverRecovery = errors.New("exceeding the value of RecoverThreshold")
// ErrOnlyReuseportOnWindows is returned when Reuseport is false.
// ErrOnlyReuseportOnWindows is returned when running on Windows without Reuseport.
ErrOnlyReuseportOnWindows = errors.New("windows only supports Reuseport = true")
// ErrCommandProducerNilCmd is returned when a CommandProducer returns
// (nil, nil) instead of a started command.
ErrCommandProducerNilCmd = errors.New("prefork: CommandProducer returned nil command")
// ErrCommandProducerNotStarted is returned when a CommandProducer returns
// an *exec.Cmd whose Process is nil (i.e. cmd.Start() was not called).
ErrCommandProducerNotStarted = errors.New("prefork: CommandProducer must return a started command")
)
// Logger is used for logging formatted messages.
// Logger is used for logging formatted messages. Its method set is intentionally
// identical to fasthttp.Logger so that *fasthttp.Server.Logger can be assigned
// directly.
type Logger interface {
// Printf must have the same semantics as log.Printf.
Printf(format string, args ...any)
}
// Compile-time check that fasthttp.Logger satisfies the local Logger interface;
// keeps the two types in sync if either side ever evolves.
var _ Logger = fasthttp.Logger(nil)
// Prefork implements fasthttp server prefork.
//
// Preforks master process (with all cores) between several child processes
@@ -44,7 +77,8 @@ type Logger interface {
// WARNING: using prefork prevents the use of any global state!
// Things like in-memory caches won't work.
type Prefork struct {
// By default standard logger from log package is used.
// Logger receives diagnostic output. By default the standard log package
// logger writing to stderr is used.
Logger Logger
ln net.Listener
@@ -53,21 +87,29 @@ type Prefork struct {
ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error
ServeTLSEmbedFunc func(ln net.Listener, certData, keyData []byte) error
// The network must be "tcp", "tcp4" or "tcp6".
//
// By default is "tcp4"
// Network must be "tcp", "tcp4" or "tcp6". Default is "tcp4".
Network string
files []*os.File
// Child prefork processes may exit with failure and will be started over until the times reach
// the value of RecoverThreshold, then it will return and terminate the server.
// RecoverThreshold caps how often crashed children are respawned before
// the master returns ErrOverRecovery. New() sets it to max(1, GOMAXPROCS/2).
// When constructing a Prefork directly without New(), a zero value will
// terminate the master after the very first child crash.
RecoverThreshold int
// Flag to use a listener with reuseport, if not a file Listener will be used
// RecoverInterval, when > 0, makes the master sleep for the given duration
// before respawning a crashed child. Useful as crash-loop backoff.
RecoverInterval time.Duration
// ShutdownGracePeriod is the time the master waits for children to exit
// after sending SIGTERM before falling back to SIGKILL. Defaults to 5s
// when zero. On Windows SIGTERM is not delivered, so this is unused there.
ShutdownGracePeriod time.Duration
// Reuseport selects a reuseport listener instead of fd-passing.
// See: https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/
//
// It's disabled by default
// Disabled by default.
Reuseport bool
// OnMasterDeath, when non-nil, enables monitoring of the master process
@@ -76,25 +118,45 @@ type Prefork struct {
//
// It is recommended to set this to func() { os.Exit(1) } if no custom
// cleanup is needed.
//
// Threading: invoked once from a watcher goroutine in the child. Must not
// block the goroutine for long; must not call Prefork methods.
OnMasterDeath func()
// OnChildSpawn is called in the master process whenever a new child process is spawned.
// OnChildSpawn is called in the master after a new child process is
// successfully started, both during initial spawn and during recovery.
// It receives the PID of the newly spawned child process.
//
// If this callback returns an error, all child processes are killed and the
// prefork operation returns that error.
// If this callback returns an error, all already-running children are
// killed and the prefork operation returns that error.
//
// Threading: invoked synchronously from the master goroutine. Must not
// block; must not call Prefork methods. Panics are recovered and surfaced
// as the returned error.
OnChildSpawn func(pid int) error
// OnMasterReady is called in the master process after all child processes have been spawned.
// It receives a slice of all child process PIDs.
// OnMasterReady is called in the master process exactly once, after all
// initial children have been spawned and before the supervision loop runs.
// It receives a slice of all initial child PIDs.
//
// If this callback returns an error, the prefork operation will be aborted.
// If this callback returns an error, the prefork operation aborts and
// returns that error after killing the children.
//
// Threading: invoked synchronously from the master goroutine. The slice
// is owned by the caller after the call returns. Panics are recovered.
OnMasterReady func(childPIDs []int) error
// OnChildRecover is called in the master process when a crashed child process
// has been replaced by a new one. It receives the PID of the old (crashed)
// process and the PID of the newly spawned replacement.
OnChildRecover func(oldPid, newPid int)
// OnChildRecover is called in the master after a crashed child has been
// replaced. It receives the PID of the old (crashed) process and the PID
// of its replacement.
//
// If this callback returns an error, all running children are killed and
// the prefork operation returns that error.
//
// Threading: invoked synchronously from the master goroutine, after
// OnChildSpawn for the new child. Panics are recovered and surfaced as
// the returned error.
OnChildRecover func(oldPID, newPID int) error
// CommandProducer creates and starts a child process command.
// If nil, the default implementation re-executes the current binary
@@ -104,20 +166,23 @@ type Prefork struct {
// A custom producer must:
// - Set FASTHTTP_PREFORK_CHILD=1 in the child's environment
// (otherwise IsChild() returns false and the child won't serve)
// - Call cmd.Start() before returning (the returned command must be started)
// - Call cmd.Start() before returning (the returned command must be
// started so cmd.Process is non-nil)
// - Pass the provided files as cmd.ExtraFiles when Reuseport is false
//
// This is primarily useful for testing (injecting dummy commands)
// or for frameworks that need custom child process setup.
// Primarily useful for testing (injecting dummy commands) or for
// frameworks that need custom child process setup.
CommandProducer func(files []*os.File) (*exec.Cmd, error)
}
// IsChild checks if the current thread/process is a child.
// IsChild reports whether the current process is a prefork child.
func IsChild() bool {
return os.Getenv(preforkChildEnvVariable) == "1"
return os.Getenv(preforkChildEnvVariable) == preforkChildEnvValue
}
// New wraps the fasthttp server to run with preforked processes.
// It seeds Network and RecoverThreshold to sensible defaults; existing
// fields on s (Logger, Serve*) are captured.
func New(s *fasthttp.Server) *Prefork {
return &Prefork{
Network: defaultNetwork,
@@ -140,20 +205,32 @@ func (p *Prefork) logger() Logger {
return defaultLogger
}
// invokeHook runs fn under a panic recovery, returning the panic as an error
// so a misbehaving callback never tears down the master.
func (p *Prefork) invokeHook(name string, fn func() error) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("prefork: %s panicked: %v", name, r)
p.logger().Printf("%v", err)
}
}()
return fn()
}
func (p *Prefork) watchMaster(masterPID int) {
if runtime.GOOS == "windows" {
// On Windows, os.Getppid() returns a static PID that doesn't change
// when the parent exits (no reparenting). Use FindProcess+Wait instead.
proc, err := os.FindProcess(masterPID)
if err != nil {
p.logger().Printf("watchMaster: failed to find master process %d: %v\n", masterPID, err)
p.logger().Printf("watchMaster: failed to find master process %d: %v", masterPID, err)
p.OnMasterDeath()
return
}
if _, err = proc.Wait(); err != nil {
p.logger().Printf("watchMaster: error waiting for master process %d: %v\n", masterPID, err)
p.logger().Printf("watchMaster: error waiting for master process %d: %v", masterPID, err)
}
p.logger().Printf("master process died\n")
p.logger().Printf("master process %d died", masterPID)
p.OnMasterDeath()
return
}
@@ -162,12 +239,12 @@ func (p *Prefork) watchMaster(masterPID int) {
// to another process, causing Getppid() to change. Comparing against
// the original masterPID (instead of hardcoding 1) ensures this works
// correctly when the master itself is PID 1 (e.g. in Docker containers).
ticker := time.NewTicker(500 * time.Millisecond)
ticker := time.NewTicker(masterPollInterval)
defer ticker.Stop()
for range ticker.C {
if os.Getppid() != masterPID {
p.logger().Printf("master process died\n")
p.logger().Printf("master process %d died", masterPID)
p.OnMasterDeath()
return
}
@@ -185,8 +262,10 @@ func (p *Prefork) listen(addr string) (net.Listener, error) {
return reuseport.Listen(p.Network, addr)
}
// File descriptor 3 is the first ExtraFiles entry passed by the master process.
return net.FileListener(os.NewFile(3, ""))
// fd inheritedListenerFD is the first ExtraFiles entry passed by the
// master process when Reuseport is false. Naming the file gives clearer
// errors from net.FileListener if the fd is invalid.
return net.FileListener(os.NewFile(inheritedListenerFD, "fasthttp-prefork-listener"))
}
// listenAsChild performs the common child process setup: creates the listener
@@ -213,62 +292,142 @@ func (p *Prefork) setTCPListenerFiles(addr string) error {
tcpAddr, err := net.ResolveTCPAddr(p.Network, addr)
if err != nil {
return err
return fmt.Errorf("prefork: resolve %s/%s: %w", p.Network, addr, err)
}
tcplistener, err := net.ListenTCP(p.Network, tcpAddr)
tcpListener, err := net.ListenTCP(p.Network, tcpAddr)
if err != nil {
return err
return fmt.Errorf("prefork: listen tcp %s: %w", addr, err)
}
p.ln = tcplistener
p.ln = tcpListener
fl, err := tcplistener.File()
listenerFile, err := tcpListener.File()
if err != nil {
return err
return fmt.Errorf("prefork: dup listener fd: %w", err)
}
p.files = []*os.File{fl}
p.files = []*os.File{listenerFile}
return nil
}
// childEnv returns os.Environ() with the prefork child marker variable set,
// stripping any pre-existing value to avoid duplicate keys with last-wins
// semantics.
func childEnv() []string {
src := os.Environ()
out := make([]string, 0, len(src)+1)
prefix := preforkChildEnvVariable + "="
for _, kv := range src {
if len(kv) >= len(prefix) && kv[:len(prefix)] == prefix {
continue
}
out = append(out, kv)
}
out = append(out, preforkChildEnvVariable+"="+preforkChildEnvValue)
return out
}
func (p *Prefork) doCommand() (*exec.Cmd, error) {
// Use custom CommandProducer if provided
if p.CommandProducer != nil {
cmd, err := p.CommandProducer(p.files)
if err != nil {
return nil, err
return nil, fmt.Errorf("prefork: CommandProducer: %w", err)
}
if cmd == nil || cmd.Process == nil {
return nil, errors.New("prefork: CommandProducer must return a started command")
if cmd == nil {
return nil, ErrCommandProducerNilCmd
}
if cmd.Process == nil {
return nil, ErrCommandProducerNotStarted
}
return cmd, nil
}
// Default implementation using os.Executable() for reliable path resolution
executable, err := os.Executable()
if err != nil {
return nil, err
return nil, fmt.Errorf("prefork: resolve executable: %w", err)
}
args := make([]string, len(os.Args))
args[0] = executable
copy(args[1:], os.Args[1:])
args := append([]string{executable}, os.Args[1:]...)
cmd := &exec.Cmd{
Path: executable,
Args: args,
Stdout: os.Stdout,
Stderr: os.Stderr,
Env: append(os.Environ(), preforkChildEnvVariable+"=1"),
Env: childEnv(),
ExtraFiles: p.files,
}
err = cmd.Start()
return cmd, err
if err = cmd.Start(); err != nil {
return nil, fmt.Errorf("prefork: start child %q: %w", executable, err)
}
return cmd, nil
}
func (p *Prefork) prefork(addr string) (err error) {
type childExit struct {
err error
pid int
}
// shutdownChildren signals every entry in childProcs first with SIGTERM (on
// platforms where it is supported) and waits up to grace for them to exit.
// Survivors are then killed unconditionally. wg tracks the per-child Wait
// goroutines and must be drained before returning so no goroutine outlives
// prefork().
func (p *Prefork) shutdownChildren(
childProcs map[int]*exec.Cmd,
wg *sync.WaitGroup,
cancel context.CancelFunc,
grace time.Duration,
) {
if grace <= 0 {
grace = defaultShutdownGracePeriod
}
if runtime.GOOS != "windows" {
for pid, proc := range childProcs {
if proc == nil || proc.Process == nil {
continue
}
if termErr := proc.Process.Signal(syscall.SIGTERM); termErr != nil &&
!errors.Is(termErr, os.ErrProcessDone) {
p.logger().Printf("prefork: SIGTERM child %d: %v", pid, termErr)
}
}
}
// Wait for graceful exits, with a timeout fallback to SIGKILL.
graceful := make(chan struct{})
go func() {
wg.Wait()
close(graceful)
}()
timer := time.NewTimer(grace)
defer timer.Stop()
select {
case <-graceful:
case <-timer.C:
}
for pid, proc := range childProcs {
if proc == nil || proc.Process == nil {
continue
}
if killErr := proc.Process.Kill(); killErr != nil &&
!errors.Is(killErr, os.ErrProcessDone) {
p.logger().Printf("prefork: kill child %d: %v", pid, killErr)
}
}
// Cancel the per-Wait goroutines' send-context so any still blocked on
// sigCh send unblock cleanly, then wait for all of them to exit.
cancel()
wg.Wait()
}
func (p *Prefork) prefork(addr string) (err error) { //nolint:gocyclo
if !p.Reuseport {
if runtime.GOOS == "windows" {
return ErrOnlyReuseportOnWindows
@@ -278,37 +437,60 @@ func (p *Prefork) prefork(addr string) (err error) {
return err
}
// defer for closing the net.Listener opened by setTCPListenerFiles.
// Close listener fds opened by setTCPListenerFiles. Both the original
// tcpListener (p.ln) and the duped fd (p.files[0]) belong to the
// master only; children inherit independent dup'd copies via fork+exec.
defer func() {
e := p.ln.Close()
if err == nil {
err = e
err = errors.Join(err, p.ln.Close())
for _, f := range p.files {
if closeErr := f.Close(); closeErr != nil {
p.logger().Printf("prefork: close listener fd: %v", closeErr)
}
}
p.files = nil
}()
}
// ctx cancels per-child Wait goroutines so they unblock from sigCh sends
// once the supervision loop is gone.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Catch SIGTERM/SIGINT in the master so we run our shutdown path instead
// of being killed by the OS without children getting a graceful chance.
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signalCh)
goMaxProcs := runtime.GOMAXPROCS(0)
// Buffer is sized to initial fleet; per-child goroutines fall back to a
// context-aware select on send so capacity is not load-bearing.
sigCh := make(chan childExit, goMaxProcs)
childProcs := make(map[int]*exec.Cmd, goMaxProcs)
var wg sync.WaitGroup
startWait := func(cmd *exec.Cmd, pid int) {
wg.Add(1)
go func() {
defer wg.Done()
result := childExit{pid: pid, err: cmd.Wait()}
select {
case sigCh <- result:
case <-ctx.Done():
}
}()
}
type procSig struct {
err error
pid int
}
goMaxProcs := runtime.GOMAXPROCS(0)
sigCh := make(chan procSig, goMaxProcs)
childProcs := make(map[int]*exec.Cmd, goMaxProcs)
defer func() {
for _, proc := range childProcs {
_ = proc.Process.Kill()
}
p.shutdownChildren(childProcs, &wg, cancel, p.ShutdownGracePeriod)
}()
// Collect child PIDs for OnMasterReady callback
childPIDs := make([]int, 0, goMaxProcs)
for range goMaxProcs {
var cmd *exec.Cmd
if cmd, err = p.doCommand(); err != nil {
p.logger().Printf("failed to start a child prefork process, error: %v\n", err)
p.logger().Printf("prefork: failed to start a child process: %v", err)
return err
}
@@ -316,71 +498,96 @@ func (p *Prefork) prefork(addr string) (err error) {
childProcs[pid] = cmd
childPIDs = append(childPIDs, pid)
// Start Wait goroutine before OnChildSpawn so that if OnChildSpawn
// fails and the deferred Kill() runs, Wait() will collect the child.
go func(c *exec.Cmd, pid int) {
sigCh <- procSig{pid: pid, err: c.Wait()}
}(cmd, pid)
// Start Wait goroutine before the user callback so a panic / error
// return from OnChildSpawn cannot leave a zombie behind.
startWait(cmd, pid)
if p.OnChildSpawn != nil {
if err = p.OnChildSpawn(pid); err != nil {
p.logger().Printf("OnChildSpawn callback failed for PID %d: %v\n", pid, err)
return err
pid := pid
if hookErr := p.invokeHook("OnChildSpawn", func() error {
return p.OnChildSpawn(pid)
}); hookErr != nil {
p.logger().Printf("prefork: OnChildSpawn for PID %d: %v", pid, hookErr)
return hookErr
}
}
}
// Call OnMasterReady callback after all children are spawned
if p.OnMasterReady != nil {
if err = p.OnMasterReady(childPIDs); err != nil {
p.logger().Printf("OnMasterReady callback failed: %v\n", err)
return err
pids := append([]int(nil), childPIDs...)
if hookErr := p.invokeHook("OnMasterReady", func() error {
return p.OnMasterReady(pids)
}); hookErr != nil {
p.logger().Printf("prefork: OnMasterReady: %v", hookErr)
return hookErr
}
}
var exitedProcs int
for sig := range sigCh {
delete(childProcs, sig.pid)
for {
select {
case sig := <-signalCh:
p.logger().Printf("prefork: received signal %v, shutting down", sig)
return nil
p.logger().Printf("one of the child prefork processes exited with "+
"error: %v", sig.err)
case sig := <-sigCh:
delete(childProcs, sig.pid)
exitedProcs++
if exitedProcs > p.RecoverThreshold {
p.logger().Printf("child prefork processes exit too many times, "+
"which exceeds the value of RecoverThreshold(%d), "+
"exiting the master process.\n", p.RecoverThreshold)
err = ErrOverRecovery
break
}
if sig.err != nil {
p.logger().Printf("prefork: child PID %d exited: %v", sig.pid, sig.err)
} else {
p.logger().Printf("prefork: child PID %d exited cleanly", sig.pid)
}
var cmd *exec.Cmd
cmd, err = p.doCommand()
if err != nil {
break
}
newPid := cmd.Process.Pid
childProcs[newPid] = cmd
exitedProcs++
if exitedProcs > p.RecoverThreshold {
p.logger().Printf(
"prefork: child exits (%d) exceed RecoverThreshold (%d), terminating master",
exitedProcs, p.RecoverThreshold,
)
return ErrOverRecovery
}
// Start Wait goroutine before callbacks to avoid zombie processes
// if a callback fails and the deferred Kill() runs.
go func(c *exec.Cmd, pid int) {
sigCh <- procSig{pid: pid, err: c.Wait()}
}(cmd, newPid)
if p.RecoverInterval > 0 {
select {
case <-time.After(p.RecoverInterval):
case <-signalCh:
return nil
}
}
if p.OnChildSpawn != nil {
if err = p.OnChildSpawn(newPid); err != nil {
p.logger().Printf("OnChildSpawn callback failed for recovered PID %d: %v\n", newPid, err)
return err
cmd, doErr := p.doCommand()
if doErr != nil {
p.logger().Printf("prefork: recovery doCommand: %v", doErr)
return doErr
}
newPID := cmd.Process.Pid
childProcs[newPID] = cmd
startWait(cmd, newPID)
if p.OnChildSpawn != nil {
newPID := newPID
if hookErr := p.invokeHook("OnChildSpawn", func() error {
return p.OnChildSpawn(newPID)
}); hookErr != nil {
p.logger().Printf("prefork: OnChildSpawn for recovered PID %d: %v", newPID, hookErr)
return hookErr
}
}
if p.OnChildRecover != nil {
oldPID := sig.pid
newPID := newPID
if hookErr := p.invokeHook("OnChildRecover", func() error {
return p.OnChildRecover(oldPID, newPID)
}); hookErr != nil {
p.logger().Printf("prefork: OnChildRecover (%d -> %d): %v", oldPID, newPID, hookErr)
return hookErr
}
}
}
if p.OnChildRecover != nil {
p.OnChildRecover(sig.pid, newPid)
}
}
return err
}
// ListenAndServe serves HTTP requests from the given TCP addr.
@@ -398,11 +605,10 @@ func (p *Prefork) ListenAndServe(addr string) error {
// ListenAndServeTLS serves HTTPS requests from the given TCP addr.
//
// certKey is the path to the TLS private key file.
// certFile is the path to the TLS certificate file.
//
// Note: parameter order is (addr, certKey, certFile) — key before cert.
// Internally forwards to ServeTLSFunc as (certFile, certKey).
// Note: parameter order is (addr, certKey, certFile) — key path comes
// before cert path. This is preserved for backward compatibility with
// existing callers and differs from fasthttp.Server.ListenAndServeTLS.
// New code should prefer ListenAndServeTLSEmbed.
func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error {
if IsChild() {
ln, err := p.listenAsChild(addr)
+344 -190
View File
@@ -3,44 +3,78 @@ package prefork
import (
"errors"
"fmt"
"math/rand"
"net"
"os"
"os/exec"
"reflect"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/valyala/fasthttp"
)
func setUp() {
os.Setenv(preforkChildEnvVariable, "1")
// freeAddr returns a TCP4 address that the OS believes is free at call time.
// It avoids the rand-port collision flakes the previous helper exhibited.
func freeAddr(t testing.TB) string {
t.Helper()
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("freeAddr: %v", err)
}
addr := ln.Addr().String()
if closeErr := ln.Close(); closeErr != nil {
t.Fatalf("freeAddr: close: %v", closeErr)
}
return addr
}
func tearDown() {
os.Unsetenv(preforkChildEnvVariable)
}
// noopChildProducer returns a CommandProducer that re-execs the test binary
// into a no-op subprocess. The returned cleanup must be deferred (or registered
// via t.Cleanup) so leaked subprocesses are reaped if the test fails midway.
func noopChildProducer(t testing.TB) (func(files []*os.File) (*exec.Cmd, error), func()) {
t.Helper()
var (
mu sync.Mutex
spawned []*exec.Cmd
)
func getAddr() string {
return fmt.Sprintf("127.0.0.1:%d", rand.Intn(9000-3000)+3000)
produce := func(_ []*os.File) (*exec.Cmd, error) {
cmd := exec.Command(os.Args[0], "-test.run=^$")
cmd.Env = append(os.Environ(), preforkChildEnvVariable+"="+preforkChildEnvValue)
if err := cmd.Start(); err != nil {
return nil, err
}
mu.Lock()
spawned = append(spawned, cmd)
mu.Unlock()
return cmd, nil
}
cleanup := func() {
mu.Lock()
defer mu.Unlock()
for _, cmd := range spawned {
if cmd == nil || cmd.Process == nil {
continue
}
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
}
}
return produce, cleanup
}
func Test_IsChild(t *testing.T) {
// This test can't run parallel as it modifies the process environment.
v := IsChild()
if v {
t.Errorf("IsChild() == %v, want %v", v, false)
// This test cannot run in parallel — IsChild() reads a process-global env var.
if IsChild() {
t.Fatal("test starts as child unexpectedly")
}
setUp()
defer tearDown()
v = IsChild()
if !v {
t.Errorf("IsChild() == %v, want %v", v, true)
t.Setenv(preforkChildEnvVariable, preforkChildEnvValue)
if !IsChild() {
t.Errorf("IsChild() == false after Setenv, want true")
}
}
@@ -53,51 +87,37 @@ func Test_New(t *testing.T) {
if p.Network != defaultNetwork {
t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
}
if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() {
t.Errorf("Prefork.ServeFunc == %p, want %p", p.ServeFunc, s.Serve)
if p.RecoverThreshold <= 0 {
t.Errorf("Prefork.RecoverThreshold == %d, want > 0", p.RecoverThreshold)
}
if reflect.ValueOf(p.ServeTLSFunc).Pointer() != reflect.ValueOf(s.ServeTLS).Pointer() {
t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSFunc, s.ServeTLS)
}
if reflect.ValueOf(p.ServeTLSEmbedFunc).Pointer() != reflect.ValueOf(s.ServeTLSEmbed).Pointer() {
t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSEmbedFunc, s.ServeTLSEmbed)
if p.ServeFunc == nil || p.ServeTLSFunc == nil || p.ServeTLSEmbedFunc == nil {
t.Error("New() did not wire one of ServeFunc/ServeTLSFunc/ServeTLSEmbedFunc")
}
}
func Test_listen(t *testing.T) {
func Test_listen_Reuseport(t *testing.T) {
prev := runtime.GOMAXPROCS(0)
t.Cleanup(func() {
runtime.GOMAXPROCS(prev)
})
p := &Prefork{
Reuseport: true,
}
addr := getAddr()
p := &Prefork{Reuseport: true}
addr := freeAddr(t)
ln, err := p.listen(addr)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() {
_ = ln.Close()
})
ln.Close()
lnAddr := ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
if got, want := ln.Addr().String(), addr; got != want {
t.Errorf("ln.Addr() == %q, want %q", got, want)
}
if p.Network != defaultNetwork {
t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
}
procs := runtime.GOMAXPROCS(0)
if procs != 1 {
t.Errorf("GOMAXPROCS == %d, want %d", procs, 1)
}
}
func Test_setTCPListenerFiles(t *testing.T) {
@@ -108,144 +128,187 @@ func Test_setTCPListenerFiles(t *testing.T) {
}
p := &Prefork{}
addr := getAddr()
addr := freeAddr(t)
err := p.setTCPListenerFiles(addr)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
if err := p.setTCPListenerFiles(addr); err != nil {
t.Fatalf("setTCPListenerFiles: %v", err)
}
t.Cleanup(func() {
_ = p.ln.Close()
for _, f := range p.files {
_ = f.Close()
}
})
if p.ln == nil {
t.Fatal("Prefork.ln is nil")
t.Fatal("p.ln is nil after setTCPListenerFiles")
}
p.ln.Close()
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
if got, want := p.ln.Addr().String(), addr; got != want {
t.Errorf("p.ln.Addr() == %q, want %q", got, want)
}
if p.Network != defaultNetwork {
t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
}
if len(p.files) != 1 {
t.Errorf("Prefork.files == %d, want %d", len(p.files), 1)
t.Errorf("len(p.files) == %d, want 1", len(p.files))
}
}
func Test_ListenAndServe(t *testing.T) {
// This test can't run parallel as it modifies the process environment.
func Test_setTCPListenerFiles_BadAddr(t *testing.T) {
t.Parallel()
setUp()
defer tearDown()
s := &fasthttp.Server{}
p := New(s)
p.Reuseport = true
p.ServeFunc = func(ln net.Listener) error {
return nil
if runtime.GOOS == "windows" {
t.SkipNow()
}
addr := getAddr()
err := p.ListenAndServe(addr)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
p.ln.Close()
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
}
if p.ln == nil {
t.Error("Prefork.ln is nil")
p := &Prefork{}
if err := p.setTCPListenerFiles("definitely not an address"); err == nil {
t.Fatal("expected error for malformed addr, got nil")
}
}
func Test_ListenAndServeTLS(t *testing.T) {
// This test can't run parallel as it modifies the process environment.
// Test_ListenAndServe_Stub_ChildPath drives the child branch of all three
// ListenAndServe* entry points using a stubbed Serve function. It replaces the
// previous trio of near-identical tests that only validated field assignment.
func Test_ListenAndServe_Stub_ChildPath(t *testing.T) {
// child env mutation precludes t.Parallel.
t.Setenv(preforkChildEnvVariable, preforkChildEnvValue)
setUp()
defer tearDown()
s := &fasthttp.Server{}
p := New(s)
p.Reuseport = true
p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error {
return nil
type call struct {
listener bool
certFile string
keyFile string
certData string
keyData string
}
addr := getAddr()
err := p.ListenAndServeTLS(addr, "./key", "./cert")
if err != nil {
t.Errorf("Unexpected error: %v", err)
tests := []struct {
name string
run func(t *testing.T, p *Prefork, addr string) error
want call
}{
{
name: "ListenAndServe",
run: func(_ *testing.T, p *Prefork, addr string) error { return p.ListenAndServe(addr) },
want: call{listener: true},
},
{
name: "ListenAndServeTLS",
run: func(_ *testing.T, p *Prefork, addr string) error {
return p.ListenAndServeTLS(addr, "./key", "./cert")
},
want: call{listener: true, certFile: "./cert", keyFile: "./key"},
},
{
name: "ListenAndServeTLSEmbed",
run: func(_ *testing.T, p *Prefork, addr string) error {
return p.ListenAndServeTLSEmbed(addr, []byte("certPEM"), []byte("keyPEM"))
},
want: call{listener: true, certData: "certPEM", keyData: "keyPEM"},
},
}
p.ln.Close()
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var got call
p := New(&fasthttp.Server{})
p.Reuseport = true
p.ServeFunc = func(ln net.Listener) error {
got.listener = ln != nil
return nil
}
p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error {
got.listener = ln != nil
got.certFile = certFile
got.keyFile = keyFile
return nil
}
p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error {
got.listener = ln != nil
got.certData = string(certData)
got.keyData = string(keyData)
return nil
}
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
}
addr := freeAddr(t)
if err := tc.run(t, p, addr); err != nil {
t.Fatalf("%s: %v", tc.name, err)
}
t.Cleanup(func() {
if p.ln != nil {
_ = p.ln.Close()
}
})
if p.ln == nil {
t.Error("Prefork.ln is nil")
if got != tc.want {
t.Errorf("%s call = %+v, want %+v", tc.name, got, tc.want)
}
})
}
}
func Test_ListenAndServeTLSEmbed(t *testing.T) {
// This test can't run parallel as it modifies the process environment.
func Test_doCommand_CommandProducerErrors(t *testing.T) {
t.Parallel()
setUp()
defer tearDown()
s := &fasthttp.Server{}
p := New(s)
p.Reuseport = true
p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error {
return nil
tests := []struct {
name string
produce func(files []*os.File) (*exec.Cmd, error)
wantErr error
}{
{
name: "producer returns error",
produce: func([]*os.File) (*exec.Cmd, error) {
return nil, errors.New("boom")
},
},
{
name: "producer returns nil cmd",
//nolint:nilnil // intentionally tests the (nil, nil) misbehaviour guard
produce: func([]*os.File) (*exec.Cmd, error) {
return nil, nil
},
wantErr: ErrCommandProducerNilCmd,
},
{
name: "producer returns unstarted cmd",
produce: func([]*os.File) (*exec.Cmd, error) {
return &exec.Cmd{}, nil
},
wantErr: ErrCommandProducerNotStarted,
},
}
addr := getAddr()
err := p.ListenAndServeTLSEmbed(addr, []byte("key"), []byte("cert"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
p.ln.Close()
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
}
if p.ln == nil {
t.Error("Prefork.ln is nil")
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
p := &Prefork{CommandProducer: tc.produce}
cmd, err := p.doCommand()
if cmd != nil {
t.Errorf("expected nil cmd on error, got %v", cmd)
}
if err == nil {
t.Fatal("expected error, got nil")
}
if tc.wantErr != nil && !errors.Is(err, tc.wantErr) {
t.Errorf("err = %v, want errors.Is %v", err, tc.wantErr)
}
})
}
}
type testLogger struct {
mu sync.Mutex
messages []string
}
func (l *testLogger) Printf(format string, args ...any) {
l.messages = append(l.messages, fmt.Sprintf(format, args...))
msg := fmt.Sprintf(format, args...)
l.mu.Lock()
l.messages = append(l.messages, msg)
l.mu.Unlock()
}
// Test_Prefork_Lifecycle runs the full prefork lifecycle with a CommandProducer
// and verifies that callbacks are invoked in the correct order with the correct arguments.
// Test_Prefork_Lifecycle drives prefork() to ErrOverRecovery via
// short-lived no-op children and asserts the callback ordering / arguments.
func Test_Prefork_Lifecycle(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() {
runtime.GOMAXPROCS(prev)
})
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
type event struct {
name string
@@ -260,16 +323,14 @@ func Test_Prefork_Lifecycle(t *testing.T) {
mu.Unlock()
}
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
Logger: &testLogger{},
CommandProducer: func(_ []*os.File) (*exec.Cmd, error) {
cmd := exec.Command(os.Args[0], "-test.run=^$")
cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1")
err := cmd.Start()
return cmd, err
},
CommandProducer: produce,
OnChildSpawn: func(pid int) error {
record("spawn", pid)
return nil
@@ -278,12 +339,13 @@ func Test_Prefork_Lifecycle(t *testing.T) {
record("ready", childPIDs...)
return nil
},
OnChildRecover: func(oldPid, newPid int) {
record("recover", oldPid, newPid)
OnChildRecover: func(oldPID, newPID int) error {
record("recover", oldPID, newPID)
return nil
},
}
err := p.prefork(getAddr())
err := p.prefork(freeAddr(t))
if !errors.Is(err, ErrOverRecovery) {
t.Fatalf("expected ErrOverRecovery, got: %v", err)
}
@@ -291,10 +353,9 @@ func Test_Prefork_Lifecycle(t *testing.T) {
mu.Lock()
defer mu.Unlock()
// Verify we got spawn events for initial children
var spawnCount int
var readyCount int
var recoverCount int
goMaxProcs := runtime.GOMAXPROCS(0)
var spawnCount, readyCount, recoverCount int
for _, e := range events {
switch e.name {
case "spawn":
@@ -318,21 +379,17 @@ func Test_Prefork_Lifecycle(t *testing.T) {
}
}
goMaxProcs := runtime.GOMAXPROCS(0)
if readyCount != 1 {
t.Errorf("OnMasterReady called %d times, want 1", readyCount)
}
// Initial spawns + at least one recovery spawn
if spawnCount < goMaxProcs {
t.Errorf("OnChildSpawn called %d times, want at least %d", spawnCount, goMaxProcs)
}
if recoverCount == 0 {
t.Error("OnChildRecover was never called")
}
// ready must come after exactly goMaxProcs initial spawns.
readyIdx := -1
spawnsBeforeReady := 0
for i, e := range events {
@@ -344,7 +401,6 @@ func Test_Prefork_Lifecycle(t *testing.T) {
spawnsBeforeReady++
}
}
if readyIdx == -1 {
t.Fatal("OnMasterReady was never called")
}
@@ -352,8 +408,8 @@ func Test_Prefork_Lifecycle(t *testing.T) {
t.Errorf("OnMasterReady called after %d initial spawns, want %d", spawnsBeforeReady, goMaxProcs)
}
// every recover event must be preceded by a spawn for the new PID.
recoveredSpawnByPID := make(map[int]bool)
recoveredPIDs := make(map[int]bool)
for _, e := range events[readyIdx+1:] {
if e.name == "spawn" {
recoveredSpawnByPID[e.pids[0]] = true
@@ -362,56 +418,154 @@ func Test_Prefork_Lifecycle(t *testing.T) {
if !recoveredSpawnByPID[e.pids[1]] {
t.Errorf("OnChildRecover for PID %d happened before OnChildSpawn", e.pids[1])
}
recoveredPIDs[e.pids[1]] = true
}
}
for pid := range recoveredPIDs {
if !recoveredSpawnByPID[pid] {
t.Errorf("OnChildRecover for PID %d did not have a matching OnChildSpawn", pid)
}
}
}
func Test_Prefork_RecoveredChildSpawnError(t *testing.T) {
func Test_Prefork_InitialChildSpawnError(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() {
runtime.GOMAXPROCS(prev)
})
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
expectedErr := errors.New("spawn failed")
var spawnCount int
var recoverCount int
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
expectedErr := errors.New("initial spawn rejected")
var calls atomic.Int32
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
Logger: &testLogger{},
CommandProducer: func(_ []*os.File) (*exec.Cmd, error) {
cmd := exec.Command(os.Args[0], "-test.run=^$")
cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1")
err := cmd.Start()
return cmd, err
CommandProducer: produce,
OnChildSpawn: func(_ int) error {
calls.Add(1)
return expectedErr
},
}
err := p.prefork(freeAddr(t))
if !errors.Is(err, expectedErr) {
t.Fatalf("expected %v, got: %v", expectedErr, err)
}
if calls.Load() == 0 {
t.Fatal("OnChildSpawn was never invoked")
}
}
func Test_Prefork_OnMasterReadyError(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
expectedErr := errors.New("ready rejected")
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
Logger: &testLogger{},
CommandProducer: produce,
OnMasterReady: func([]int) error {
return expectedErr
},
}
err := p.prefork(freeAddr(t))
if !errors.Is(err, expectedErr) {
t.Fatalf("expected %v, got: %v", expectedErr, err)
}
}
func Test_Prefork_RecoveredChildSpawnError(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
expectedErr := errors.New("spawn failed")
var spawnCount, recoverCount atomic.Int32
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
Logger: &testLogger{},
CommandProducer: produce,
OnChildSpawn: func(pid int) error {
if pid <= 0 {
t.Errorf("OnChildSpawn called with invalid PID: %d", pid)
}
spawnCount++
if spawnCount > runtime.GOMAXPROCS(0) {
n := spawnCount.Add(1)
if int(n) > runtime.GOMAXPROCS(0) {
return expectedErr
}
return nil
},
OnChildRecover: func(_, _ int) {
recoverCount++
OnChildRecover: func(_, _ int) error {
recoverCount.Add(1)
return nil
},
}
err := p.prefork(getAddr())
err := p.prefork(freeAddr(t))
if !errors.Is(err, expectedErr) {
t.Fatalf("expected %v, got: %v", expectedErr, err)
}
if recoverCount != 0 {
t.Fatalf("OnChildRecover called %d times, want 0", recoverCount)
if got := recoverCount.Load(); got != 0 {
t.Fatalf("OnChildRecover called %d times, want 0", got)
}
}
func Test_Prefork_HookPanicSurfaces(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
Logger: &testLogger{},
CommandProducer: produce,
OnChildSpawn: func(_ int) error {
panic("user code blew up")
},
}
err := p.prefork(freeAddr(t))
if err == nil {
t.Fatal("expected error from panicking hook, got nil")
}
}
// Test_Prefork_RecoverInterval verifies the optional backoff actually delays
// the respawn — without rabbit-holing into hard wallclock assertions.
func Test_Prefork_RecoverInterval(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
const interval = 50 * time.Millisecond
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
RecoverInterval: interval,
Logger: &testLogger{},
CommandProducer: produce,
}
start := time.Now()
err := p.prefork(freeAddr(t))
elapsed := time.Since(start)
if !errors.Is(err, ErrOverRecovery) {
t.Fatalf("expected ErrOverRecovery, got %v", err)
}
// At least one recover interval must have elapsed before threshold fired.
if elapsed < interval {
t.Errorf("elapsed %v < interval %v; backoff did not apply", elapsed, interval)
}
}