1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
// Package zstd implements the Zstandard decompressor.
package zstd
import (
"errors"
"fmt"
"io"
"runtime"
"sync"
"github.com/klauspost/compress/zstd"
)
type readCloser struct {
c io.Closer
r *zstd.Decoder
}
var (
//nolint:gochecknoglobals
zstdReaderPool sync.Pool
errAlreadyClosed = errors.New("zstd: already closed")
errNeedOneReader = errors.New("zstd: need exactly one reader")
)
func (rc *readCloser) Close() error {
if rc.c == nil {
return errAlreadyClosed
}
if err := rc.c.Close(); err != nil {
return fmt.Errorf("zstd: error closing: %w", err)
}
zstdReaderPool.Put(rc.r)
rc.c, rc.r = nil, nil
return nil
}
func (rc *readCloser) Read(p []byte) (int, error) {
if rc.r == nil {
return 0, errAlreadyClosed
}
n, err := rc.r.Read(p)
if err != nil && !errors.Is(err, io.EOF) {
err = fmt.Errorf("zstd: error reading: %w", err)
}
return n, err
}
// NewReader returns a new Zstandard io.ReadCloser.
func NewReader(_ []byte, _ uint64, readers []io.ReadCloser) (io.ReadCloser, error) {
if len(readers) != 1 {
return nil, errNeedOneReader
}
var err error
r, ok := zstdReaderPool.Get().(*zstd.Decoder)
if ok {
if err = r.Reset(readers[0]); err != nil {
return nil, fmt.Errorf("zstd: error resetting: %w", err)
}
} else {
if r, err = zstd.NewReader(readers[0]); err != nil {
return nil, fmt.Errorf("zstd: error creating reader: %w", err)
}
runtime.SetFinalizer(r, (*zstd.Decoder).Close)
}
return &readCloser{
c: readers[0],
r: r,
}, nil
}
|