Files
fasthttp/prefork/prefork.go
T
René 262ea09c9e feat(prefork): graceful shutdown, leak fixes, hook robustness
Addresses outstanding review concerns and several adjacent issues
surfaced during a follow-up review pass.

Lifecycle / supervision
- Track every per-child Wait goroutine via sync.WaitGroup and unblock
  pending sigCh sends through a context.Cancel so early-return paths
  (OnChildSpawn / OnMasterReady error, recovery doCommand error,
  ErrOverRecovery) can no longer leak goroutines or stall children.
- Install signal.Notify(SIGTERM, SIGINT) in the master so deploy/
  rolling-restart signals enter the shutdown path instead of killing
  the master without graceful teardown.
- Replace the unconditional SIGKILL defer with a SIGTERM-then-SIGKILL
  sequence gated by a configurable ShutdownGracePeriod (defaults to 5s,
  Windows path stays SIGKILL since Signal(SIGTERM) is unsupported).

API
- OnChildRecover now returns error so callers can implement recovery
  policies (circuit-breaker etc.); panic in any hook is recovered and
  surfaced as the returned error, with diagnostic logging.
- Add RecoverInterval (optional crash-loop backoff) and
  ShutdownGracePeriod fields with safe zero-value defaults.
- Export ErrCommandProducerNilCmd and ErrCommandProducerNotStarted
  sentinel errors so callers can errors.Is them.
- Rename oldPid/newPid to oldPID/newPID per Go initialism convention.
- Logger interface now declares an explicit compile-time compatibility
  check with fasthttp.Logger.

Resource hygiene
- Master closes both the original tcpListener and the duped fd in
  p.files when prefork() returns; previously the duped fd leaked once
  per call.
- doCommand wraps every error path with %w + fmt.Errorf so caller-side
  diagnostics keep stage context.
- Strip pre-existing FASTHTTP_PREFORK_CHILD entries before appending so
  child env never carries duplicate keys.
- Extract magic numbers as package constants
  (inheritedListenerFD, masterPollInterval, defaultShutdownGracePeriod,
  preforkChildEnvValue).
- Rename the inherited listener fd via os.NewFile so net.FileListener
  errors are diagnosable.

Tests
- Migrate to t.Setenv (drop the global setUp/tearDown helpers) — fixes
  the env-mutation-vs-parallel race.
- Replace rand.Intn port helper with `:0` + Listener.Addr() to remove
  port-collision flakes under -count and parallel runs.
- Collapse the three near-identical Test_ListenAndServe* tests into a
  single table-driven subtest that actually asserts the args forwarded
  to ServeFunc/ServeTLSFunc/ServeTLSEmbedFunc.
- Add coverage for the previously untested branches:
  CommandProducer returning err / nil cmd / unstarted cmd,
  initial OnChildSpawn error, OnMasterReady error,
  hook panic surfacing, RecoverInterval enforcement.
- noopChildProducer helper kills + waits any spawned child binaries
  during cleanup so failed tests no longer leave subprocesses around.
2026-05-02 13:12:08 +02:00

638 lines
19 KiB
Go

