diff --git a/fs.go b/fs.go index 1273f72..ba532a9 100644 --- a/fs.go +++ b/fs.go @@ -1740,13 +1740,16 @@ func (h *fsHandler) compressFileNolock( _ = zf.Close() _ = f.Close() if err != nil { + _ = os.Remove(tmpFilePath) return nil, fmt.Errorf("error when compressing file %q to %q: %w", filePath, tmpFilePath, err) } if err = os.Chtimes(tmpFilePath, time.Now(), fileInfo.ModTime()); err != nil { + _ = os.Remove(tmpFilePath) return nil, fmt.Errorf("cannot change modification time to %v for tmp file %q: %v", fileInfo.ModTime(), tmpFilePath, err) } if err = os.Rename(tmpFilePath, compressedFilePath); err != nil { + _ = os.Remove(tmpFilePath) return nil, fmt.Errorf("cannot move compressed file from %q to %q: %w", tmpFilePath, compressedFilePath, err) } return h.newCompressedFSFile(compressedFilePath, fileEncoding) diff --git a/fs_test.go b/fs_test.go index cc283ad..6e872e5 100644 --- a/fs_test.go +++ b/fs_test.go @@ -3,8 +3,10 @@ package fasthttp import ( "bufio" "bytes" + "errors" "fmt" "io" + iofs "io/fs" "math/rand" "os" "path/filepath" @@ -882,6 +884,42 @@ func runFSCompressSingleThread(t *testing.T, fs *FS) { testFSCompress(t, h, "/README.md") } +// errReadFile is an fs.File whose Read always fails, used to force the +// compression step in compressFileNolock to return an error. +type errReadFile struct { + fi iofs.FileInfo +} + +func (f *errReadFile) Stat() (iofs.FileInfo, error) { return f.fi, nil } +func (f *errReadFile) Read([]byte) (int, error) { return 0, errors.New("forced read error") } +func (f *errReadFile) Close() error { return nil } + +func TestFSCompressTmpFileRemovedOnError(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + // Use this test file as the source for a valid fs.FileInfo. + fi, err := os.Stat("fs_test.go") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + h := &fsHandler{} + compressedFilePath := filepath.Join(dir, "out.gz") + tmpFilePath := compressedFilePath + ".tmp" + + _, err = h.compressFileNolock( + &errReadFile{fi: fi}, fi, "fs_test.go", compressedFilePath, "gzip") + if err == nil { + t.Fatalf("expecting error when compression fails") + } + + if _, err := os.Stat(tmpFilePath); !os.IsNotExist(err) { + t.Fatalf("temporary file %q must be removed on compression error, stat err: %v", tmpFilePath, err) + } +} + func testFSCompress(t *testing.T, h RequestHandler, filePath string) { t.Helper()