feat(prefork): Enhance prefork management with WatchMaster, CommandProducer, and Windows support (#2180)

* feat(prefork): add WatchMaster and callback support for child process management

* feat(prefork): add CommandProducer for customizable child process commands

* refactor(prefork): improve comments and parameter order in ListenAndServeTLS

* refactor(prefork): enhance logging message and clarify OnChildRecover callback comment

* fix(prefork): add Windows support to watchMaster

On Windows, os.Getppid() returns a static PID that doesn't change when
the parent exits (no reparenting). Use FindProcess+Wait instead, which
correctly detects parent exit. Also document why masterPID comparison
works for Docker containers (master PID 1 case).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(prefork): extract listenAsChild to eliminate DRY violation

The three ListenAndServe* methods had identical child setup code
(listen, set ln, watch master). Extract to listenAsChild() for
cleaner code. Also add comment for the magic file descriptor number 3.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(prefork): restore upstream ListenAndServeTLS parameter order

Keep upstream's (addr, certKey, certFile) signature to avoid breaking
callers. Fix the doc comment to match the actual parameter order instead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(prefork): address lint errors and review feedback

Lint fixes:
- Remove unused Reuseport field write in test (govet/unusedwrite)
- Replace fmt.Errorf with errors.New for static errors (perfsprint)

Review feedback (Copilot):
- Validate CommandProducer returns a started command (nil/Process check)
- Clarify ListenAndServeTLS doc: parameter order and internal forwarding
- Use hermetic test binary re-exec instead of external 'go' binary
- Rename misleading test to reflect what it actually asserts

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(prefork): address maintainer review feedback

- watchMaster: log errors from FindProcess/Wait instead of swallowing
- watchMaster: don't call OnMasterDeath if FindProcess fails
- OnChildRecover: change signature to func(pid int), drop unused error return
- OnChildSpawn: add comment clarifying deferred cleanup handles the child
- CommandProducer: improve docs describing contract and use cases

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* refactor(prefork): address erikdubbelboer review feedback

- OnChildRecover: signature changed to func(oldPid, newPid int) so
  callers can track which process was replaced
- OnChildSpawn: also called for recovered children (a recovered child
  is still a spawned child)
- watchMaster: call OnMasterDeath when FindProcess fails (process is
  most likely gone)
- CommandProducer: document that FASTHTTP_PREFORK_CHILD=1 must be set
  in the child env, and what the default does when nil

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(prefork): avoid zombie processes and replace shallow tests

- Move Wait() goroutine before OnChildSpawn so Kill()+Wait() works
  correctly if a callback fails and the deferred cleanup runs
- Add Wait() call in deferred cleanup after Kill() to reap children
- Same fix in recovery loop
- Remove shallow callback tests that only tested Go compiler
- Add Test_Prefork_Lifecycle: runs full prefork with CommandProducer,
  verifies callbacks fire in correct order with correct arguments

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(prefork): ensure recovery default stays positive

* test(prefork): isolate lifecycle tests

* fix(prefork): tighten recovery callback flow

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
RW
2026-05-02 13:17:27 +02:00
committed by GitHub
parent 805cd10465
commit 481e579af9
2 changed files with 900 additions and 238 deletions
+421 -104
View File
@@ -2,12 +2,17 @@
package prefork
import (
"context"
"errors"
"fmt"
"log"
"net"
"os"
"os/exec"
"os/signal"
"runtime"
"sync"
"syscall"
"time"
"github.com/valyala/fasthttp"
@@ -16,25 +21,53 @@ import (
const (
preforkChildEnvVariable = "FASTHTTP_PREFORK_CHILD"
preforkChildEnvValue = "1"
defaultNetwork = "tcp4"
// inheritedListenerFD is the file descriptor used by the master to pass
// the bound listener to a child process via ExtraFiles. Children open the
// listener via os.NewFile(inheritedListenerFD, ...) when Reuseport is false.
inheritedListenerFD = 3
// masterPollInterval is the period of the watchMaster ppid-poll on Unix.
masterPollInterval = 500 * time.Millisecond
// defaultShutdownGracePeriod is how long the master waits for children to
// exit cleanly after sending SIGTERM before forcibly killing them.
defaultShutdownGracePeriod = 5 * time.Second
)
var (
defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags))
// ErrOverRecovery is returned when the times of starting over child prefork processes exceed
// the threshold.
// ErrOverRecovery is returned when child prefork process restarts exceed
// the value of RecoverThreshold.
ErrOverRecovery = errors.New("exceeding the value of RecoverThreshold")
// ErrOnlyReuseportOnWindows is returned when Reuseport is false.
// ErrOnlyReuseportOnWindows is returned when running on Windows without Reuseport.
ErrOnlyReuseportOnWindows = errors.New("windows only supports Reuseport = true")
// ErrCommandProducerNilCmd is returned when a CommandProducer returns
// (nil, nil) instead of a started command.
ErrCommandProducerNilCmd = errors.New("prefork: CommandProducer returned nil command")
// ErrCommandProducerNotStarted is returned when a CommandProducer returns
// an *exec.Cmd whose Process is nil (i.e. cmd.Start() was not called).
ErrCommandProducerNotStarted = errors.New("prefork: CommandProducer must return a started command")
)
// Logger is used for logging formatted messages.
// Logger is used for logging formatted messages. Its method set is intentionally
// identical to fasthttp.Logger so that *fasthttp.Server.Logger can be assigned
// directly.
type Logger interface {
// Printf must have the same semantics as log.Printf.
Printf(format string, args ...any)
}
// Compile-time check that fasthttp.Logger satisfies the local Logger interface;
// keeps the two types in sync if either side ever evolves.
var _ Logger = fasthttp.Logger(nil)
// Prefork implements fasthttp server prefork.
//
// Preforks master process (with all cores) between several child processes
@@ -44,7 +77,8 @@ type Logger interface {
// WARNING: using prefork prevents the use of any global state!
// Things like in-memory caches won't work.
type Prefork struct {
// By default standard logger from log package is used.
// Logger receives diagnostic output. By default the standard log package
// logger writing to stderr is used.
Logger Logger
ln net.Listener
@@ -53,21 +87,29 @@ type Prefork struct {
ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error
ServeTLSEmbedFunc func(ln net.Listener, certData, keyData []byte) error
// The network must be "tcp", "tcp4" or "tcp6".
//
// By default is "tcp4"
// Network must be "tcp", "tcp4" or "tcp6". Default is "tcp4".
Network string
files []*os.File
// Child prefork processes may exit with failure and will be started over until the times reach
// the value of RecoverThreshold, then it will return and terminate the server.
// RecoverThreshold caps how often crashed children are respawned before
// the master returns ErrOverRecovery. New() sets it to max(1, GOMAXPROCS/2).
// When constructing a Prefork directly without New(), a zero value will
// terminate the master after the very first child crash.
RecoverThreshold int
// Flag to use a listener with reuseport, if not a file Listener will be used
// RecoverInterval, when > 0, makes the master sleep for the given duration
// before respawning a crashed child. Useful as crash-loop backoff.
RecoverInterval time.Duration
// ShutdownGracePeriod is the time the master waits for children to exit
// after sending SIGTERM before falling back to SIGKILL. Defaults to 5s
// when zero. On Windows SIGTERM is not delivered, so this is unused there.
ShutdownGracePeriod time.Duration
// Reuseport selects a reuseport listener instead of fd-passing.
// See: https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/
//
// It's disabled by default
// Disabled by default.
Reuseport bool
// OnMasterDeath, when non-nil, enables monitoring of the master process
@@ -76,19 +118,75 @@ type Prefork struct {
//
// It is recommended to set this to func() { os.Exit(1) } if no custom
// cleanup is needed.
//
// Threading: invoked once from a watcher goroutine in the child. Must not
// block the goroutine for long; must not call Prefork methods.
OnMasterDeath func()
// OnChildSpawn is called in the master after a new child process is
// successfully started, both during initial spawn and during recovery.
// It receives the PID of the newly spawned child process.
//
// If this callback returns an error, all already-running children are
// killed and the prefork operation returns that error.
//
// Threading: invoked synchronously from the master goroutine. Must not
// block; must not call Prefork methods. Panics are recovered and surfaced
// as the returned error.
OnChildSpawn func(pid int) error
// OnMasterReady is called in the master process exactly once, after all
// initial children have been spawned and before the supervision loop runs.
// It receives a slice of all initial child PIDs.
//
// If this callback returns an error, the prefork operation aborts and
// returns that error after killing the children.
//
// Threading: invoked synchronously from the master goroutine. The slice
// is owned by the caller after the call returns. Panics are recovered.
OnMasterReady func(childPIDs []int) error
// OnChildRecover is called in the master after a crashed child has been
// replaced. It receives the PID of the old (crashed) process and the PID
// of its replacement.
//
// If this callback returns an error, all running children are killed and
// the prefork operation returns that error.
//
// Threading: invoked synchronously from the master goroutine, after
// OnChildSpawn for the new child. Panics are recovered and surfaced as
// the returned error.
OnChildRecover func(oldPID, newPID int) error
// CommandProducer creates and starts a child process command.
// If nil, the default implementation re-executes the current binary
// with FASTHTTP_PREFORK_CHILD=1 in the environment, stdout/stderr
// inherited from the parent, and the given files as ExtraFiles.
//
// A custom producer must:
// - Set FASTHTTP_PREFORK_CHILD=1 in the child's environment
// (otherwise IsChild() returns false and the child won't serve)
// - Call cmd.Start() before returning (the returned command must be
// started so cmd.Process is non-nil)
// - Pass the provided files as cmd.ExtraFiles when Reuseport is false
//
// Primarily useful for testing (injecting dummy commands) or for
// frameworks that need custom child process setup.
CommandProducer func(files []*os.File) (*exec.Cmd, error)
}
// IsChild checks if the current thread/process is a child.
// IsChild reports whether the current process is a prefork child.
func IsChild() bool {
return os.Getenv(preforkChildEnvVariable) == "1"
return os.Getenv(preforkChildEnvVariable) == preforkChildEnvValue
}
// New wraps the fasthttp server to run with preforked processes.
// It seeds Network and RecoverThreshold to sensible defaults; existing
// fields on s (Logger, Serve*) are captured.
func New(s *fasthttp.Server) *Prefork {
return &Prefork{
Network: defaultNetwork,
RecoverThreshold: runtime.GOMAXPROCS(0) / 2,
RecoverThreshold: defaultRecoverThreshold(),
Logger: s.Logger,
ServeFunc: s.Serve,
ServeTLSFunc: s.ServeTLS,
@@ -96,6 +194,10 @@ func New(s *fasthttp.Server) *Prefork {
}
}
func defaultRecoverThreshold() int {
return max(1, runtime.GOMAXPROCS(0)/2)
}
func (p *Prefork) logger() Logger {
if p.Logger != nil {
return p.Logger
@@ -103,13 +205,46 @@ func (p *Prefork) logger() Logger {
return defaultLogger
}
// invokeHook runs fn under a panic recovery, returning the panic as an error
// so a misbehaving callback never tears down the master.
func (p *Prefork) invokeHook(name string, fn func() error) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("prefork: %s panicked: %v", name, r)
p.logger().Printf("%v", err)
}
}()
return fn()
}
func (p *Prefork) watchMaster(masterPID int) {
ticker := time.NewTicker(500 * time.Millisecond)
if runtime.GOOS == "windows" {
// On Windows, os.Getppid() returns a static PID that doesn't change
// when the parent exits (no reparenting). Use FindProcess+Wait instead.
proc, err := os.FindProcess(masterPID)
if err != nil {
p.logger().Printf("watchMaster: failed to find master process %d: %v", masterPID, err)
p.OnMasterDeath()
return
}
if _, err = proc.Wait(); err != nil {
p.logger().Printf("watchMaster: error waiting for master process %d: %v", masterPID, err)
}
p.logger().Printf("master process %d died", masterPID)
p.OnMasterDeath()
return
}
// Unix/Linux/macOS: When the master exits, the OS reparents the child
// to another process, causing Getppid() to change. Comparing against
// the original masterPID (instead of hardcoding 1) ensures this works
// correctly when the master itself is PID 1 (e.g. in Docker containers).
ticker := time.NewTicker(masterPollInterval)
defer ticker.Stop()
for range ticker.C {
if os.Getppid() != masterPID {
p.logger().Printf("master process died\n")
p.logger().Printf("master process %d died", masterPID)
p.OnMasterDeath()
return
}
@@ -127,7 +262,27 @@ func (p *Prefork) listen(addr string) (net.Listener, error) {
return reuseport.Listen(p.Network, addr)
}
return net.FileListener(os.NewFile(3, ""))
// fd inheritedListenerFD is the first ExtraFiles entry passed by the
// master process when Reuseport is false. Naming the file gives clearer
// errors from net.FileListener if the fd is invalid.
return net.FileListener(os.NewFile(inheritedListenerFD, "fasthttp-prefork-listener"))
}
// listenAsChild performs the common child process setup: creates the listener
// and starts watching the master process if OnMasterDeath is configured.
func (p *Prefork) listenAsChild(addr string) (net.Listener, error) {
ln, err := p.listen(addr)
if err != nil {
return nil, err
}
p.ln = ln
if p.OnMasterDeath != nil {
go p.watchMaster(os.Getppid())
}
return ln, nil
}
func (p *Prefork) setTCPListenerFiles(addr string) error {
@@ -137,49 +292,142 @@ func (p *Prefork) setTCPListenerFiles(addr string) error {
tcpAddr, err := net.ResolveTCPAddr(p.Network, addr)
if err != nil {
return err
return fmt.Errorf("prefork: resolve %s/%s: %w", p.Network, addr, err)
}
tcplistener, err := net.ListenTCP(p.Network, tcpAddr)
tcpListener, err := net.ListenTCP(p.Network, tcpAddr)
if err != nil {
return err
return fmt.Errorf("prefork: listen tcp %s: %w", addr, err)
}
p.ln = tcplistener
p.ln = tcpListener
fl, err := tcplistener.File()
listenerFile, err := tcpListener.File()
if err != nil {
return err
return fmt.Errorf("prefork: dup listener fd: %w", err)
}
p.files = []*os.File{fl}
p.files = []*os.File{listenerFile}
return nil
}
// childEnv returns os.Environ() with the prefork child marker variable set,
// stripping any pre-existing value to avoid duplicate keys with last-wins
// semantics.
func childEnv() []string {
src := os.Environ()
out := make([]string, 0, len(src)+1)
prefix := preforkChildEnvVariable + "="
for _, kv := range src {
if len(kv) >= len(prefix) && kv[:len(prefix)] == prefix {
continue
}
out = append(out, kv)
}
out = append(out, preforkChildEnvVariable+"="+preforkChildEnvValue)
return out
}
func (p *Prefork) doCommand() (*exec.Cmd, error) {
executable, err := os.Executable()
if err != nil {
return nil, err
if p.CommandProducer != nil {
cmd, err := p.CommandProducer(p.files)
if err != nil {
return nil, fmt.Errorf("prefork: CommandProducer: %w", err)
}
if cmd == nil {
return nil, ErrCommandProducerNilCmd
}
if cmd.Process == nil {
return nil, ErrCommandProducerNotStarted
}
return cmd, nil
}
args := make([]string, len(os.Args))
args[0] = executable
copy(args[1:], os.Args[1:])
executable, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("prefork: resolve executable: %w", err)
}
args := append([]string{executable}, os.Args[1:]...)
cmd := &exec.Cmd{
Path: executable,
Args: args,
Stdout: os.Stdout,
Stderr: os.Stderr,
Env: append(os.Environ(), preforkChildEnvVariable+"=1"),
Env: childEnv(),
ExtraFiles: p.files,
}
err = cmd.Start()
return cmd, err
if err = cmd.Start(); err != nil {
return nil, fmt.Errorf("prefork: start child %q: %w", executable, err)
}
return cmd, nil
}
func (p *Prefork) prefork(addr string) (err error) {
type childExit struct {
err error
pid int
}
// shutdownChildren signals every entry in childProcs first with SIGTERM (on
// platforms where it is supported) and waits up to grace for them to exit.
// Survivors are then killed unconditionally. wg tracks the per-child Wait
// goroutines and must be drained before returning so no goroutine outlives
// prefork().
func (p *Prefork) shutdownChildren(
childProcs map[int]*exec.Cmd,
wg *sync.WaitGroup,
cancel context.CancelFunc,
grace time.Duration,
) {
if grace <= 0 {
grace = defaultShutdownGracePeriod
}
if runtime.GOOS != "windows" {
for pid, proc := range childProcs {
if proc == nil || proc.Process == nil {
continue
}
if termErr := proc.Process.Signal(syscall.SIGTERM); termErr != nil &&
!errors.Is(termErr, os.ErrProcessDone) {
p.logger().Printf("prefork: SIGTERM child %d: %v", pid, termErr)
}
}
}
// Wait for graceful exits, with a timeout fallback to SIGKILL.
graceful := make(chan struct{})
go func() {
wg.Wait()
close(graceful)
}()
timer := time.NewTimer(grace)
defer timer.Stop()
select {
case <-graceful:
case <-timer.C:
}
for pid, proc := range childProcs {
if proc == nil || proc.Process == nil {
continue
}
if killErr := proc.Process.Kill(); killErr != nil &&
!errors.Is(killErr, os.ErrProcessDone) {
p.logger().Printf("prefork: kill child %d: %v", pid, killErr)
}
}
// Cancel the per-Wait goroutines' send-context so any still blocked on
// sigCh send unblock cleanly, then wait for all of them to exit.
cancel()
wg.Wait()
}
func (p *Prefork) prefork(addr string) (err error) { //nolint:gocyclo
if !p.Reuseport {
if runtime.GOOS == "windows" {
return ErrOnlyReuseportOnWindows
@@ -189,86 +437,166 @@ func (p *Prefork) prefork(addr string) (err error) {
return err
}
// defer for closing the net.Listener opened by setTCPListenerFiles.
// Close listener fds opened by setTCPListenerFiles. Both the original
// tcpListener (p.ln) and the duped fd (p.files[0]) belong to the
// master only; children inherit independent dup'd copies via fork+exec.
defer func() {
e := p.ln.Close()
if err == nil {
err = e
err = errors.Join(err, p.ln.Close())
for _, f := range p.files {
if closeErr := f.Close(); closeErr != nil {
p.logger().Printf("prefork: close listener fd: %v", closeErr)
}
}
p.files = nil
}()
}
// ctx cancels per-child Wait goroutines so they unblock from sigCh sends
// once the supervision loop is gone.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Catch SIGTERM/SIGINT in the master so we run our shutdown path instead
// of being killed by the OS without children getting a graceful chance.
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signalCh)
goMaxProcs := runtime.GOMAXPROCS(0)
// Buffer is sized to initial fleet; per-child goroutines fall back to a
// context-aware select on send so capacity is not load-bearing.
sigCh := make(chan childExit, goMaxProcs)
childProcs := make(map[int]*exec.Cmd, goMaxProcs)
var wg sync.WaitGroup
startWait := func(cmd *exec.Cmd, pid int) {
wg.Add(1)
go func() {
defer wg.Done()
result := childExit{pid: pid, err: cmd.Wait()}
select {
case sigCh <- result:
case <-ctx.Done():
}
}()
}
type procSig struct {
err error
pid int
}
goMaxProcs := runtime.GOMAXPROCS(0)
sigCh := make(chan procSig, goMaxProcs)
childProcs := make(map[int]*exec.Cmd)
defer func() {
for _, proc := range childProcs {
_ = proc.Process.Kill()
}
p.shutdownChildren(childProcs, &wg, cancel, p.ShutdownGracePeriod)
}()
childPIDs := make([]int, 0, goMaxProcs)
for range goMaxProcs {
var cmd *exec.Cmd
if cmd, err = p.doCommand(); err != nil {
p.logger().Printf("failed to start a child prefork process, error: %v\n", err)
p.logger().Printf("prefork: failed to start a child process: %v", err)
return err
}
childProcs[cmd.Process.Pid] = cmd
go func() {
sigCh <- procSig{pid: cmd.Process.Pid, err: cmd.Wait()}
}()
pid := cmd.Process.Pid
childProcs[pid] = cmd
childPIDs = append(childPIDs, pid)
// Start Wait goroutine before the user callback so a panic / error
// return from OnChildSpawn cannot leave a zombie behind.
startWait(cmd, pid)
if p.OnChildSpawn != nil {
pid := pid
if hookErr := p.invokeHook("OnChildSpawn", func() error {
return p.OnChildSpawn(pid)
}); hookErr != nil {
p.logger().Printf("prefork: OnChildSpawn for PID %d: %v", pid, hookErr)
return hookErr
}
}
}
if p.OnMasterReady != nil {
pids := append([]int(nil), childPIDs...)
if hookErr := p.invokeHook("OnMasterReady", func() error {
return p.OnMasterReady(pids)
}); hookErr != nil {
p.logger().Printf("prefork: OnMasterReady: %v", hookErr)
return hookErr
}
}
var exitedProcs int
for sig := range sigCh {
delete(childProcs, sig.pid)
for {
select {
case sig := <-signalCh:
p.logger().Printf("prefork: received signal %v, shutting down", sig)
return nil
p.logger().Printf("one of the child prefork processes exited with "+
"error: %v", sig.err)
case sig := <-sigCh:
delete(childProcs, sig.pid)
exitedProcs++
if exitedProcs > p.RecoverThreshold {
p.logger().Printf("child prefork processes exit too many times, "+
"which exceeds the value of RecoverThreshold(%d), "+
"exiting the master process.\n", exitedProcs)
err = ErrOverRecovery
break
if sig.err != nil {
p.logger().Printf("prefork: child PID %d exited: %v", sig.pid, sig.err)
} else {
p.logger().Printf("prefork: child PID %d exited cleanly", sig.pid)
}
exitedProcs++
if exitedProcs > p.RecoverThreshold {
p.logger().Printf(
"prefork: child exits (%d) exceed RecoverThreshold (%d), terminating master",
exitedProcs, p.RecoverThreshold,
)
return ErrOverRecovery
}
if p.RecoverInterval > 0 {
select {
case <-time.After(p.RecoverInterval):
case <-signalCh:
return nil
}
}
cmd, doErr := p.doCommand()
if doErr != nil {
p.logger().Printf("prefork: recovery doCommand: %v", doErr)
return doErr
}
newPID := cmd.Process.Pid
childProcs[newPID] = cmd
startWait(cmd, newPID)
if p.OnChildSpawn != nil {
newPID := newPID
if hookErr := p.invokeHook("OnChildSpawn", func() error {
return p.OnChildSpawn(newPID)
}); hookErr != nil {
p.logger().Printf("prefork: OnChildSpawn for recovered PID %d: %v", newPID, hookErr)
return hookErr
}
}
if p.OnChildRecover != nil {
oldPID := sig.pid
newPID := newPID
if hookErr := p.invokeHook("OnChildRecover", func() error {
return p.OnChildRecover(oldPID, newPID)
}); hookErr != nil {
p.logger().Printf("prefork: OnChildRecover (%d -> %d): %v", oldPID, newPID, hookErr)
return hookErr
}
}
}
var cmd *exec.Cmd
if cmd, err = p.doCommand(); err != nil {
break
}
childProcs[cmd.Process.Pid] = cmd
go func() {
sigCh <- procSig{pid: cmd.Process.Pid, err: cmd.Wait()}
}()
}
return err
}
// ListenAndServe serves HTTP requests from the given TCP addr.
func (p *Prefork) ListenAndServe(addr string) error {
if IsChild() {
ln, err := p.listen(addr)
ln, err := p.listenAsChild(addr)
if err != nil {
return err
}
p.ln = ln
if p.OnMasterDeath != nil {
go p.watchMaster(os.Getppid())
}
return p.ServeFunc(ln)
}
@@ -277,20 +605,16 @@ func (p *Prefork) ListenAndServe(addr string) error {
// ListenAndServeTLS serves HTTPS requests from the given TCP addr.
//
// certFile and keyFile are paths to TLS certificate and key files.
// Note: parameter order is (addr, certKey, certFile) — key path comes
// before cert path. This is preserved for backward compatibility with
// existing callers and differs from fasthttp.Server.ListenAndServeTLS.
// New code should prefer ListenAndServeTLSEmbed.
func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error {
if IsChild() {
ln, err := p.listen(addr)
ln, err := p.listenAsChild(addr)
if err != nil {
return err
}
p.ln = ln
if p.OnMasterDeath != nil {
go p.watchMaster(os.Getppid())
}
return p.ServeTLSFunc(ln, certFile, certKey)
}
@@ -302,17 +626,10 @@ func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error {
// certData and keyData must contain valid TLS certificate and key data.
func (p *Prefork) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error {
if IsChild() {
ln, err := p.listen(addr)
ln, err := p.listenAsChild(addr)
if err != nil {
return err
}
p.ln = ln
if p.OnMasterDeath != nil {
go p.watchMaster(os.Getppid())
}
return p.ServeTLSEmbedFunc(ln, certData, keyData)
}
+479 -134
View File
@@ -1,43 +1,80 @@
package prefork
import (
"errors"
"fmt"
"math/rand"
"net"
"os"
"reflect"
"os/exec"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/valyala/fasthttp"
)
func setUp() {
os.Setenv(preforkChildEnvVariable, "1")
// freeAddr returns a TCP4 address that the OS believes is free at call time.
// It avoids the rand-port collision flakes the previous helper exhibited.
func freeAddr(t testing.TB) string {
t.Helper()
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("freeAddr: %v", err)
}
addr := ln.Addr().String()
if closeErr := ln.Close(); closeErr != nil {
t.Fatalf("freeAddr: close: %v", closeErr)
}
return addr
}
func tearDown() {
os.Unsetenv(preforkChildEnvVariable)
}
// noopChildProducer returns a CommandProducer that re-execs the test binary
// into a no-op subprocess. The returned cleanup must be deferred (or registered
// via t.Cleanup) so leaked subprocesses are reaped if the test fails midway.
func noopChildProducer(t testing.TB) (func(files []*os.File) (*exec.Cmd, error), func()) {
t.Helper()
var (
mu sync.Mutex
spawned []*exec.Cmd
)
func getAddr() string {
return fmt.Sprintf("127.0.0.1:%d", rand.Intn(9000-3000)+3000)
produce := func(_ []*os.File) (*exec.Cmd, error) {
cmd := exec.Command(os.Args[0], "-test.run=^$")
cmd.Env = append(os.Environ(), preforkChildEnvVariable+"="+preforkChildEnvValue)
if err := cmd.Start(); err != nil {
return nil, err
}
mu.Lock()
spawned = append(spawned, cmd)
mu.Unlock()
return cmd, nil
}
cleanup := func() {
mu.Lock()
defer mu.Unlock()
for _, cmd := range spawned {
if cmd == nil || cmd.Process == nil {
continue
}
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
}
}
return produce, cleanup
}
func Test_IsChild(t *testing.T) {
// This test can't run parallel as it modifies os.Args.
v := IsChild()
if v {
t.Errorf("IsChild() == %v, want %v", v, false)
// This test cannot run in parallel — IsChild() reads a process-global env var.
if IsChild() {
t.Fatal("test starts as child unexpectedly")
}
setUp()
defer tearDown()
v = IsChild()
if !v {
t.Errorf("IsChild() == %v, want %v", v, true)
t.Setenv(preforkChildEnvVariable, preforkChildEnvValue)
if !IsChild() {
t.Errorf("IsChild() == false after Setenv, want true")
}
}
@@ -50,48 +87,37 @@ func Test_New(t *testing.T) {
if p.Network != defaultNetwork {
t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
}
if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() {
t.Errorf("Prefork.ServeFunc == %p, want %p", p.ServeFunc, s.Serve)
if p.RecoverThreshold <= 0 {
t.Errorf("Prefork.RecoverThreshold == %d, want > 0", p.RecoverThreshold)
}
if reflect.ValueOf(p.ServeTLSFunc).Pointer() != reflect.ValueOf(s.ServeTLS).Pointer() {
t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSFunc, s.ServeTLS)
}
if reflect.ValueOf(p.ServeTLSEmbedFunc).Pointer() != reflect.ValueOf(s.ServeTLSEmbed).Pointer() {
t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSEmbedFunc, s.ServeTLSEmbed)
if p.ServeFunc == nil || p.ServeTLSFunc == nil || p.ServeTLSEmbedFunc == nil {
t.Error("New() did not wire one of ServeFunc/ServeTLSFunc/ServeTLSEmbedFunc")
}
}
func Test_listen(t *testing.T) {
t.Parallel()
func Test_listen_Reuseport(t *testing.T) {
prev := runtime.GOMAXPROCS(0)
t.Cleanup(func() {
runtime.GOMAXPROCS(prev)
})
p := &Prefork{
Reuseport: true,
}
addr := getAddr()
p := &Prefork{Reuseport: true}
addr := freeAddr(t)
ln, err := p.listen(addr)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() {
_ = ln.Close()
})
ln.Close()
lnAddr := ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
if got, want := ln.Addr().String(), addr; got != want {
t.Errorf("ln.Addr() == %q, want %q", got, want)
}
if p.Network != defaultNetwork {
t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
}
procs := runtime.GOMAXPROCS(0)
if procs != 1 {
t.Errorf("GOMAXPROCS == %d, want %d", procs, 1)
}
}
func Test_setTCPListenerFiles(t *testing.T) {
@@ -102,125 +128,444 @@ func Test_setTCPListenerFiles(t *testing.T) {
}
p := &Prefork{}
addr := getAddr()
addr := freeAddr(t)
err := p.setTCPListenerFiles(addr)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
if err := p.setTCPListenerFiles(addr); err != nil {
t.Fatalf("setTCPListenerFiles: %v", err)
}
t.Cleanup(func() {
_ = p.ln.Close()
for _, f := range p.files {
_ = f.Close()
}
})
if p.ln == nil {
t.Fatal("Prefork.ln is nil")
t.Fatal("p.ln is nil after setTCPListenerFiles")
}
p.ln.Close()
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
if got, want := p.ln.Addr().String(), addr; got != want {
t.Errorf("p.ln.Addr() == %q, want %q", got, want)
}
if p.Network != defaultNetwork {
t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
}
if len(p.files) != 1 {
t.Errorf("Prefork.files == %d, want %d", len(p.files), 1)
t.Errorf("len(p.files) == %d, want 1", len(p.files))
}
}
func Test_ListenAndServe(t *testing.T) {
// This test can't run parallel as it modifies os.Args.
func Test_setTCPListenerFiles_BadAddr(t *testing.T) {
t.Parallel()
setUp()
defer tearDown()
s := &fasthttp.Server{}
p := New(s)
p.Reuseport = true
p.ServeFunc = func(ln net.Listener) error {
return nil
if runtime.GOOS == "windows" {
t.SkipNow()
}
addr := getAddr()
err := p.ListenAndServe(addr)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
p.ln.Close()
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
}
if p.ln == nil {
t.Error("Prefork.ln is nil")
p := &Prefork{}
if err := p.setTCPListenerFiles("definitely not an address"); err == nil {
t.Fatal("expected error for malformed addr, got nil")
}
}
func Test_ListenAndServeTLS(t *testing.T) {
// This test can't run parallel as it modifies os.Args.
// Test_ListenAndServe_Stub_ChildPath drives the child branch of all three
// ListenAndServe* entry points using a stubbed Serve function. It replaces the
// previous trio of near-identical tests that only validated field assignment.
func Test_ListenAndServe_Stub_ChildPath(t *testing.T) {
// child env mutation precludes t.Parallel.
t.Setenv(preforkChildEnvVariable, preforkChildEnvValue)
setUp()
defer tearDown()
s := &fasthttp.Server{}
p := New(s)
p.Reuseport = true
p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error {
return nil
type call struct {
listener bool
certFile string
keyFile string
certData string
keyData string
}
addr := getAddr()
err := p.ListenAndServeTLS(addr, "./key", "./cert")
if err != nil {
t.Errorf("Unexpected error: %v", err)
tests := []struct {
name string
run func(t *testing.T, p *Prefork, addr string) error
want call
}{
{
name: "ListenAndServe",
run: func(_ *testing.T, p *Prefork, addr string) error { return p.ListenAndServe(addr) },
want: call{listener: true},
},
{
name: "ListenAndServeTLS",
run: func(_ *testing.T, p *Prefork, addr string) error {
return p.ListenAndServeTLS(addr, "./key", "./cert")
},
want: call{listener: true, certFile: "./cert", keyFile: "./key"},
},
{
name: "ListenAndServeTLSEmbed",
run: func(_ *testing.T, p *Prefork, addr string) error {
return p.ListenAndServeTLSEmbed(addr, []byte("certPEM"), []byte("keyPEM"))
},
want: call{listener: true, certData: "certPEM", keyData: "keyPEM"},
},
}
p.ln.Close()
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var got call
p := New(&fasthttp.Server{})
p.Reuseport = true
p.ServeFunc = func(ln net.Listener) error {
got.listener = ln != nil
return nil
}
p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error {
got.listener = ln != nil
got.certFile = certFile
got.keyFile = keyFile
return nil
}
p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error {
got.listener = ln != nil
got.certData = string(certData)
got.keyData = string(keyData)
return nil
}
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
}
addr := freeAddr(t)
if err := tc.run(t, p, addr); err != nil {
t.Fatalf("%s: %v", tc.name, err)
}
t.Cleanup(func() {
if p.ln != nil {
_ = p.ln.Close()
}
})
if p.ln == nil {
t.Error("Prefork.ln is nil")
if got != tc.want {
t.Errorf("%s call = %+v, want %+v", tc.name, got, tc.want)
}
})
}
}
func Test_ListenAndServeTLSEmbed(t *testing.T) {
// This test can't run parallel as it modifies os.Args.
func Test_doCommand_CommandProducerErrors(t *testing.T) {
t.Parallel()
setUp()
defer tearDown()
s := &fasthttp.Server{}
p := New(s)
p.Reuseport = true
p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error {
return nil
tests := []struct {
name string
produce func(files []*os.File) (*exec.Cmd, error)
wantErr error
}{
{
name: "producer returns error",
produce: func([]*os.File) (*exec.Cmd, error) {
return nil, errors.New("boom")
},
},
{
name: "producer returns nil cmd",
//nolint:nilnil // intentionally tests the (nil, nil) misbehaviour guard
produce: func([]*os.File) (*exec.Cmd, error) {
return nil, nil
},
wantErr: ErrCommandProducerNilCmd,
},
{
name: "producer returns unstarted cmd",
produce: func([]*os.File) (*exec.Cmd, error) {
return &exec.Cmd{}, nil
},
wantErr: ErrCommandProducerNotStarted,
},
}
addr := getAddr()
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
p := &Prefork{CommandProducer: tc.produce}
cmd, err := p.doCommand()
if cmd != nil {
t.Errorf("expected nil cmd on error, got %v", cmd)
}
if err == nil {
t.Fatal("expected error, got nil")
}
if tc.wantErr != nil && !errors.Is(err, tc.wantErr) {
t.Errorf("err = %v, want errors.Is %v", err, tc.wantErr)
}
})
}
}
type testLogger struct {
mu sync.Mutex
messages []string
}
func (l *testLogger) Printf(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
l.mu.Lock()
l.messages = append(l.messages, msg)
l.mu.Unlock()
}
// Test_Prefork_Lifecycle drives prefork() to ErrOverRecovery via
// short-lived no-op children and asserts the callback ordering / arguments.
func Test_Prefork_Lifecycle(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
type event struct {
name string
pids []int
}
var mu sync.Mutex
var events []event
record := func(name string, pids ...int) {
mu.Lock()
events = append(events, event{name, pids})
mu.Unlock()
}
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
Logger: &testLogger{},
CommandProducer: produce,
OnChildSpawn: func(pid int) error {
record("spawn", pid)
return nil
},
OnMasterReady: func(childPIDs []int) error {
record("ready", childPIDs...)
return nil
},
OnChildRecover: func(oldPID, newPID int) error {
record("recover", oldPID, newPID)
return nil
},
}
err := p.prefork(freeAddr(t))
if !errors.Is(err, ErrOverRecovery) {
t.Fatalf("expected ErrOverRecovery, got: %v", err)
}
mu.Lock()
defer mu.Unlock()
goMaxProcs := runtime.GOMAXPROCS(0)
var spawnCount, readyCount, recoverCount int
for _, e := range events {
switch e.name {
case "spawn":
spawnCount++
if len(e.pids) != 1 || e.pids[0] <= 0 {
t.Errorf("spawn event has invalid PID: %v", e.pids)
}
case "ready":
readyCount++
if len(e.pids) == 0 {
t.Error("ready event received empty PID list")
}
case "recover":
recoverCount++
if len(e.pids) != 2 || e.pids[0] <= 0 || e.pids[1] <= 0 {
t.Errorf("recover event has invalid PIDs: %v", e.pids)
}
if e.pids[0] == e.pids[1] {
t.Error("recover old and new PID should differ")
}
}
}
if readyCount != 1 {
t.Errorf("OnMasterReady called %d times, want 1", readyCount)
}
if spawnCount < goMaxProcs {
t.Errorf("OnChildSpawn called %d times, want at least %d", spawnCount, goMaxProcs)
}
if recoverCount == 0 {
t.Error("OnChildRecover was never called")
}
// ready must come after exactly goMaxProcs initial spawns.
readyIdx := -1
spawnsBeforeReady := 0
for i, e := range events {
if e.name == "ready" {
readyIdx = i
break
}
if e.name == "spawn" {
spawnsBeforeReady++
}
}
if readyIdx == -1 {
t.Fatal("OnMasterReady was never called")
}
if spawnsBeforeReady != goMaxProcs {
t.Errorf("OnMasterReady called after %d initial spawns, want %d", spawnsBeforeReady, goMaxProcs)
}
// every recover event must be preceded by a spawn for the new PID.
recoveredSpawnByPID := make(map[int]bool)
for _, e := range events[readyIdx+1:] {
if e.name == "spawn" {
recoveredSpawnByPID[e.pids[0]] = true
}
if e.name == "recover" {
if !recoveredSpawnByPID[e.pids[1]] {
t.Errorf("OnChildRecover for PID %d happened before OnChildSpawn", e.pids[1])
}
}
}
}
func Test_Prefork_InitialChildSpawnError(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
err := p.ListenAndServeTLSEmbed(addr, []byte("key"), []byte("cert"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
expectedErr := errors.New("initial spawn rejected")
var calls atomic.Int32
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
Logger: &testLogger{},
CommandProducer: produce,
OnChildSpawn: func(_ int) error {
calls.Add(1)
return expectedErr
},
}
err := p.prefork(freeAddr(t))
if !errors.Is(err, expectedErr) {
t.Fatalf("expected %v, got: %v", expectedErr, err)
}
if calls.Load() == 0 {
t.Fatal("OnChildSpawn was never invoked")
}
}
func Test_Prefork_OnMasterReadyError(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
expectedErr := errors.New("ready rejected")
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
Logger: &testLogger{},
CommandProducer: produce,
OnMasterReady: func([]int) error {
return expectedErr
},
}
err := p.prefork(freeAddr(t))
if !errors.Is(err, expectedErr) {
t.Fatalf("expected %v, got: %v", expectedErr, err)
}
}
func Test_Prefork_RecoveredChildSpawnError(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
expectedErr := errors.New("spawn failed")
var spawnCount, recoverCount atomic.Int32
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
Logger: &testLogger{},
CommandProducer: produce,
OnChildSpawn: func(pid int) error {
if pid <= 0 {
t.Errorf("OnChildSpawn called with invalid PID: %d", pid)
}
n := spawnCount.Add(1)
if int(n) > runtime.GOMAXPROCS(0) {
return expectedErr
}
return nil
},
OnChildRecover: func(_, _ int) error {
recoverCount.Add(1)
return nil
},
}
err := p.prefork(freeAddr(t))
if !errors.Is(err, expectedErr) {
t.Fatalf("expected %v, got: %v", expectedErr, err)
}
if got := recoverCount.Load(); got != 0 {
t.Fatalf("OnChildRecover called %d times, want 0", got)
}
}
func Test_Prefork_HookPanicSurfaces(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
Logger: &testLogger{},
CommandProducer: produce,
OnChildSpawn: func(_ int) error {
panic("user code blew up")
},
}
err := p.prefork(freeAddr(t))
if err == nil {
t.Fatal("expected error from panicking hook, got nil")
}
}
p.ln.Close()
// Test_Prefork_RecoverInterval verifies the optional backoff actually delays
// the respawn — without rabbit-holing into hard wallclock assertions.
func Test_Prefork_RecoverInterval(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
produce, cleanup := noopChildProducer(t)
t.Cleanup(cleanup)
const interval = 50 * time.Millisecond
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
RecoverInterval: interval,
Logger: &testLogger{},
CommandProducer: produce,
}
start := time.Now()
err := p.prefork(freeAddr(t))
elapsed := time.Since(start)
if p.ln == nil {
t.Error("Prefork.ln is nil")
if !errors.Is(err, ErrOverRecovery) {
t.Fatalf("expected ErrOverRecovery, got %v", err)
}
// At least one recover interval must have elapsed before threshold fired.
if elapsed < interval {
t.Errorf("elapsed %v < interval %v; backoff did not apply", elapsed, interval)
}
}