fix(prefork): tighten recovery callback flow

This commit is contained in:
René
2026-05-02 12:12:45 +02:00
parent b5233e2b48
commit db78ffe6f1
2 changed files with 54 additions and 42 deletions
+7 -7
View File
@@ -81,7 +81,8 @@ type Prefork struct {
// OnChildSpawn is called in the master process whenever a new child process is spawned.
// It receives the PID of the newly spawned child process.
//
// If this callback returns an error, the prefork operation will be aborted.
// If this callback returns an error, all child processes are killed and the
// prefork operation returns that error.
OnChildSpawn func(pid int) error
// OnMasterReady is called in the master process after all child processes have been spawned.
@@ -298,7 +299,6 @@ func (p *Prefork) prefork(addr string) (err error) {
defer func() {
for _, proc := range childProcs {
_ = proc.Process.Kill()
_, _ = proc.Process.Wait() // avoid zombie processes after Kill
}
}()
@@ -368,16 +368,16 @@ func (p *Prefork) prefork(addr string) (err error) {
sigCh <- procSig{pid: pid, err: c.Wait()}
}(cmd, newPid)
if p.OnChildRecover != nil {
p.OnChildRecover(sig.pid, newPid)
}
if p.OnChildSpawn != nil {
if err = p.OnChildSpawn(newPid); err != nil {
p.logger().Printf("OnChildSpawn callback failed for recovered PID %d: %v\n", newPid, err)
break
return err
}
}
if p.OnChildRecover != nil {
p.OnChildRecover(sig.pid, newPid)
}
}
return err
+47 -35
View File
@@ -54,10 +54,6 @@ func Test_New(t *testing.T) {
t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
}
if p.RecoverThreshold != defaultRecoverThreshold() {
t.Errorf("Prefork.RecoverThreshold == %d, want %d", p.RecoverThreshold, defaultRecoverThreshold())
}
if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() {
t.Errorf("Prefork.ServeFunc == %p, want %p", p.ServeFunc, s.Serve)
}
@@ -71,17 +67,6 @@ func Test_New(t *testing.T) {
}
}
func Test_defaultRecoverThreshold_SingleCore(t *testing.T) {
prev := runtime.GOMAXPROCS(1)
t.Cleanup(func() {
runtime.GOMAXPROCS(prev)
})
if threshold := defaultRecoverThreshold(); threshold != 1 {
t.Errorf("defaultRecoverThreshold() == %d, want 1", threshold)
}
}
func Test_listen(t *testing.T) {
prev := runtime.GOMAXPROCS(0)
t.Cleanup(func() {
@@ -246,26 +231,6 @@ func Test_ListenAndServeTLSEmbed(t *testing.T) {
}
}
func Test_Prefork_Logger(t *testing.T) {
t.Parallel()
s := &fasthttp.Server{}
p := New(s)
// Test default logger
logger := p.logger()
if logger == nil {
t.Error("Default logger should not be nil")
}
// Test custom logger
customLogger := &testLogger{}
p.Logger = customLogger
if p.logger() != customLogger {
t.Error("Custom logger should be returned")
}
}
type testLogger struct {
messages []string
}
@@ -394,6 +359,9 @@ func Test_Prefork_Lifecycle(t *testing.T) {
recoveredSpawnByPID[e.pids[0]] = true
}
if e.name == "recover" {
if !recoveredSpawnByPID[e.pids[1]] {
t.Errorf("OnChildRecover for PID %d happened before OnChildSpawn", e.pids[1])
}
recoveredPIDs[e.pids[1]] = true
}
}
@@ -403,3 +371,47 @@ func Test_Prefork_Lifecycle(t *testing.T) {
}
}
}
func Test_Prefork_RecoveredChildSpawnError(t *testing.T) {
prev := runtime.GOMAXPROCS(2)
t.Cleanup(func() {
runtime.GOMAXPROCS(prev)
})
expectedErr := errors.New("spawn failed")
var spawnCount int
var recoverCount int
p := &Prefork{
Reuseport: true,
RecoverThreshold: 1,
Logger: &testLogger{},
CommandProducer: func(_ []*os.File) (*exec.Cmd, error) {
cmd := exec.Command(os.Args[0], "-test.run=^$")
cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1")
err := cmd.Start()
return cmd, err
},
OnChildSpawn: func(pid int) error {
if pid <= 0 {
t.Errorf("OnChildSpawn called with invalid PID: %d", pid)
}
spawnCount++
if spawnCount > runtime.GOMAXPROCS(0) {
return expectedErr
}
return nil
},
OnChildRecover: func(_, _ int) {
recoverCount++
},
}
err := p.prefork(getAddr())
if !errors.Is(err, expectedErr) {
t.Fatalf("expected %v, got: %v", expectedErr, err)
}
if recoverCount != 0 {
t.Fatalf("OnChildRecover called %d times, want 0", recoverCount)
}
}