// Package prefork provides a way to prefork a fasthttp server.
package prefork
import (
"context"
"errors"
"fmt"
"log"
"net"
"os"
"os/exec"
"os/signal"
"runtime"
"sync"
"syscall"
"time"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/reuseport"
)
const (
preforkChildEnvVariable = "FASTHTTP_PREFORK_CHILD"
preforkChildEnvValue = "1"
defaultNetwork = "tcp4"
// inheritedListenerFD is the file descriptor used by the master to pass
// the bound listener to a child process via ExtraFiles. Children open the
// listener via os.NewFile(inheritedListenerFD, ...) when Reuseport is false.
inheritedListenerFD = 3
// masterPollInterval is the period of the watchMaster ppid-poll on Unix.
masterPollInterval = 500 * time.Millisecond
// defaultShutdownGracePeriod is how long the master waits for children to
// exit cleanly after sending SIGTERM before forcibly killing them.
defaultShutdownGracePeriod = 5 * time.Second
)
var (
defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags))
// ErrOverRecovery is returned when child prefork process restarts exceed
// the value of RecoverThreshold.
ErrOverRecovery = errors.New("exceeding the value of RecoverThreshold")
// ErrOnlyReuseportOnWindows is returned when running on Windows without Reuseport.
ErrOnlyReuseportOnWindows = errors.New("windows only supports Reuseport = true")
// ErrCommandProducerNilCmd is returned when a CommandProducer returns
// (nil, nil) instead of a started command.
ErrCommandProducerNilCmd = errors.New("prefork: CommandProducer returned nil command")
// ErrCommandProducerNotStarted is returned when a CommandProducer returns
// an *exec.Cmd whose Process is nil (i.e. cmd.Start() was not called).
ErrCommandProducerNotStarted = errors.New("prefork: CommandProducer must return a started command")
)
// Logger is used for logging formatted messages. Its method set is intentionally
// identical to fasthttp.Logger so that *fasthttp.Server.Logger can be assigned
// directly.
type Logger interface {
// Printf must have the same semantics as log.Printf.
Printf(format string, args ...any)
}
// Compile-time check that fasthttp.Logger satisfies the local Logger interface;
// keeps the two types in sync if either side ever evolves.
var _ Logger = fasthttp.Logger(nil)
// Prefork implements fasthttp server prefork.
//
// Preforks master process (with all cores) between several child processes
// increases performance significantly, because Go doesn't have to share
// and manage memory between cores.
//
// WARNING: using prefork prevents the use of any global state!
// Things like in-memory caches won't work.
type Prefork struct {
// Logger receives diagnostic output. By default the standard log package
// logger writing to stderr is used.
Logger Logger
ln net.Listener
ServeFunc func(ln net.Listener) error
ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error
ServeTLSEmbedFunc func(ln net.Listener, certData, keyData []byte) error
// Network must be "tcp", "tcp4" or "tcp6". Default is "tcp4".
Network string
files []*os.File
// RecoverThreshold caps how often crashed children are respawned before
// the master returns ErrOverRecovery. New() sets it to max(1, GOMAXPROCS/2).
// When constructing a Prefork directly without New(), a zero value will
// terminate the master after the very first child crash.
RecoverThreshold int
// RecoverInterval, when > 0, makes the master sleep for the given duration
// before respawning a crashed child. Useful as crash-loop backoff.
RecoverInterval time.Duration
// ShutdownGracePeriod is the time the master waits for children to exit
// after sending SIGTERM before falling back to SIGKILL. Defaults to 5s
// when zero. On Windows SIGTERM is not delivered, so this is unused there.
ShutdownGracePeriod time.Duration
// Reuseport selects a reuseport listener instead of fd-passing.
// See: https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/
// Disabled by default.
Reuseport bool
// OnMasterDeath, when non-nil, enables monitoring of the master process
// in child processes. If the master process dies unexpectedly, this
// callback is invoked. This allows custom cleanup before shutdown.
//
// It is recommended to set this to func() { os.Exit(1) } if no custom
// cleanup is needed.
//
// Threading: invoked once from a watcher goroutine in the child. Must not
// block the goroutine for long; must not call Prefork methods.
OnMasterDeath func()
// OnChildSpawn is called in the master after a new child process is
// successfully started, both during initial spawn and during recovery.
// It receives the PID of the newly spawned child process.
//
// If this callback returns an error, all already-running children are
// killed and the prefork operation returns that error.
//
// Threading: invoked synchronously from the master goroutine. Must not
// block; must not call Prefork methods. Panics are recovered and surfaced
// as the returned error.
OnChildSpawn func(pid int) error
// OnMasterReady is called in the master process exactly once, after all
// initial children have been spawned and before the supervision loop runs.
// It receives a slice of all initial child PIDs.
//
// If this callback returns an error, the prefork operation aborts and
// returns that error after killing the children.
//
// Threading: invoked synchronously from the master goroutine. The slice
// is owned by the caller after the call returns. Panics are recovered.
OnMasterReady func(childPIDs []int) error
// OnChildRecover is called in the master after a crashed child has been
// replaced. It receives the PID of the old (crashed) process and the PID
// of its replacement.
//
// If this callback returns an error, all running children are killed and
// the prefork operation returns that error.
//
// Threading: invoked synchronously from the master goroutine, after
// OnChildSpawn for the new child. Panics are recovered and surfaced as
// the returned error.
OnChildRecover func(oldPID, newPID int) error
// CommandProducer creates and starts a child process command.
// If nil, the default implementation re-executes the current binary
// with FASTHTTP_PREFORK_CHILD=1 in the environment, stdout/stderr
// inherited from the parent, and the given files as ExtraFiles.
//
// A custom producer must:
// - Set FASTHTTP_PREFORK_CHILD=1 in the child's environment
// (otherwise IsChild() returns false and the child won't serve)
// - Call cmd.Start() before returning (the returned command must be
// started so cmd.Process is non-nil)
// - Pass the provided files as cmd.ExtraFiles when Reuseport is false
//
// Primarily useful for testing (injecting dummy commands) or for
// frameworks that need custom child process setup.
CommandProducer func(files []*os.File) (*exec.Cmd, error)
}
// IsChild reports whether the current process is a prefork child.
func IsChild() bool {
return os.Getenv(preforkChildEnvVariable) == preforkChildEnvValue
}
// New wraps the fasthttp server to run with preforked processes.
// It seeds Network and RecoverThreshold to sensible defaults; existing
// fields on s (Logger, Serve*) are captured.
func New(s *fasthttp.Server) *Prefork {
return &Prefork{
Network: defaultNetwork,
RecoverThreshold: defaultRecoverThreshold(),
Logger: s.Logger,
ServeFunc: s.Serve,
ServeTLSFunc: s.ServeTLS,
ServeTLSEmbedFunc: s.ServeTLSEmbed,
}
}
func defaultRecoverThreshold() int {
return max(1, runtime.GOMAXPROCS(0)/2)
}
func (p *Prefork) logger() Logger {
if p.Logger != nil {
return p.Logger
}
return defaultLogger
}
// invokeHook runs fn under a panic recovery, returning the panic as an error
// so a misbehaving callback never tears down the master.
func (p *Prefork) invokeHook(name string, fn func() error) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("prefork: %s panicked: %v", name, r)
p.logger().Printf("%v", err)
}
}()
return fn()
}
func (p *Prefork) watchMaster(masterPID int) {
if runtime.GOOS == "windows" {
// On Windows, os.Getppid() returns a static PID that doesn't change
// when the parent exits (no reparenting). Use FindProcess+Wait instead.
proc, err := os.FindProcess(masterPID)
if err != nil {
p.logger().Printf("watchMaster: failed to find master process %d: %v", masterPID, err)
p.OnMasterDeath()
return
}
if _, err = proc.Wait(); err != nil {
p.logger().Printf("watchMaster: error waiting for master process %d: %v", masterPID, err)
}
p.logger().Printf("master process %d died", masterPID)
p.OnMasterDeath()
return
}
// Unix/Linux/macOS: When the master exits, the OS reparents the child
// to another process, causing Getppid() to change. Comparing against
// the original masterPID (instead of hardcoding 1) ensures this works
// correctly when the master itself is PID 1 (e.g. in Docker containers).
ticker := time.NewTicker(masterPollInterval)
defer ticker.Stop()
for range ticker.C {
if os.Getppid() != masterPID {
p.logger().Printf("master process %d died", masterPID)
p.OnMasterDeath()
return
}
}
}
func (p *Prefork) listen(addr string) (net.Listener, error) {
runtime.GOMAXPROCS(1)
if p.Network == "" {
p.Network = defaultNetwork
}
if p.Reuseport {
return reuseport.Listen(p.Network, addr)
}
// fd inheritedListenerFD is the first ExtraFiles entry passed by the
// master process when Reuseport is false. Naming the file gives clearer
// errors from net.FileListener if the fd is invalid.
return net.FileListener(os.NewFile(inheritedListenerFD, "fasthttp-prefork-listener"))
}
// listenAsChild performs the common child process setup: creates the listener
// and starts watching the master process if OnMasterDeath is configured.
func (p *Prefork) listenAsChild(addr string) (net.Listener, error) {
ln, err := p.listen(addr)
if err != nil {
return nil, err
}
p.ln = ln
if p.OnMasterDeath != nil {
go p.watchMaster(os.Getppid())
}
return ln, nil
}
func (p *Prefork) setTCPListenerFiles(addr string) error {
if p.Network == "" {
p.Network = defaultNetwork
}
tcpAddr, err := net.ResolveTCPAddr(p.Network, addr)
if err != nil {
return fmt.Errorf("prefork: resolve %s/%s: %w", p.Network, addr, err)
}
tcpListener, err := net.ListenTCP(p.Network, tcpAddr)
if err != nil {
return fmt.Errorf("prefork: listen tcp %s: %w", addr, err)
}
p.ln = tcpListener
listenerFile, err := tcpListener.File()
if err != nil {
return fmt.Errorf("prefork: dup listener fd: %w", err)
}
p.files = []*os.File{listenerFile}
return nil
}
// childEnv returns os.Environ() with the prefork child marker variable set,
// stripping any pre-existing value to avoid duplicate keys with last-wins
// semantics.
func childEnv() []string {
src := os.Environ()
out := make([]string, 0, len(src)+1)
prefix := preforkChildEnvVariable + "="
for _, kv := range src {
if len(kv) >= len(prefix) && kv[:len(prefix)] == prefix {
continue
}
out = append(out, kv)
}
out = append(out, preforkChildEnvVariable+"="+preforkChildEnvValue)
return out
}
func (p *Prefork) doCommand() (*exec.Cmd, error) {
if p.CommandProducer != nil {
cmd, err := p.CommandProducer(p.files)
if err != nil {
return nil, fmt.Errorf("prefork: CommandProducer: %w", err)
}
if cmd == nil {
return nil, ErrCommandProducerNilCmd
}
if cmd.Process == nil {
return nil, ErrCommandProducerNotStarted
}
return cmd, nil
}
executable, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("prefork: resolve executable: %w", err)
}
args := append([]string{executable}, os.Args[1:]...)
cmd := &exec.Cmd{
Path: executable,
Args: args,
Stdout: os.Stdout,
Stderr: os.Stderr,
Env: childEnv(),
ExtraFiles: p.files,
}
if err = cmd.Start(); err != nil {
return nil, fmt.Errorf("prefork: start child %q: %w", executable, err)
}
return cmd, nil
}
type childExit struct {
err error
pid int
}
// shutdownChildren signals every entry in childProcs first with SIGTERM (on
// platforms where it is supported) and waits up to grace for them to exit.
// Survivors are then killed unconditionally. wg tracks the per-child Wait
// goroutines and must be drained before returning so no goroutine outlives
// prefork().
func (p *Prefork) shutdownChildren(
childProcs map[int]*exec.Cmd,
wg *sync.WaitGroup,
cancel context.CancelFunc,
grace time.Duration,
) {
if grace <= 0 {
grace = defaultShutdownGracePeriod
}
if runtime.GOOS != "windows" {
for pid, proc := range childProcs {
if proc == nil || proc.Process == nil {
continue
}
if termErr := proc.Process.Signal(syscall.SIGTERM); termErr != nil &&
!errors.Is(termErr, os.ErrProcessDone) {
p.logger().Printf("prefork: SIGTERM child %d: %v", pid, termErr)
}
}
}
// Wait for graceful exits, with a timeout fallback to SIGKILL.
graceful := make(chan struct{})
go func() {
wg.Wait()
close(graceful)
}()
timer := time.NewTimer(grace)
defer timer.Stop()
select {
case <-graceful:
case <-timer.C:
}
for pid, proc := range childProcs {
if proc == nil || proc.Process == nil {
continue
}
if killErr := proc.Process.Kill(); killErr != nil &&
!errors.Is(killErr, os.ErrProcessDone) {
p.logger().Printf("prefork: kill child %d: %v", pid, killErr)
}
}
// Cancel the per-Wait goroutines' send-context so any still blocked on
// sigCh send unblock cleanly, then wait for all of them to exit.
cancel()
wg.Wait()
}
func (p *Prefork) prefork(addr string) (err error) { //nolint:gocyclo
if !p.Reuseport {
if runtime.GOOS == "windows" {
return ErrOnlyReuseportOnWindows
}
if err = p.setTCPListenerFiles(addr); err != nil {
return err
}
// Close listener fds opened by setTCPListenerFiles. Both the original
// tcpListener (p.ln) and the duped fd (p.files[0]) belong to the
// master only; children inherit independent dup'd copies via fork+exec.
defer func() {
err = errors.Join(err, p.ln.Close())
for _, f := range p.files {
if closeErr := f.Close(); closeErr != nil {
p.logger().Printf("prefork: close listener fd: %v", closeErr)
}
}
p.files = nil
}()
}
// ctx cancels per-child Wait goroutines so they unblock from sigCh sends
// once the supervision loop is gone.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Catch SIGTERM/SIGINT in the master so we run our shutdown path instead
// of being killed by the OS without children getting a graceful chance.
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signalCh)
goMaxProcs := runtime.GOMAXPROCS(0)
// Buffer is sized to initial fleet; per-child goroutines fall back to a
// context-aware select on send so capacity is not load-bearing.
sigCh := make(chan childExit, goMaxProcs)
childProcs := make(map[int]*exec.Cmd, goMaxProcs)
var wg sync.WaitGroup
startWait := func(cmd *exec.Cmd, pid int) {
wg.Add(1)
go func() {
defer wg.Done()
result := childExit{pid: pid, err: cmd.Wait()}
select {
case sigCh <- result:
case <-ctx.Done():
}
}()
}
defer func() {
p.shutdownChildren(childProcs, &wg, cancel, p.ShutdownGracePeriod)
}()
childPIDs := make([]int, 0, goMaxProcs)
for range goMaxProcs {
var cmd *exec.Cmd
if cmd, err = p.doCommand(); err != nil {
p.logger().Printf("prefork: failed to start a child process: %v", err)
return err
}
pid := cmd.Process.Pid
childProcs[pid] = cmd
childPIDs = append(childPIDs, pid)
// Start Wait goroutine before the user callback so a panic / error
// return from OnChildSpawn cannot leave a zombie behind.
startWait(cmd, pid)
if p.OnChildSpawn != nil {
pid := pid
if hookErr := p.invokeHook("OnChildSpawn", func() error {
return p.OnChildSpawn(pid)
}); hookErr != nil {
p.logger().Printf("prefork: OnChildSpawn for PID %d: %v", pid, hookErr)
return hookErr
}
}
}
if p.OnMasterReady != nil {
pids := append([]int(nil), childPIDs...)
if hookErr := p.invokeHook("OnMasterReady", func() error {
return p.OnMasterReady(pids)
}); hookErr != nil {
p.logger().Printf("prefork: OnMasterReady: %v", hookErr)
return hookErr
}
}
var exitedProcs int
for {
select {
case sig := <-signalCh:
p.logger().Printf("prefork: received signal %v, shutting down", sig)
return nil
case sig := <-sigCh:
delete(childProcs, sig.pid)
if sig.err != nil {
p.logger().Printf("prefork: child PID %d exited: %v", sig.pid, sig.err)
} else {
p.logger().Printf("prefork: child PID %d exited cleanly", sig.pid)
}
exitedProcs++
if exitedProcs > p.RecoverThreshold {
p.logger().Printf(
"prefork: child exits (%d) exceed RecoverThreshold (%d), terminating master",
exitedProcs, p.RecoverThreshold,
)
return ErrOverRecovery
}
if p.RecoverInterval > 0 {
select {
case <-time.After(p.RecoverInterval):
case <-signalCh:
return nil
}
}
cmd, doErr := p.doCommand()
if doErr != nil {
p.logger().Printf("prefork: recovery doCommand: %v", doErr)
return doErr
}
newPID := cmd.Process.Pid
childProcs[newPID] = cmd
startWait(cmd, newPID)
if p.OnChildSpawn != nil {
newPID := newPID
if hookErr := p.invokeHook("OnChildSpawn", func() error {
return p.OnChildSpawn(newPID)
}); hookErr != nil {
p.logger().Printf("prefork: OnChildSpawn for recovered PID %d: %v", newPID, hookErr)
return hookErr
}
}
if p.OnChildRecover != nil {
oldPID := sig.pid
newPID := newPID
if hookErr := p.invokeHook("OnChildRecover", func() error {
return p.OnChildRecover(oldPID, newPID)
}); hookErr != nil {
p.logger().Printf("prefork: OnChildRecover (%d -> %d): %v", oldPID, newPID, hookErr)
return hookErr
}
}
}
}
}
// ListenAndServe serves HTTP requests from the given TCP addr.
func (p *Prefork) ListenAndServe(addr string) error {
if IsChild() {
ln, err := p.listenAsChild(addr)
if err != nil {
return err
}
return p.ServeFunc(ln)
}
return p.prefork(addr)
}
// ListenAndServeTLS serves HTTPS requests from the given TCP addr.
//
// Note: parameter order is (addr, certKey, certFile) — key path comes
// before cert path. This is preserved for backward compatibility with
// existing callers and differs from fasthttp.Server.ListenAndServeTLS.
// New code should prefer ListenAndServeTLSEmbed.
func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error {
if IsChild() {
ln, err := p.listenAsChild(addr)
if err != nil {
return err
}
return p.ServeTLSFunc(ln, certFile, certKey)
}
return p.prefork(addr)
}
// ListenAndServeTLSEmbed serves HTTPS requests from the given TCP addr.
//
// certData and keyData must contain valid TLS certificate and key data.
func (p *Prefork) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error {
if IsChild() {
ln, err := p.listenAsChild(addr)
if err != nil {
return err
}
return p.ServeTLSEmbedFunc(ln, certData, keyData)
}
return p.prefork(addr)
}