1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
|
package jwe
import (
"bytes"
"compress/flate"
"fmt"
"io"
"github.com/lestrrat-go/jwx/v2/internal/pool"
)
func uncompress(src []byte, maxBufferSize int64) ([]byte, error) {
var dst bytes.Buffer
r := flate.NewReader(bytes.NewReader(src))
defer r.Close()
var buf [16384]byte
var sofar int64
for {
n, readErr := r.Read(buf[:])
sofar += int64(n)
if sofar > maxBufferSize {
return nil, fmt.Errorf(`compressed payload exceeds maximum allowed size`)
}
if readErr != nil {
// if we have a read error, and it's not EOF, then we need to stop
if readErr != io.EOF {
return nil, fmt.Errorf(`failed to read inflated data: %w`, readErr)
}
}
if _, err := dst.Write(buf[:n]); err != nil {
return nil, fmt.Errorf(`failed to write inflated data: %w`, err)
}
if readErr != nil {
// if it got here, then readErr == io.EOF, we're done
return dst.Bytes(), nil
}
}
}
func compress(plaintext []byte) ([]byte, error) {
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
w, _ := flate.NewWriter(buf, 1)
in := plaintext
for len(in) > 0 {
n, err := w.Write(in)
if err != nil {
return nil, fmt.Errorf(`failed to write to compression writer: %w`, err)
}
in = in[n:]
}
if err := w.Close(); err != nil {
return nil, fmt.Errorf(`failed to close compression writer: %w`, err)
}
ret := make([]byte, buf.Len())
copy(ret, buf.Bytes())
return ret, nil
}
|