Skip to content

Commit

Permalink
Improve memory usage when extracting zip archives (#21)
Browse files Browse the repository at this point in the history
* zip: if the reader is capable of seeking do not buffer the entire archive

* lint: Remove usage of deprecated io/util package

* Avoid buffering in archive detection if the stream is seekable

* Slightly increase test limits
  • Loading branch information
cmaglie committed May 26, 2023
1 parent 40e27c6 commit 07d6c33
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 25 deletions.
31 changes: 14 additions & 17 deletions extract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
Expand Down Expand Up @@ -187,9 +186,9 @@ func TestArchiveFailure(t *testing.T) {

func TestExtract(t *testing.T) {
for _, test := range ExtractCases {
dir, _ := ioutil.TempDir("", "")
dir, _ := os.MkdirTemp("", "")
dir = filepath.Join(dir, "test")
data, err := ioutil.ReadFile(test.Archive)
data, err := os.ReadFile(test.Archive)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -222,8 +221,8 @@ func TestExtract(t *testing.T) {
}

func BenchmarkArchive(b *testing.B) {
dir, _ := ioutil.TempDir("", "")
data, _ := ioutil.ReadFile("testdata/archive.tar.bz2")
dir, _ := os.MkdirTemp("", "")
data, _ := os.ReadFile("testdata/archive.tar.bz2")

b.StartTimer()

Expand All @@ -244,8 +243,8 @@ func BenchmarkArchive(b *testing.B) {
}

func BenchmarkTarBz2(b *testing.B) {
dir, _ := ioutil.TempDir("", "")
data, _ := ioutil.ReadFile("testdata/archive.tar.bz2")
dir, _ := os.MkdirTemp("", "")
data, _ := os.ReadFile("testdata/archive.tar.bz2")

b.StartTimer()

Expand All @@ -266,8 +265,8 @@ func BenchmarkTarBz2(b *testing.B) {
}

func BenchmarkTarGz(b *testing.B) {
dir, _ := ioutil.TempDir("", "")
data, _ := ioutil.ReadFile("testdata/archive.tar.gz")
dir, _ := os.MkdirTemp("", "")
data, _ := os.ReadFile("testdata/archive.tar.gz")

b.StartTimer()

Expand All @@ -288,8 +287,8 @@ func BenchmarkTarGz(b *testing.B) {
}

func BenchmarkZip(b *testing.B) {
dir, _ := ioutil.TempDir("", "")
data, _ := ioutil.ReadFile("testdata/archive.zip")
dir, _ := os.MkdirTemp("", "")
data, _ := os.ReadFile("testdata/archive.zip")

b.StartTimer()

Expand Down Expand Up @@ -319,7 +318,7 @@ func testWalk(t *testing.T, dir string, testFiles Files) {
} else if info.Mode()&os.ModeSymlink != 0 {
files[path] = "link"
} else {
data, err := ioutil.ReadFile(filepath.Join(dir, path))
data, err := os.ReadFile(filepath.Join(dir, path))
require.NoError(t, err)
files[path] = strings.TrimSpace(string(data))
}
Expand Down Expand Up @@ -370,7 +369,7 @@ func TestTarGzMemoryConsumption(t *testing.T) {
runtime.GC()
runtime.ReadMemStats(&m)

err = extract.Gz(context.Background(), f, tmpDir.String(), nil)
err = extract.Archive(context.Background(), f, tmpDir.String(), nil)
require.NoError(t, err)

runtime.ReadMemStats(&m2)
Expand Down Expand Up @@ -398,7 +397,7 @@ func TestZipMemoryConsumption(t *testing.T) {
runtime.GC()
runtime.ReadMemStats(&m)

err = extract.Zip(context.Background(), f, tmpDir.String(), nil)
err = extract.Archive(context.Background(), f, tmpDir.String(), nil)
require.NoError(t, err)

runtime.ReadMemStats(&m2)
Expand All @@ -407,9 +406,7 @@ func TestZipMemoryConsumption(t *testing.T) {
heapUsed = 0
}
fmt.Println("Heap memory used during the test:", heapUsed)
// the .zip file require random access, so the full io.Reader content must be cached, since
// the test file is 130MB, that's the reason for the high memory consumed.
require.True(t, heapUsed < 250000000, "heap consumption should be less than 250M but is %d", heapUsed)
require.True(t, heapUsed < 10000000, "heap consumption should be less than 10M but is %d", heapUsed)
}

func download(t require.TestingT, url string, file *paths.Path) error {
Expand Down
40 changes: 32 additions & 8 deletions extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"compress/bzip2"
"compress/gzip"
"context"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -237,11 +237,27 @@ func (e *Extractor) Tar(ctx context.Context, body io.Reader, location string, re
// Zip extracts a .zip archived stream of data in the specified location.
// It accepts a rename function to handle the names of the files (see the example).
func (e *Extractor) Zip(ctx context.Context, body io.Reader, location string, rename Renamer) error {
// read the whole body into a buffer. Not sure this is the best way to do it
buffer := bytes.NewBuffer([]byte{})
copyCancel(ctx, buffer, body)

archive, err := zip.NewReader(bytes.NewReader(buffer.Bytes()), int64(buffer.Len()))
var bodySize int64
bodyReaderAt, isReaderAt := (body).(io.ReaderAt)
if bodySeeker, isSeeker := (body).(io.Seeker); isReaderAt && isSeeker {
// get the size by seeking to the end
endPos, err := bodySeeker.Seek(0, io.SeekEnd)
if err != nil {
return fmt.Errorf("failed to seek to the end of the body: %s", err)
}
// reset the reader to the beginning
if _, err := bodySeeker.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek to the beginning of the body: %w", err)
}
bodySize = endPos
} else {
// read the whole body into a buffer. Not sure this is the best way to do it
buffer := bytes.NewBuffer([]byte{})
copyCancel(ctx, buffer, body)
bodyReaderAt = bytes.NewReader(buffer.Bytes())
bodySize = int64(buffer.Len())
}
archive, err := zip.NewReader(bodyReaderAt, bodySize)
if err != nil {
return errors.Annotatef(err, "Read the zip file")
}
Expand Down Expand Up @@ -290,7 +306,7 @@ func (e *Extractor) Zip(ctx context.Context, body io.Reader, location string, re
case info.Mode()&os.ModeSymlink != 0:
if f, err := header.Open(); err != nil {
return errors.Annotatef(err, "Open link %s", path)
} else if name, err := ioutil.ReadAll(f); err != nil {
} else if name, err := io.ReadAll(f); err != nil {
return errors.Annotatef(err, "Read address of link %s", path)
} else {
links = append(links, link{Path: path, Name: string(name)})
Expand Down Expand Up @@ -347,7 +363,15 @@ func match(r io.Reader) (io.Reader, types.Type, error) {
return nil, types.Unknown, err
}

r = io.MultiReader(bytes.NewBuffer(buffer[:n]), r)
if seeker, ok := r.(io.Seeker); ok {
// if the stream is seekable, we just rewind it
if _, err := seeker.Seek(0, io.SeekStart); err != nil {
return nil, types.Unknown, err
}
} else {
// otherwise we create a new reader that will prepend the buffer
r = io.MultiReader(bytes.NewBuffer(buffer[:n]), r)
}

typ, err := filetype.Match(buffer)

Expand Down

0 comments on commit 07d6c33

Please sign in to comment.