mirror of
https://github.com/valyala/fasthttp.git
synced 2026-06-14 15:56:44 +03:00
fix(prefork): tighten recovery callback flow
This commit is contained in:
+7
-7
@@ -81,7 +81,8 @@ type Prefork struct {
|
||||
// OnChildSpawn is called in the master process whenever a new child process is spawned.
|
||||
// It receives the PID of the newly spawned child process.
|
||||
//
|
||||
// If this callback returns an error, the prefork operation will be aborted.
|
||||
// If this callback returns an error, all child processes are killed and the
|
||||
// prefork operation returns that error.
|
||||
OnChildSpawn func(pid int) error
|
||||
|
||||
// OnMasterReady is called in the master process after all child processes have been spawned.
|
||||
@@ -298,7 +299,6 @@ func (p *Prefork) prefork(addr string) (err error) {
|
||||
defer func() {
|
||||
for _, proc := range childProcs {
|
||||
_ = proc.Process.Kill()
|
||||
_, _ = proc.Process.Wait() // avoid zombie processes after Kill
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -368,16 +368,16 @@ func (p *Prefork) prefork(addr string) (err error) {
|
||||
sigCh <- procSig{pid: pid, err: c.Wait()}
|
||||
}(cmd, newPid)
|
||||
|
||||
if p.OnChildRecover != nil {
|
||||
p.OnChildRecover(sig.pid, newPid)
|
||||
}
|
||||
|
||||
if p.OnChildSpawn != nil {
|
||||
if err = p.OnChildSpawn(newPid); err != nil {
|
||||
p.logger().Printf("OnChildSpawn callback failed for recovered PID %d: %v\n", newPid, err)
|
||||
break
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if p.OnChildRecover != nil {
|
||||
p.OnChildRecover(sig.pid, newPid)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
|
||||
+47
-35
@@ -54,10 +54,6 @@ func Test_New(t *testing.T) {
|
||||
t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
|
||||
}
|
||||
|
||||
if p.RecoverThreshold != defaultRecoverThreshold() {
|
||||
t.Errorf("Prefork.RecoverThreshold == %d, want %d", p.RecoverThreshold, defaultRecoverThreshold())
|
||||
}
|
||||
|
||||
if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() {
|
||||
t.Errorf("Prefork.ServeFunc == %p, want %p", p.ServeFunc, s.Serve)
|
||||
}
|
||||
@@ -71,17 +67,6 @@ func Test_New(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_defaultRecoverThreshold_SingleCore(t *testing.T) {
|
||||
prev := runtime.GOMAXPROCS(1)
|
||||
t.Cleanup(func() {
|
||||
runtime.GOMAXPROCS(prev)
|
||||
})
|
||||
|
||||
if threshold := defaultRecoverThreshold(); threshold != 1 {
|
||||
t.Errorf("defaultRecoverThreshold() == %d, want 1", threshold)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_listen(t *testing.T) {
|
||||
prev := runtime.GOMAXPROCS(0)
|
||||
t.Cleanup(func() {
|
||||
@@ -246,26 +231,6 @@ func Test_ListenAndServeTLSEmbed(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Prefork_Logger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := &fasthttp.Server{}
|
||||
p := New(s)
|
||||
|
||||
// Test default logger
|
||||
logger := p.logger()
|
||||
if logger == nil {
|
||||
t.Error("Default logger should not be nil")
|
||||
}
|
||||
|
||||
// Test custom logger
|
||||
customLogger := &testLogger{}
|
||||
p.Logger = customLogger
|
||||
if p.logger() != customLogger {
|
||||
t.Error("Custom logger should be returned")
|
||||
}
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
messages []string
|
||||
}
|
||||
@@ -394,6 +359,9 @@ func Test_Prefork_Lifecycle(t *testing.T) {
|
||||
recoveredSpawnByPID[e.pids[0]] = true
|
||||
}
|
||||
if e.name == "recover" {
|
||||
if !recoveredSpawnByPID[e.pids[1]] {
|
||||
t.Errorf("OnChildRecover for PID %d happened before OnChildSpawn", e.pids[1])
|
||||
}
|
||||
recoveredPIDs[e.pids[1]] = true
|
||||
}
|
||||
}
|
||||
@@ -403,3 +371,47 @@ func Test_Prefork_Lifecycle(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Prefork_RecoveredChildSpawnError(t *testing.T) {
|
||||
prev := runtime.GOMAXPROCS(2)
|
||||
t.Cleanup(func() {
|
||||
runtime.GOMAXPROCS(prev)
|
||||
})
|
||||
|
||||
expectedErr := errors.New("spawn failed")
|
||||
var spawnCount int
|
||||
var recoverCount int
|
||||
|
||||
p := &Prefork{
|
||||
Reuseport: true,
|
||||
RecoverThreshold: 1,
|
||||
Logger: &testLogger{},
|
||||
CommandProducer: func(_ []*os.File) (*exec.Cmd, error) {
|
||||
cmd := exec.Command(os.Args[0], "-test.run=^$")
|
||||
cmd.Env = append(os.Environ(), preforkChildEnvVariable+"=1")
|
||||
err := cmd.Start()
|
||||
return cmd, err
|
||||
},
|
||||
OnChildSpawn: func(pid int) error {
|
||||
if pid <= 0 {
|
||||
t.Errorf("OnChildSpawn called with invalid PID: %d", pid)
|
||||
}
|
||||
spawnCount++
|
||||
if spawnCount > runtime.GOMAXPROCS(0) {
|
||||
return expectedErr
|
||||
}
|
||||
return nil
|
||||
},
|
||||
OnChildRecover: func(_, _ int) {
|
||||
recoverCount++
|
||||
},
|
||||
}
|
||||
|
||||
err := p.prefork(getAddr())
|
||||
if !errors.Is(err, expectedErr) {
|
||||
t.Fatalf("expected %v, got: %v", expectedErr, err)
|
||||
}
|
||||
if recoverCount != 0 {
|
||||
t.Fatalf("OnChildRecover called %d times, want 0", recoverCount)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user