From db78ffe6f154ccd75e06806bc028d05f4d86c3cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9?= Date: Sat, 2 May 2026 12:12:45 +0200 Subject: [PATCH] fix(prefork): tighten recovery callback flow --- prefork/prefork.go | 14 +++---- prefork/prefork_test.go | 82 +++++++++++++++++++++++------------------ 2 files changed, 54 insertions(+), 42 deletions(-) diff --git a/prefork/prefork.go b/prefork/prefork.go index 0b6e6b5..52a7a88 100644 --- a/prefork/prefork.go +++ b/prefork/prefork.go @@ -81,7 +81,8 @@ type Prefork struct { // OnChildSpawn is called in the master process whenever a new child process is spawned. // It receives the PID of the newly spawned child process. // - // If this callback returns an error, the prefork operation will be aborted. + // If this callback returns an error, all child processes are killed and the + // prefork operation returns that error. OnChildSpawn func(pid int) error // OnMasterReady is called in the master process after all child processes have been spawned. @@ -298,7 +299,6 @@ func (p *Prefork) prefork(addr string) (err error) { defer func() { for _, proc := range childProcs { _ = proc.Process.Kill() - _, _ = proc.Process.Wait() // avoid zombie processes after Kill } }() @@ -368,16 +368,16 @@ func (p *Prefork) prefork(addr string) (err error) { sigCh <- procSig{pid: pid, err: c.Wait()} }(cmd, newPid) - if p.OnChildRecover != nil { - p.OnChildRecover(sig.pid, newPid) - } - if p.OnChildSpawn != nil { if err = p.OnChildSpawn(newPid); err != nil { p.logger().Printf("OnChildSpawn callback failed for recovered PID %d: %v\n", newPid, err) - break + return err } } + + if p.OnChildRecover != nil { + p.OnChildRecover(sig.pid, newPid) + } } return err diff --git a/prefork/prefork_test.go b/prefork/prefork_test.go index 795e641..d358b51 100644 --- a/prefork/prefork_test.go +++ b/prefork/prefork_test.go @@ -54,10 +54,6 @@ func Test_New(t *testing.T) { t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork) } - if p.RecoverThreshold != defaultRecoverThreshold() { - t.Errorf("Prefork.RecoverThreshold == %d, want %d", p.RecoverThreshold, defaultRecoverThreshold()) - } - if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() { t.Errorf("Prefork.ServeFunc == %p, want %p", p.ServeFunc, s.Serve) } @@ -71,17 +67,6 @@ func Test_New(t *testing.T) { } } -func Test_defaultRecoverThreshold_SingleCore(t *testing.T) { - prev := runtime.GOMAXPROCS(1) - t.Cleanup(func() { - runtime.GOMAXPROCS(prev) - }) - - if threshold := defaultRecoverThreshold(); threshold != 1 { - t.Errorf("defaultRecoverThreshold() == %d, want 1", threshold) - } -} - func Test_listen(t *testing.T) { prev := runtime.GOMAXPROCS(0) t.Cleanup(func() { @@ -246,26 +231,6 @@ func Test_ListenAndServeTLSEmbed(t *testing.T) { } } -func Test_Prefork_Logger(t *testing.T) { - t.Parallel() - - s := &fasthttp.Server{} - p := New(s) - - // Test default logger - logger := p.logger() - if logger == nil { - t.Error("Default logger should not be nil") - } - - // Test custom logger - customLogger := &testLogger{} - p.Logger = customLogger - if p.logger() != customLogger { - t.Error("Custom logger should be returned") - } -} - type testLogger struct { messages []string } @@ -394,6 +359,9 @@ func Test_Prefork_Lifecycle(t *testing.T) { recoveredSpawnByPID[e.pids[0]] = true } if e.name == "recover" { + if !recoveredSpawnByPID[e.pids[1]] { + t.Errorf("OnChildRecover for PID %d happened before OnChildSpawn", e.pids[1]) + } recoveredPIDs[e.pids[1]] = true } } @@ -403,3 +371,47 @@ func Test_Prefork_Lifecycle(t *testing.T) { } } } + +func Test_Prefork_RecoveredChildSpawnError(t *testing.T) { + prev := runtime.GOMAXPROCS(2) + t.Cleanup(func() { + runtime.GOMAXPROCS(prev) + }) + + expectedErr := errors.New("spawn failed") + var spawnCount int + var recoverCount int + + p := &Prefork{ + Reuseport: true, + RecoverThreshold: 1, + Logger: &testLogger{}, + CommandProducer: func(_ []*os.File) (*exec.Cmd, error) { + cmd := exec.Command(os.Args[0], "-test.run=^$") + cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1") + err := cmd.Start() + return cmd, err + }, + OnChildSpawn: func(pid int) error { + if pid <= 0 { + t.Errorf("OnChildSpawn called with invalid PID: %d", pid) + } + spawnCount++ + if spawnCount > runtime.GOMAXPROCS(0) { + return expectedErr + } + return nil + }, + OnChildRecover: func(_, _ int) { + recoverCount++ + }, + } + + err := p.prefork(getAddr()) + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got: %v", expectedErr, err) + } + if recoverCount != 0 { + t.Fatalf("OnChildRecover called %d times, want 0", recoverCount) + } +}