mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-14 15:56:44 +03:00
feat(prefork): graceful shutdown, leak fixes, hook robustness
Addresses outstanding review concerns and several adjacent issues surfaced during a follow-up review pass. Lifecycle / supervision - Track every per-child Wait goroutine via sync.WaitGroup and unblock pending sigCh sends through a context.Cancel so early-return paths (OnChildSpawn / OnMasterReady error, recovery doCommand error, ErrOverRecovery) can no longer leak goroutines or stall children. - Install signal.Notify(SIGTERM, SIGINT) in the master so deploy/ rolling-restart signals enter the shutdown path instead of killing the master without graceful teardown. - Replace the unconditional SIGKILL defer with a SIGTERM-then-SIGKILL sequence gated by a configurable ShutdownGracePeriod (defaults to 5s, Windows path stays SIGKILL since Signal(SIGTERM) is unsupported). API - OnChildRecover now returns error so callers can implement recovery policies (circuit-breaker etc.); panic in any hook is recovered and surfaced as the returned error, with diagnostic logging. - Add RecoverInterval (optional crash-loop backoff) and ShutdownGracePeriod fields with safe zero-value defaults. - Export ErrCommandProducerNilCmd and ErrCommandProducerNotStarted sentinel errors so callers can errors.Is them. - Rename oldPid/newPid to oldPID/newPID per Go initialism convention. - Logger interface now declares an explicit compile-time compatibility check with fasthttp.Logger. Resource hygiene - Master closes both the original tcpListener and the duped fd in p.files when prefork() returns; previously the duped fd leaked once per call. - doCommand wraps every error path with %w + fmt.Errorf so caller-side diagnostics keep stage context. - Strip pre-existing FASTHTTP_PREFORK_CHILD entries before appending so child env never carries duplicate keys. - Extract magic numbers as package constants (inheritedListenerFD, masterPollInterval, defaultShutdownGracePeriod, preforkChildEnvValue). - Rename the inherited listener fd via os.NewFile so net.FileListener errors are diagnosable. Tests - Migrate to t.Setenv (drop the global setUp/tearDown helpers) — fixes the env-mutation-vs-parallel race. - Replace rand.Intn port helper with `:0` + Listener.Addr() to remove port-collision flakes under -count and parallel runs. - Collapse the three near-identical Test_ListenAndServe* tests into a single table-driven subtest that actually asserts the args forwarded to ServeFunc/ServeTLSFunc/ServeTLSEmbedFunc. - Add coverage for the previously untested branches: CommandProducer returning err / nil cmd / unstarted cmd, initial OnChildSpawn error, OnMasterReady error, hook panic surfacing, RecoverInterval enforcement. - noopChildProducer helper kills + waits any spawned child binaries during cleanup so failed tests no longer leave subprocesses around.
This commit is contained in:
+330
-124
@@ -2,12 +2,17 @@
|
||||
package prefork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
@@ -16,25 +21,53 @@ import (
|
||||
|
||||
const (
|
||||
preforkChildEnvVariable = "FASTHTTP_PREFORK_CHILD"
|
||||
preforkChildEnvValue = "1"
|
||||
defaultNetwork = "tcp4"
|
||||
|
||||
// inheritedListenerFD is the file descriptor used by the master to pass
|
||||
// the bound listener to a child process via ExtraFiles. Children open the
|
||||
// listener via os.NewFile(inheritedListenerFD, ...) when Reuseport is false.
|
||||
inheritedListenerFD = 3
|
||||
|
||||
// masterPollInterval is the period of the watchMaster ppid-poll on Unix.
|
||||
masterPollInterval = 500 * time.Millisecond
|
||||
|
||||
// defaultShutdownGracePeriod is how long the master waits for children to
|
||||
// exit cleanly after sending SIGTERM before forcibly killing them.
|
||||
defaultShutdownGracePeriod = 5 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags))
|
||||
// ErrOverRecovery is returned when the times of starting over child prefork processes exceed
|
||||
// the threshold.
|
||||
|
||||
// ErrOverRecovery is returned when child prefork process restarts exceed
|
||||
// the value of RecoverThreshold.
|
||||
ErrOverRecovery = errors.New("exceeding the value of RecoverThreshold")
|
||||
|
||||
// ErrOnlyReuseportOnWindows is returned when Reuseport is false.
|
||||
// ErrOnlyReuseportOnWindows is returned when running on Windows without Reuseport.
|
||||
ErrOnlyReuseportOnWindows = errors.New("windows only supports Reuseport = true")
|
||||
|
||||
// ErrCommandProducerNilCmd is returned when a CommandProducer returns
|
||||
// (nil, nil) instead of a started command.
|
||||
ErrCommandProducerNilCmd = errors.New("prefork: CommandProducer returned nil command")
|
||||
|
||||
// ErrCommandProducerNotStarted is returned when a CommandProducer returns
|
||||
// an *exec.Cmd whose Process is nil (i.e. cmd.Start() was not called).
|
||||
ErrCommandProducerNotStarted = errors.New("prefork: CommandProducer must return a started command")
|
||||
)
|
||||
|
||||
// Logger is used for logging formatted messages.
|
||||
// Logger is used for logging formatted messages. Its method set is intentionally
|
||||
// identical to fasthttp.Logger so that *fasthttp.Server.Logger can be assigned
|
||||
// directly.
|
||||
type Logger interface {
|
||||
// Printf must have the same semantics as log.Printf.
|
||||
Printf(format string, args ...any)
|
||||
}
|
||||
|
||||
// Compile-time check that fasthttp.Logger satisfies the local Logger interface;
|
||||
// keeps the two types in sync if either side ever evolves.
|
||||
var _ Logger = fasthttp.Logger(nil)
|
||||
|
||||
// Prefork implements fasthttp server prefork.
|
||||
//
|
||||
// Preforks master process (with all cores) between several child processes
|
||||
@@ -44,7 +77,8 @@ type Logger interface {
|
||||
// WARNING: using prefork prevents the use of any global state!
|
||||
// Things like in-memory caches won't work.
|
||||
type Prefork struct {
|
||||
// By default standard logger from log package is used.
|
||||
// Logger receives diagnostic output. By default the standard log package
|
||||
// logger writing to stderr is used.
|
||||
Logger Logger
|
||||
|
||||
ln net.Listener
|
||||
@@ -53,21 +87,29 @@ type Prefork struct {
|
||||
ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error
|
||||
ServeTLSEmbedFunc func(ln net.Listener, certData, keyData []byte) error
|
||||
|
||||
// The network must be "tcp", "tcp4" or "tcp6".
|
||||
//
|
||||
// By default is "tcp4"
|
||||
// Network must be "tcp", "tcp4" or "tcp6". Default is "tcp4".
|
||||
Network string
|
||||
|
||||
files []*os.File
|
||||
|
||||
// Child prefork processes may exit with failure and will be started over until the times reach
|
||||
// the value of RecoverThreshold, then it will return and terminate the server.
|
||||
// RecoverThreshold caps how often crashed children are respawned before
|
||||
// the master returns ErrOverRecovery. New() sets it to max(1, GOMAXPROCS/2).
|
||||
// When constructing a Prefork directly without New(), a zero value will
|
||||
// terminate the master after the very first child crash.
|
||||
RecoverThreshold int
|
||||
|
||||
// Flag to use a listener with reuseport, if not a file Listener will be used
|
||||
// RecoverInterval, when > 0, makes the master sleep for the given duration
|
||||
// before respawning a crashed child. Useful as crash-loop backoff.
|
||||
RecoverInterval time.Duration
|
||||
|
||||
// ShutdownGracePeriod is the time the master waits for children to exit
|
||||
// after sending SIGTERM before falling back to SIGKILL. Defaults to 5s
|
||||
// when zero. On Windows SIGTERM is not delivered, so this is unused there.
|
||||
ShutdownGracePeriod time.Duration
|
||||
|
||||
// Reuseport selects a reuseport listener instead of fd-passing.
|
||||
// See: https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/
|
||||
//
|
||||
// It's disabled by default
|
||||
// Disabled by default.
|
||||
Reuseport bool
|
||||
|
||||
// OnMasterDeath, when non-nil, enables monitoring of the master process
|
||||
@@ -76,25 +118,45 @@ type Prefork struct {
|
||||
//
|
||||
// It is recommended to set this to func() { os.Exit(1) } if no custom
|
||||
// cleanup is needed.
|
||||
//
|
||||
// Threading: invoked once from a watcher goroutine in the child. Must not
|
||||
// block the goroutine for long; must not call Prefork methods.
|
||||
OnMasterDeath func()
|
||||
|
||||
// OnChildSpawn is called in the master process whenever a new child process is spawned.
|
||||
// OnChildSpawn is called in the master after a new child process is
|
||||
// successfully started, both during initial spawn and during recovery.
|
||||
// It receives the PID of the newly spawned child process.
|
||||
//
|
||||
// If this callback returns an error, all child processes are killed and the
|
||||
// prefork operation returns that error.
|
||||
// If this callback returns an error, all already-running children are
|
||||
// killed and the prefork operation returns that error.
|
||||
//
|
||||
// Threading: invoked synchronously from the master goroutine. Must not
|
||||
// block; must not call Prefork methods. Panics are recovered and surfaced
|
||||
// as the returned error.
|
||||
OnChildSpawn func(pid int) error
|
||||
|
||||
// OnMasterReady is called in the master process after all child processes have been spawned.
|
||||
// It receives a slice of all child process PIDs.
|
||||
// OnMasterReady is called in the master process exactly once, after all
|
||||
// initial children have been spawned and before the supervision loop runs.
|
||||
// It receives a slice of all initial child PIDs.
|
||||
//
|
||||
// If this callback returns an error, the prefork operation will be aborted.
|
||||
// If this callback returns an error, the prefork operation aborts and
|
||||
// returns that error after killing the children.
|
||||
//
|
||||
// Threading: invoked synchronously from the master goroutine. The slice
|
||||
// is owned by the caller after the call returns. Panics are recovered.
|
||||
OnMasterReady func(childPIDs []int) error
|
||||
|
||||
// OnChildRecover is called in the master process when a crashed child process
|
||||
// has been replaced by a new one. It receives the PID of the old (crashed)
|
||||
// process and the PID of the newly spawned replacement.
|
||||
OnChildRecover func(oldPid, newPid int)
|
||||
// OnChildRecover is called in the master after a crashed child has been
|
||||
// replaced. It receives the PID of the old (crashed) process and the PID
|
||||
// of its replacement.
|
||||
//
|
||||
// If this callback returns an error, all running children are killed and
|
||||
// the prefork operation returns that error.
|
||||
//
|
||||
// Threading: invoked synchronously from the master goroutine, after
|
||||
// OnChildSpawn for the new child. Panics are recovered and surfaced as
|
||||
// the returned error.
|
||||
OnChildRecover func(oldPID, newPID int) error
|
||||
|
||||
// CommandProducer creates and starts a child process command.
|
||||
// If nil, the default implementation re-executes the current binary
|
||||
@@ -104,20 +166,23 @@ type Prefork struct {
|
||||
// A custom producer must:
|
||||
// - Set FASTHTTP_PREFORK_CHILD=1 in the child's environment
|
||||
// (otherwise IsChild() returns false and the child won't serve)
|
||||
// - Call cmd.Start() before returning (the returned command must be started)
|
||||
// - Call cmd.Start() before returning (the returned command must be
|
||||
// started so cmd.Process is non-nil)
|
||||
// - Pass the provided files as cmd.ExtraFiles when Reuseport is false
|
||||
//
|
||||
// This is primarily useful for testing (injecting dummy commands)
|
||||
// or for frameworks that need custom child process setup.
|
||||
// Primarily useful for testing (injecting dummy commands) or for
|
||||
// frameworks that need custom child process setup.
|
||||
CommandProducer func(files []*os.File) (*exec.Cmd, error)
|
||||
}
|
||||
|
||||
// IsChild checks if the current thread/process is a child.
|
||||
// IsChild reports whether the current process is a prefork child.
|
||||
func IsChild() bool {
|
||||
return os.Getenv(preforkChildEnvVariable) == "1"
|
||||
return os.Getenv(preforkChildEnvVariable) == preforkChildEnvValue
|
||||
}
|
||||
|
||||
// New wraps the fasthttp server to run with preforked processes.
|
||||
// It seeds Network and RecoverThreshold to sensible defaults; existing
|
||||
// fields on s (Logger, Serve*) are captured.
|
||||
func New(s *fasthttp.Server) *Prefork {
|
||||
return &Prefork{
|
||||
Network: defaultNetwork,
|
||||
@@ -140,20 +205,32 @@ func (p *Prefork) logger() Logger {
|
||||
return defaultLogger
|
||||
}
|
||||
|
||||
// invokeHook runs fn under a panic recovery, returning the panic as an error
|
||||
// so a misbehaving callback never tears down the master.
|
||||
func (p *Prefork) invokeHook(name string, fn func() error) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("prefork: %s panicked: %v", name, r)
|
||||
p.logger().Printf("%v", err)
|
||||
}
|
||||
}()
|
||||
return fn()
|
||||
}
|
||||
|
||||
func (p *Prefork) watchMaster(masterPID int) {
|
||||
if runtime.GOOS == "windows" {
|
||||
// On Windows, os.Getppid() returns a static PID that doesn't change
|
||||
// when the parent exits (no reparenting). Use FindProcess+Wait instead.
|
||||
proc, err := os.FindProcess(masterPID)
|
||||
if err != nil {
|
||||
p.logger().Printf("watchMaster: failed to find master process %d: %v\n", masterPID, err)
|
||||
p.logger().Printf("watchMaster: failed to find master process %d: %v", masterPID, err)
|
||||
p.OnMasterDeath()
|
||||
return
|
||||
}
|
||||
if _, err = proc.Wait(); err != nil {
|
||||
p.logger().Printf("watchMaster: error waiting for master process %d: %v\n", masterPID, err)
|
||||
p.logger().Printf("watchMaster: error waiting for master process %d: %v", masterPID, err)
|
||||
}
|
||||
p.logger().Printf("master process died\n")
|
||||
p.logger().Printf("master process %d died", masterPID)
|
||||
p.OnMasterDeath()
|
||||
return
|
||||
}
|
||||
@@ -162,12 +239,12 @@ func (p *Prefork) watchMaster(masterPID int) {
|
||||
// to another process, causing Getppid() to change. Comparing against
|
||||
// the original masterPID (instead of hardcoding 1) ensures this works
|
||||
// correctly when the master itself is PID 1 (e.g. in Docker containers).
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
ticker := time.NewTicker(masterPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if os.Getppid() != masterPID {
|
||||
p.logger().Printf("master process died\n")
|
||||
p.logger().Printf("master process %d died", masterPID)
|
||||
p.OnMasterDeath()
|
||||
return
|
||||
}
|
||||
@@ -185,8 +262,10 @@ func (p *Prefork) listen(addr string) (net.Listener, error) {
|
||||
return reuseport.Listen(p.Network, addr)
|
||||
}
|
||||
|
||||
// File descriptor 3 is the first ExtraFiles entry passed by the master process.
|
||||
return net.FileListener(os.NewFile(3, ""))
|
||||
// fd inheritedListenerFD is the first ExtraFiles entry passed by the
|
||||
// master process when Reuseport is false. Naming the file gives clearer
|
||||
// errors from net.FileListener if the fd is invalid.
|
||||
return net.FileListener(os.NewFile(inheritedListenerFD, "fasthttp-prefork-listener"))
|
||||
}
|
||||
|
||||
// listenAsChild performs the common child process setup: creates the listener
|
||||
@@ -213,62 +292,142 @@ func (p *Prefork) setTCPListenerFiles(addr string) error {
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr(p.Network, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("prefork: resolve %s/%s: %w", p.Network, addr, err)
|
||||
}
|
||||
|
||||
tcplistener, err := net.ListenTCP(p.Network, tcpAddr)
|
||||
tcpListener, err := net.ListenTCP(p.Network, tcpAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("prefork: listen tcp %s: %w", addr, err)
|
||||
}
|
||||
|
||||
p.ln = tcplistener
|
||||
p.ln = tcpListener
|
||||
|
||||
fl, err := tcplistener.File()
|
||||
listenerFile, err := tcpListener.File()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("prefork: dup listener fd: %w", err)
|
||||
}
|
||||
|
||||
p.files = []*os.File{fl}
|
||||
p.files = []*os.File{listenerFile}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// childEnv returns os.Environ() with the prefork child marker variable set,
|
||||
// stripping any pre-existing value to avoid duplicate keys with last-wins
|
||||
// semantics.
|
||||
func childEnv() []string {
|
||||
src := os.Environ()
|
||||
out := make([]string, 0, len(src)+1)
|
||||
prefix := preforkChildEnvVariable + "="
|
||||
for _, kv := range src {
|
||||
if len(kv) >= len(prefix) && kv[:len(prefix)] == prefix {
|
||||
continue
|
||||
}
|
||||
out = append(out, kv)
|
||||
}
|
||||
out = append(out, preforkChildEnvVariable+"="+preforkChildEnvValue)
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *Prefork) doCommand() (*exec.Cmd, error) {
|
||||
// Use custom CommandProducer if provided
|
||||
if p.CommandProducer != nil {
|
||||
cmd, err := p.CommandProducer(p.files)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("prefork: CommandProducer: %w", err)
|
||||
}
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return nil, errors.New("prefork: CommandProducer must return a started command")
|
||||
if cmd == nil {
|
||||
return nil, ErrCommandProducerNilCmd
|
||||
}
|
||||
if cmd.Process == nil {
|
||||
return nil, ErrCommandProducerNotStarted
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// Default implementation using os.Executable() for reliable path resolution
|
||||
executable, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("prefork: resolve executable: %w", err)
|
||||
}
|
||||
|
||||
args := make([]string, len(os.Args))
|
||||
args[0] = executable
|
||||
copy(args[1:], os.Args[1:])
|
||||
args := append([]string{executable}, os.Args[1:]...)
|
||||
|
||||
cmd := &exec.Cmd{
|
||||
Path: executable,
|
||||
Args: args,
|
||||
Stdout: os.Stdout,
|
||||
Stderr: os.Stderr,
|
||||
Env: append(os.Environ(), preforkChildEnvVariable+"=1"),
|
||||
Env: childEnv(),
|
||||
ExtraFiles: p.files,
|
||||
}
|
||||
err = cmd.Start()
|
||||
return cmd, err
|
||||
if err = cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("prefork: start child %q: %w", executable, err)
|
||||
}
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
func (p *Prefork) prefork(addr string) (err error) {
|
||||
type childExit struct {
|
||||
err error
|
||||
pid int
|
||||
}
|
||||
|
||||
// shutdownChildren signals every entry in childProcs first with SIGTERM (on
|
||||
// platforms where it is supported) and waits up to grace for them to exit.
|
||||
// Survivors are then killed unconditionally. wg tracks the per-child Wait
|
||||
// goroutines and must be drained before returning so no goroutine outlives
|
||||
// prefork().
|
||||
func (p *Prefork) shutdownChildren(
|
||||
childProcs map[int]*exec.Cmd,
|
||||
wg *sync.WaitGroup,
|
||||
cancel context.CancelFunc,
|
||||
grace time.Duration,
|
||||
) {
|
||||
if grace <= 0 {
|
||||
grace = defaultShutdownGracePeriod
|
||||
}
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
for pid, proc := range childProcs {
|
||||
if proc == nil || proc.Process == nil {
|
||||
continue
|
||||
}
|
||||
if termErr := proc.Process.Signal(syscall.SIGTERM); termErr != nil &&
|
||||
!errors.Is(termErr, os.ErrProcessDone) {
|
||||
p.logger().Printf("prefork: SIGTERM child %d: %v", pid, termErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for graceful exits, with a timeout fallback to SIGKILL.
|
||||
graceful := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(graceful)
|
||||
}()
|
||||
|
||||
timer := time.NewTimer(grace)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-graceful:
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
for pid, proc := range childProcs {
|
||||
if proc == nil || proc.Process == nil {
|
||||
continue
|
||||
}
|
||||
if killErr := proc.Process.Kill(); killErr != nil &&
|
||||
!errors.Is(killErr, os.ErrProcessDone) {
|
||||
p.logger().Printf("prefork: kill child %d: %v", pid, killErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel the per-Wait goroutines' send-context so any still blocked on
|
||||
// sigCh send unblock cleanly, then wait for all of them to exit.
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (p *Prefork) prefork(addr string) (err error) { //nolint:gocyclo
|
||||
if !p.Reuseport {
|
||||
if runtime.GOOS == "windows" {
|
||||
return ErrOnlyReuseportOnWindows
|
||||
@@ -278,37 +437,60 @@ func (p *Prefork) prefork(addr string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// defer for closing the net.Listener opened by setTCPListenerFiles.
|
||||
// Close listener fds opened by setTCPListenerFiles. Both the original
|
||||
// tcpListener (p.ln) and the duped fd (p.files[0]) belong to the
|
||||
// master only; children inherit independent dup'd copies via fork+exec.
|
||||
defer func() {
|
||||
e := p.ln.Close()
|
||||
if err == nil {
|
||||
err = e
|
||||
err = errors.Join(err, p.ln.Close())
|
||||
for _, f := range p.files {
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
p.logger().Printf("prefork: close listener fd: %v", closeErr)
|
||||
}
|
||||
}
|
||||
p.files = nil
|
||||
}()
|
||||
}
|
||||
|
||||
// ctx cancels per-child Wait goroutines so they unblock from sigCh sends
|
||||
// once the supervision loop is gone.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Catch SIGTERM/SIGINT in the master so we run our shutdown path instead
|
||||
// of being killed by the OS without children getting a graceful chance.
|
||||
signalCh := make(chan os.Signal, 1)
|
||||
signal.Notify(signalCh, syscall.SIGTERM, syscall.SIGINT)
|
||||
defer signal.Stop(signalCh)
|
||||
|
||||
goMaxProcs := runtime.GOMAXPROCS(0)
|
||||
// Buffer is sized to initial fleet; per-child goroutines fall back to a
|
||||
// context-aware select on send so capacity is not load-bearing.
|
||||
sigCh := make(chan childExit, goMaxProcs)
|
||||
childProcs := make(map[int]*exec.Cmd, goMaxProcs)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
startWait := func(cmd *exec.Cmd, pid int) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result := childExit{pid: pid, err: cmd.Wait()}
|
||||
select {
|
||||
case sigCh <- result:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type procSig struct {
|
||||
err error
|
||||
pid int
|
||||
}
|
||||
|
||||
goMaxProcs := runtime.GOMAXPROCS(0)
|
||||
sigCh := make(chan procSig, goMaxProcs)
|
||||
childProcs := make(map[int]*exec.Cmd, goMaxProcs)
|
||||
|
||||
defer func() {
|
||||
for _, proc := range childProcs {
|
||||
_ = proc.Process.Kill()
|
||||
}
|
||||
p.shutdownChildren(childProcs, &wg, cancel, p.ShutdownGracePeriod)
|
||||
}()
|
||||
|
||||
// Collect child PIDs for OnMasterReady callback
|
||||
childPIDs := make([]int, 0, goMaxProcs)
|
||||
|
||||
for range goMaxProcs {
|
||||
var cmd *exec.Cmd
|
||||
if cmd, err = p.doCommand(); err != nil {
|
||||
p.logger().Printf("failed to start a child prefork process, error: %v\n", err)
|
||||
p.logger().Printf("prefork: failed to start a child process: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -316,71 +498,96 @@ func (p *Prefork) prefork(addr string) (err error) {
|
||||
childProcs[pid] = cmd
|
||||
childPIDs = append(childPIDs, pid)
|
||||
|
||||
// Start Wait goroutine before OnChildSpawn so that if OnChildSpawn
|
||||
// fails and the deferred Kill() runs, Wait() will collect the child.
|
||||
go func(c *exec.Cmd, pid int) {
|
||||
sigCh <- procSig{pid: pid, err: c.Wait()}
|
||||
}(cmd, pid)
|
||||
// Start Wait goroutine before the user callback so a panic / error
|
||||
// return from OnChildSpawn cannot leave a zombie behind.
|
||||
startWait(cmd, pid)
|
||||
|
||||
if p.OnChildSpawn != nil {
|
||||
if err = p.OnChildSpawn(pid); err != nil {
|
||||
p.logger().Printf("OnChildSpawn callback failed for PID %d: %v\n", pid, err)
|
||||
return err
|
||||
pid := pid
|
||||
if hookErr := p.invokeHook("OnChildSpawn", func() error {
|
||||
return p.OnChildSpawn(pid)
|
||||
}); hookErr != nil {
|
||||
p.logger().Printf("prefork: OnChildSpawn for PID %d: %v", pid, hookErr)
|
||||
return hookErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Call OnMasterReady callback after all children are spawned
|
||||
if p.OnMasterReady != nil {
|
||||
if err = p.OnMasterReady(childPIDs); err != nil {
|
||||
p.logger().Printf("OnMasterReady callback failed: %v\n", err)
|
||||
return err
|
||||
pids := append([]int(nil), childPIDs...)
|
||||
if hookErr := p.invokeHook("OnMasterReady", func() error {
|
||||
return p.OnMasterReady(pids)
|
||||
}); hookErr != nil {
|
||||
p.logger().Printf("prefork: OnMasterReady: %v", hookErr)
|
||||
return hookErr
|
||||
}
|
||||
}
|
||||
|
||||
var exitedProcs int
|
||||
for sig := range sigCh {
|
||||
delete(childProcs, sig.pid)
|
||||
for {
|
||||
select {
|
||||
case sig := <-signalCh:
|
||||
p.logger().Printf("prefork: received signal %v, shutting down", sig)
|
||||
return nil
|
||||
|
||||
p.logger().Printf("one of the child prefork processes exited with "+
|
||||
"error: %v", sig.err)
|
||||
case sig := <-sigCh:
|
||||
delete(childProcs, sig.pid)
|
||||
|
||||
exitedProcs++
|
||||
if exitedProcs > p.RecoverThreshold {
|
||||
p.logger().Printf("child prefork processes exit too many times, "+
|
||||
"which exceeds the value of RecoverThreshold(%d), "+
|
||||
"exiting the master process.\n", p.RecoverThreshold)
|
||||
err = ErrOverRecovery
|
||||
break
|
||||
}
|
||||
if sig.err != nil {
|
||||
p.logger().Printf("prefork: child PID %d exited: %v", sig.pid, sig.err)
|
||||
} else {
|
||||
p.logger().Printf("prefork: child PID %d exited cleanly", sig.pid)
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
cmd, err = p.doCommand()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
newPid := cmd.Process.Pid
|
||||
childProcs[newPid] = cmd
|
||||
exitedProcs++
|
||||
if exitedProcs > p.RecoverThreshold {
|
||||
p.logger().Printf(
|
||||
"prefork: child exits (%d) exceed RecoverThreshold (%d), terminating master",
|
||||
exitedProcs, p.RecoverThreshold,
|
||||
)
|
||||
return ErrOverRecovery
|
||||
}
|
||||
|
||||
// Start Wait goroutine before callbacks to avoid zombie processes
|
||||
// if a callback fails and the deferred Kill() runs.
|
||||
go func(c *exec.Cmd, pid int) {
|
||||
sigCh <- procSig{pid: pid, err: c.Wait()}
|
||||
}(cmd, newPid)
|
||||
if p.RecoverInterval > 0 {
|
||||
select {
|
||||
case <-time.After(p.RecoverInterval):
|
||||
case <-signalCh:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if p.OnChildSpawn != nil {
|
||||
if err = p.OnChildSpawn(newPid); err != nil {
|
||||
p.logger().Printf("OnChildSpawn callback failed for recovered PID %d: %v\n", newPid, err)
|
||||
return err
|
||||
cmd, doErr := p.doCommand()
|
||||
if doErr != nil {
|
||||
p.logger().Printf("prefork: recovery doCommand: %v", doErr)
|
||||
return doErr
|
||||
}
|
||||
newPID := cmd.Process.Pid
|
||||
childProcs[newPID] = cmd
|
||||
|
||||
startWait(cmd, newPID)
|
||||
|
||||
if p.OnChildSpawn != nil {
|
||||
newPID := newPID
|
||||
if hookErr := p.invokeHook("OnChildSpawn", func() error {
|
||||
return p.OnChildSpawn(newPID)
|
||||
}); hookErr != nil {
|
||||
p.logger().Printf("prefork: OnChildSpawn for recovered PID %d: %v", newPID, hookErr)
|
||||
return hookErr
|
||||
}
|
||||
}
|
||||
|
||||
if p.OnChildRecover != nil {
|
||||
oldPID := sig.pid
|
||||
newPID := newPID
|
||||
if hookErr := p.invokeHook("OnChildRecover", func() error {
|
||||
return p.OnChildRecover(oldPID, newPID)
|
||||
}); hookErr != nil {
|
||||
p.logger().Printf("prefork: OnChildRecover (%d -> %d): %v", oldPID, newPID, hookErr)
|
||||
return hookErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if p.OnChildRecover != nil {
|
||||
p.OnChildRecover(sig.pid, newPid)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ListenAndServe serves HTTP requests from the given TCP addr.
|
||||
@@ -398,11 +605,10 @@ func (p *Prefork) ListenAndServe(addr string) error {
|
||||
|
||||
// ListenAndServeTLS serves HTTPS requests from the given TCP addr.
|
||||
//
|
||||
// certKey is the path to the TLS private key file.
|
||||
// certFile is the path to the TLS certificate file.
|
||||
//
|
||||
// Note: parameter order is (addr, certKey, certFile) — key before cert.
|
||||
// Internally forwards to ServeTLSFunc as (certFile, certKey).
|
||||
// Note: parameter order is (addr, certKey, certFile) — key path comes
|
||||
// before cert path. This is preserved for backward compatibility with
|
||||
// existing callers and differs from fasthttp.Server.ListenAndServeTLS.
|
||||
// New code should prefer ListenAndServeTLSEmbed.
|
||||
func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error {
|
||||
if IsChild() {
|
||||
ln, err := p.listenAsChild(addr)
|
||||
|
||||
+344
-190
@@ -3,44 +3,78 @@ package prefork
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func setUp() {
|
||||
os.Setenv(preforkChildEnvVariable, "1")
|
||||
// freeAddr returns a TCP4 address that the OS believes is free at call time.
|
||||
// It avoids the rand-port collision flakes the previous helper exhibited.
|
||||
func freeAddr(t testing.TB) string {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("freeAddr: %v", err)
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
if closeErr := ln.Close(); closeErr != nil {
|
||||
t.Fatalf("freeAddr: close: %v", closeErr)
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
func tearDown() {
|
||||
os.Unsetenv(preforkChildEnvVariable)
|
||||
}
|
||||
// noopChildProducer returns a CommandProducer that re-execs the test binary
|
||||
// into a no-op subprocess. The returned cleanup must be deferred (or registered
|
||||
// via t.Cleanup) so leaked subprocesses are reaped if the test fails midway.
|
||||
func noopChildProducer(t testing.TB) (func(files []*os.File) (*exec.Cmd, error), func()) {
|
||||
t.Helper()
|
||||
var (
|
||||
mu sync.Mutex
|
||||
spawned []*exec.Cmd
|
||||
)
|
||||
|
||||
func getAddr() string {
|
||||
return fmt.Sprintf("127.0.0.1:%d", rand.Intn(9000-3000)+3000)
|
||||
produce := func(_ []*os.File) (*exec.Cmd, error) {
|
||||
cmd := exec.Command(os.Args[0], "-test.run=^$")
|
||||
cmd.Env = append(os.Environ(), preforkChildEnvVariable+"="+preforkChildEnvValue)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mu.Lock()
|
||||
spawned = append(spawned, cmd)
|
||||
mu.Unlock()
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
for _, cmd := range spawned {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
continue
|
||||
}
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
}
|
||||
}
|
||||
return produce, cleanup
|
||||
}
|
||||
|
||||
func Test_IsChild(t *testing.T) {
|
||||
// This test can't run parallel as it modifies the process environment.
|
||||
|
||||
v := IsChild()
|
||||
if v {
|
||||
t.Errorf("IsChild() == %v, want %v", v, false)
|
||||
// This test cannot run in parallel — IsChild() reads a process-global env var.
|
||||
if IsChild() {
|
||||
t.Fatal("test starts as child unexpectedly")
|
||||
}
|
||||
|
||||
setUp()
|
||||
defer tearDown()
|
||||
|
||||
v = IsChild()
|
||||
if !v {
|
||||
t.Errorf("IsChild() == %v, want %v", v, true)
|
||||
t.Setenv(preforkChildEnvVariable, preforkChildEnvValue)
|
||||
if !IsChild() {
|
||||
t.Errorf("IsChild() == false after Setenv, want true")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,51 +87,37 @@ func Test_New(t *testing.T) {
|
||||
if p.Network != defaultNetwork {
|
||||
t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
|
||||
}
|
||||
|
||||
if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() {
|
||||
t.Errorf("Prefork.ServeFunc == %p, want %p", p.ServeFunc, s.Serve)
|
||||
if p.RecoverThreshold <= 0 {
|
||||
t.Errorf("Prefork.RecoverThreshold == %d, want > 0", p.RecoverThreshold)
|
||||
}
|
||||
|
||||
if reflect.ValueOf(p.ServeTLSFunc).Pointer() != reflect.ValueOf(s.ServeTLS).Pointer() {
|
||||
t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSFunc, s.ServeTLS)
|
||||
}
|
||||
|
||||
if reflect.ValueOf(p.ServeTLSEmbedFunc).Pointer() != reflect.ValueOf(s.ServeTLSEmbed).Pointer() {
|
||||
t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSEmbedFunc, s.ServeTLSEmbed)
|
||||
if p.ServeFunc == nil || p.ServeTLSFunc == nil || p.ServeTLSEmbedFunc == nil {
|
||||
t.Error("New() did not wire one of ServeFunc/ServeTLSFunc/ServeTLSEmbedFunc")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_listen(t *testing.T) {
|
||||
func Test_listen_Reuseport(t *testing.T) {
|
||||
prev := runtime.GOMAXPROCS(0)
|
||||
t.Cleanup(func() {
|
||||
runtime.GOMAXPROCS(prev)
|
||||
})
|
||||
|
||||
p := &Prefork{
|
||||
Reuseport: true,
|
||||
}
|
||||
addr := getAddr()
|
||||
p := &Prefork{Reuseport: true}
|
||||
addr := freeAddr(t)
|
||||
|
||||
ln, err := p.listen(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = ln.Close()
|
||||
})
|
||||
|
||||
ln.Close()
|
||||
|
||||
lnAddr := ln.Addr().String()
|
||||
if lnAddr != addr {
|
||||
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
|
||||
if got, want := ln.Addr().String(), addr; got != want {
|
||||
t.Errorf("ln.Addr() == %q, want %q", got, want)
|
||||
}
|
||||
|
||||
if p.Network != defaultNetwork {
|
||||
t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
|
||||
}
|
||||
|
||||
procs := runtime.GOMAXPROCS(0)
|
||||
if procs != 1 {
|
||||
t.Errorf("GOMAXPROCS == %d, want %d", procs, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_setTCPListenerFiles(t *testing.T) {
|
||||
@@ -108,144 +128,187 @@ func Test_setTCPListenerFiles(t *testing.T) {
|
||||
}
|
||||
|
||||
p := &Prefork{}
|
||||
addr := getAddr()
|
||||
addr := freeAddr(t)
|
||||
|
||||
err := p.setTCPListenerFiles(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
if err := p.setTCPListenerFiles(addr); err != nil {
|
||||
t.Fatalf("setTCPListenerFiles: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = p.ln.Close()
|
||||
for _, f := range p.files {
|
||||
_ = f.Close()
|
||||
}
|
||||
})
|
||||
|
||||
if p.ln == nil {
|
||||
t.Fatal("Prefork.ln is nil")
|
||||
t.Fatal("p.ln is nil after setTCPListenerFiles")
|
||||
}
|
||||
|
||||
p.ln.Close()
|
||||
|
||||
lnAddr := p.ln.Addr().String()
|
||||
if lnAddr != addr {
|
||||
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
|
||||
if got, want := p.ln.Addr().String(), addr; got != want {
|
||||
t.Errorf("p.ln.Addr() == %q, want %q", got, want)
|
||||
}
|
||||
|
||||
if p.Network != defaultNetwork {
|
||||
t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
|
||||
}
|
||||
|
||||
if len(p.files) != 1 {
|
||||
t.Errorf("Prefork.files == %d, want %d", len(p.files), 1)
|
||||
t.Errorf("len(p.files) == %d, want 1", len(p.files))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ListenAndServe(t *testing.T) {
|
||||
// This test can't run parallel as it modifies the process environment.
|
||||
func Test_setTCPListenerFiles_BadAddr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setUp()
|
||||
defer tearDown()
|
||||
|
||||
s := &fasthttp.Server{}
|
||||
p := New(s)
|
||||
p.Reuseport = true
|
||||
p.ServeFunc = func(ln net.Listener) error {
|
||||
return nil
|
||||
if runtime.GOOS == "windows" {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
addr := getAddr()
|
||||
|
||||
err := p.ListenAndServe(addr)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
p.ln.Close()
|
||||
|
||||
lnAddr := p.ln.Addr().String()
|
||||
if lnAddr != addr {
|
||||
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
|
||||
}
|
||||
|
||||
if p.ln == nil {
|
||||
t.Error("Prefork.ln is nil")
|
||||
p := &Prefork{}
|
||||
if err := p.setTCPListenerFiles("definitely not an address"); err == nil {
|
||||
t.Fatal("expected error for malformed addr, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ListenAndServeTLS(t *testing.T) {
|
||||
// This test can't run parallel as it modifies the process environment.
|
||||
// Test_ListenAndServe_Stub_ChildPath drives the child branch of all three
|
||||
// ListenAndServe* entry points using a stubbed Serve function. It replaces the
|
||||
// previous trio of near-identical tests that only validated field assignment.
|
||||
func Test_ListenAndServe_Stub_ChildPath(t *testing.T) {
|
||||
// child env mutation precludes t.Parallel.
|
||||
t.Setenv(preforkChildEnvVariable, preforkChildEnvValue)
|
||||
|
||||
setUp()
|
||||
defer tearDown()
|
||||
|
||||
s := &fasthttp.Server{}
|
||||
p := New(s)
|
||||
p.Reuseport = true
|
||||
p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error {
|
||||
return nil
|
||||
type call struct {
|
||||
listener bool
|
||||
certFile string
|
||||
keyFile string
|
||||
certData string
|
||||
keyData string
|
||||
}
|
||||
|
||||
addr := getAddr()
|
||||
|
||||
err := p.ListenAndServeTLS(addr, "./key", "./cert")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(t *testing.T, p *Prefork, addr string) error
|
||||
want call
|
||||
}{
|
||||
{
|
||||
name: "ListenAndServe",
|
||||
run: func(_ *testing.T, p *Prefork, addr string) error { return p.ListenAndServe(addr) },
|
||||
want: call{listener: true},
|
||||
},
|
||||
{
|
||||
name: "ListenAndServeTLS",
|
||||
run: func(_ *testing.T, p *Prefork, addr string) error {
|
||||
return p.ListenAndServeTLS(addr, "./key", "./cert")
|
||||
},
|
||||
want: call{listener: true, certFile: "./cert", keyFile: "./key"},
|
||||
},
|
||||
{
|
||||
name: "ListenAndServeTLSEmbed",
|
||||
run: func(_ *testing.T, p *Prefork, addr string) error {
|
||||
return p.ListenAndServeTLSEmbed(addr, []byte("certPEM"), []byte("keyPEM"))
|
||||
},
|
||||
want: call{listener: true, certData: "certPEM", keyData: "keyPEM"},
|
||||
},
|
||||
}
|
||||
|
||||
p.ln.Close()
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var got call
|
||||
p := New(&fasthttp.Server{})
|
||||
p.Reuseport = true
|
||||
p.ServeFunc = func(ln net.Listener) error {
|
||||
got.listener = ln != nil
|
||||
return nil
|
||||
}
|
||||
p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error {
|
||||
got.listener = ln != nil
|
||||
got.certFile = certFile
|
||||
got.keyFile = keyFile
|
||||
return nil
|
||||
}
|
||||
p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error {
|
||||
got.listener = ln != nil
|
||||
got.certData = string(certData)
|
||||
got.keyData = string(keyData)
|
||||
return nil
|
||||
}
|
||||
|
||||
lnAddr := p.ln.Addr().String()
|
||||
if lnAddr != addr {
|
||||
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
|
||||
}
|
||||
addr := freeAddr(t)
|
||||
if err := tc.run(t, p, addr); err != nil {
|
||||
t.Fatalf("%s: %v", tc.name, err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if p.ln != nil {
|
||||
_ = p.ln.Close()
|
||||
}
|
||||
})
|
||||
|
||||
if p.ln == nil {
|
||||
t.Error("Prefork.ln is nil")
|
||||
if got != tc.want {
|
||||
t.Errorf("%s call = %+v, want %+v", tc.name, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ListenAndServeTLSEmbed(t *testing.T) {
|
||||
// This test can't run parallel as it modifies the process environment.
|
||||
func Test_doCommand_CommandProducerErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setUp()
|
||||
defer tearDown()
|
||||
|
||||
s := &fasthttp.Server{}
|
||||
p := New(s)
|
||||
p.Reuseport = true
|
||||
p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error {
|
||||
return nil
|
||||
tests := []struct {
|
||||
name string
|
||||
produce func(files []*os.File) (*exec.Cmd, error)
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "producer returns error",
|
||||
produce: func([]*os.File) (*exec.Cmd, error) {
|
||||
return nil, errors.New("boom")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "producer returns nil cmd",
|
||||
//nolint:nilnil // intentionally tests the (nil, nil) misbehaviour guard
|
||||
produce: func([]*os.File) (*exec.Cmd, error) {
|
||||
return nil, nil
|
||||
},
|
||||
wantErr: ErrCommandProducerNilCmd,
|
||||
},
|
||||
{
|
||||
name: "producer returns unstarted cmd",
|
||||
produce: func([]*os.File) (*exec.Cmd, error) {
|
||||
return &exec.Cmd{}, nil
|
||||
},
|
||||
wantErr: ErrCommandProducerNotStarted,
|
||||
},
|
||||
}
|
||||
|
||||
addr := getAddr()
|
||||
|
||||
err := p.ListenAndServeTLSEmbed(addr, []byte("key"), []byte("cert"))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
p.ln.Close()
|
||||
|
||||
lnAddr := p.ln.Addr().String()
|
||||
if lnAddr != addr {
|
||||
t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
|
||||
}
|
||||
|
||||
if p.ln == nil {
|
||||
t.Error("Prefork.ln is nil")
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := &Prefork{CommandProducer: tc.produce}
|
||||
cmd, err := p.doCommand()
|
||||
if cmd != nil {
|
||||
t.Errorf("expected nil cmd on error, got %v", cmd)
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if tc.wantErr != nil && !errors.Is(err, tc.wantErr) {
|
||||
t.Errorf("err = %v, want errors.Is %v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
mu sync.Mutex
|
||||
messages []string
|
||||
}
|
||||
|
||||
func (l *testLogger) Printf(format string, args ...any) {
|
||||
l.messages = append(l.messages, fmt.Sprintf(format, args...))
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
l.mu.Lock()
|
||||
l.messages = append(l.messages, msg)
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
// Test_Prefork_Lifecycle runs the full prefork lifecycle with a CommandProducer
|
||||
// and verifies that callbacks are invoked in the correct order with the correct arguments.
|
||||
// Test_Prefork_Lifecycle drives prefork() to ErrOverRecovery via
|
||||
// short-lived no-op children and asserts the callback ordering / arguments.
|
||||
func Test_Prefork_Lifecycle(t *testing.T) {
|
||||
prev := runtime.GOMAXPROCS(2)
|
||||
t.Cleanup(func() {
|
||||
runtime.GOMAXPROCS(prev)
|
||||
})
|
||||
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
|
||||
|
||||
type event struct {
|
||||
name string
|
||||
@@ -260,16 +323,14 @@ func Test_Prefork_Lifecycle(t *testing.T) {
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
produce, cleanup := noopChildProducer(t)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
p := &Prefork{
|
||||
Reuseport: true,
|
||||
RecoverThreshold: 1,
|
||||
Logger: &testLogger{},
|
||||
CommandProducer: func(_ []*os.File) (*exec.Cmd, error) {
|
||||
cmd := exec.Command(os.Args[0], "-test.run=^$")
|
||||
cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1")
|
||||
err := cmd.Start()
|
||||
return cmd, err
|
||||
},
|
||||
CommandProducer: produce,
|
||||
OnChildSpawn: func(pid int) error {
|
||||
record("spawn", pid)
|
||||
return nil
|
||||
@@ -278,12 +339,13 @@ func Test_Prefork_Lifecycle(t *testing.T) {
|
||||
record("ready", childPIDs...)
|
||||
return nil
|
||||
},
|
||||
OnChildRecover: func(oldPid, newPid int) {
|
||||
record("recover", oldPid, newPid)
|
||||
OnChildRecover: func(oldPID, newPID int) error {
|
||||
record("recover", oldPID, newPID)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
err := p.prefork(getAddr())
|
||||
err := p.prefork(freeAddr(t))
|
||||
if !errors.Is(err, ErrOverRecovery) {
|
||||
t.Fatalf("expected ErrOverRecovery, got: %v", err)
|
||||
}
|
||||
@@ -291,10 +353,9 @@ func Test_Prefork_Lifecycle(t *testing.T) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Verify we got spawn events for initial children
|
||||
var spawnCount int
|
||||
var readyCount int
|
||||
var recoverCount int
|
||||
goMaxProcs := runtime.GOMAXPROCS(0)
|
||||
|
||||
var spawnCount, readyCount, recoverCount int
|
||||
for _, e := range events {
|
||||
switch e.name {
|
||||
case "spawn":
|
||||
@@ -318,21 +379,17 @@ func Test_Prefork_Lifecycle(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
goMaxProcs := runtime.GOMAXPROCS(0)
|
||||
|
||||
if readyCount != 1 {
|
||||
t.Errorf("OnMasterReady called %d times, want 1", readyCount)
|
||||
}
|
||||
|
||||
// Initial spawns + at least one recovery spawn
|
||||
if spawnCount < goMaxProcs {
|
||||
t.Errorf("OnChildSpawn called %d times, want at least %d", spawnCount, goMaxProcs)
|
||||
}
|
||||
|
||||
if recoverCount == 0 {
|
||||
t.Error("OnChildRecover was never called")
|
||||
}
|
||||
|
||||
// ready must come after exactly goMaxProcs initial spawns.
|
||||
readyIdx := -1
|
||||
spawnsBeforeReady := 0
|
||||
for i, e := range events {
|
||||
@@ -344,7 +401,6 @@ func Test_Prefork_Lifecycle(t *testing.T) {
|
||||
spawnsBeforeReady++
|
||||
}
|
||||
}
|
||||
|
||||
if readyIdx == -1 {
|
||||
t.Fatal("OnMasterReady was never called")
|
||||
}
|
||||
@@ -352,8 +408,8 @@ func Test_Prefork_Lifecycle(t *testing.T) {
|
||||
t.Errorf("OnMasterReady called after %d initial spawns, want %d", spawnsBeforeReady, goMaxProcs)
|
||||
}
|
||||
|
||||
// every recover event must be preceded by a spawn for the new PID.
|
||||
recoveredSpawnByPID := make(map[int]bool)
|
||||
recoveredPIDs := make(map[int]bool)
|
||||
for _, e := range events[readyIdx+1:] {
|
||||
if e.name == "spawn" {
|
||||
recoveredSpawnByPID[e.pids[0]] = true
|
||||
@@ -362,56 +418,154 @@ func Test_Prefork_Lifecycle(t *testing.T) {
|
||||
if !recoveredSpawnByPID[e.pids[1]] {
|
||||
t.Errorf("OnChildRecover for PID %d happened before OnChildSpawn", e.pids[1])
|
||||
}
|
||||
recoveredPIDs[e.pids[1]] = true
|
||||
}
|
||||
}
|
||||
for pid := range recoveredPIDs {
|
||||
if !recoveredSpawnByPID[pid] {
|
||||
t.Errorf("OnChildRecover for PID %d did not have a matching OnChildSpawn", pid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Prefork_RecoveredChildSpawnError(t *testing.T) {
|
||||
func Test_Prefork_InitialChildSpawnError(t *testing.T) {
|
||||
prev := runtime.GOMAXPROCS(2)
|
||||
t.Cleanup(func() {
|
||||
runtime.GOMAXPROCS(prev)
|
||||
})
|
||||
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
|
||||
|
||||
expectedErr := errors.New("spawn failed")
|
||||
var spawnCount int
|
||||
var recoverCount int
|
||||
produce, cleanup := noopChildProducer(t)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
expectedErr := errors.New("initial spawn rejected")
|
||||
var calls atomic.Int32
|
||||
|
||||
p := &Prefork{
|
||||
Reuseport: true,
|
||||
RecoverThreshold: 1,
|
||||
Logger: &testLogger{},
|
||||
CommandProducer: func(_ []*os.File) (*exec.Cmd, error) {
|
||||
cmd := exec.Command(os.Args[0], "-test.run=^$")
|
||||
cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1")
|
||||
err := cmd.Start()
|
||||
return cmd, err
|
||||
CommandProducer: produce,
|
||||
OnChildSpawn: func(_ int) error {
|
||||
calls.Add(1)
|
||||
return expectedErr
|
||||
},
|
||||
}
|
||||
|
||||
err := p.prefork(freeAddr(t))
|
||||
if !errors.Is(err, expectedErr) {
|
||||
t.Fatalf("expected %v, got: %v", expectedErr, err)
|
||||
}
|
||||
if calls.Load() == 0 {
|
||||
t.Fatal("OnChildSpawn was never invoked")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Prefork_OnMasterReadyError(t *testing.T) {
|
||||
prev := runtime.GOMAXPROCS(2)
|
||||
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
|
||||
|
||||
produce, cleanup := noopChildProducer(t)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
expectedErr := errors.New("ready rejected")
|
||||
p := &Prefork{
|
||||
Reuseport: true,
|
||||
RecoverThreshold: 1,
|
||||
Logger: &testLogger{},
|
||||
CommandProducer: produce,
|
||||
OnMasterReady: func([]int) error {
|
||||
return expectedErr
|
||||
},
|
||||
}
|
||||
|
||||
err := p.prefork(freeAddr(t))
|
||||
if !errors.Is(err, expectedErr) {
|
||||
t.Fatalf("expected %v, got: %v", expectedErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Prefork_RecoveredChildSpawnError(t *testing.T) {
|
||||
prev := runtime.GOMAXPROCS(2)
|
||||
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
|
||||
|
||||
produce, cleanup := noopChildProducer(t)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
expectedErr := errors.New("spawn failed")
|
||||
var spawnCount, recoverCount atomic.Int32
|
||||
|
||||
p := &Prefork{
|
||||
Reuseport: true,
|
||||
RecoverThreshold: 1,
|
||||
Logger: &testLogger{},
|
||||
CommandProducer: produce,
|
||||
OnChildSpawn: func(pid int) error {
|
||||
if pid <= 0 {
|
||||
t.Errorf("OnChildSpawn called with invalid PID: %d", pid)
|
||||
}
|
||||
spawnCount++
|
||||
if spawnCount > runtime.GOMAXPROCS(0) {
|
||||
n := spawnCount.Add(1)
|
||||
if int(n) > runtime.GOMAXPROCS(0) {
|
||||
return expectedErr
|
||||
}
|
||||
return nil
|
||||
},
|
||||
OnChildRecover: func(_, _ int) {
|
||||
recoverCount++
|
||||
OnChildRecover: func(_, _ int) error {
|
||||
recoverCount.Add(1)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
err := p.prefork(getAddr())
|
||||
err := p.prefork(freeAddr(t))
|
||||
if !errors.Is(err, expectedErr) {
|
||||
t.Fatalf("expected %v, got: %v", expectedErr, err)
|
||||
}
|
||||
if recoverCount != 0 {
|
||||
t.Fatalf("OnChildRecover called %d times, want 0", recoverCount)
|
||||
if got := recoverCount.Load(); got != 0 {
|
||||
t.Fatalf("OnChildRecover called %d times, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Prefork_HookPanicSurfaces(t *testing.T) {
|
||||
prev := runtime.GOMAXPROCS(2)
|
||||
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
|
||||
|
||||
produce, cleanup := noopChildProducer(t)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
p := &Prefork{
|
||||
Reuseport: true,
|
||||
RecoverThreshold: 1,
|
||||
Logger: &testLogger{},
|
||||
CommandProducer: produce,
|
||||
OnChildSpawn: func(_ int) error {
|
||||
panic("user code blew up")
|
||||
},
|
||||
}
|
||||
|
||||
err := p.prefork(freeAddr(t))
|
||||
if err == nil {
|
||||
t.Fatal("expected error from panicking hook, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// Test_Prefork_RecoverInterval verifies the optional backoff actually delays
|
||||
// the respawn — without rabbit-holing into hard wallclock assertions.
|
||||
func Test_Prefork_RecoverInterval(t *testing.T) {
|
||||
prev := runtime.GOMAXPROCS(2)
|
||||
t.Cleanup(func() { runtime.GOMAXPROCS(prev) })
|
||||
|
||||
produce, cleanup := noopChildProducer(t)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
const interval = 50 * time.Millisecond
|
||||
p := &Prefork{
|
||||
Reuseport: true,
|
||||
RecoverThreshold: 1,
|
||||
RecoverInterval: interval,
|
||||
Logger: &testLogger{},
|
||||
CommandProducer: produce,
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
err := p.prefork(freeAddr(t))
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if !errors.Is(err, ErrOverRecovery) {
|
||||
t.Fatalf("expected ErrOverRecovery, got %v", err)
|
||||
}
|
||||
// At least one recover interval must have elapsed before threshold fired.
|
||||
if elapsed < interval {
|
||||
t.Errorf("elapsed %v < interval %v; backoff did not apply", elapsed, interval)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user