File: compression.go

package info (click to toggle)
golang-mongodb-mongo-driver 1.17.1%2Bds1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie, trixie-proposed-updates
  • size: 25,988 kB
  • sloc: perl: 533; ansic: 491; python: 432; sh: 327; makefile: 174
file content (151 lines) | stat: -rw-r--r-- 4,029 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package driver

import (
	"bytes"
	"compress/zlib"
	"fmt"
	"io"
	"sync"

	"github.com/golang/snappy"
	"github.com/klauspost/compress/zstd"
	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)

// CompressionOpts holds settings for how to compress a payload
type CompressionOpts struct {
	Compressor       wiremessage.CompressorID
	ZlibLevel        int
	ZstdLevel        int
	UncompressedSize int32
}

var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder

func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
	if v, ok := zstdEncoders.Load(level); ok {
		return v.(*zstd.Encoder), nil
	}
	encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
	if err != nil {
		return nil, err
	}
	zstdEncoders.Store(level, encoder)
	return encoder, nil
}

var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder

func getZlibEncoder(level int) (*zlibEncoder, error) {
	if v, ok := zlibEncoders.Load(level); ok {
		return v.(*zlibEncoder), nil
	}
	writer, err := zlib.NewWriterLevel(nil, level)
	if err != nil {
		return nil, err
	}
	encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
	zlibEncoders.Store(level, encoder)

	return encoder, nil
}

type zlibEncoder struct {
	mu     sync.Mutex
	writer *zlib.Writer
	buf    *bytes.Buffer
}

func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
	e.mu.Lock()
	defer e.mu.Unlock()

	e.buf.Reset()
	e.writer.Reset(e.buf)

	_, err := e.writer.Write(src)
	if err != nil {
		return nil, err
	}
	err = e.writer.Close()
	if err != nil {
		return nil, err
	}
	dst = append(dst[:0], e.buf.Bytes()...)
	return dst, nil
}

// CompressPayload takes a byte slice and compresses it according to the options passed
func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
	switch opts.Compressor {
	case wiremessage.CompressorNoOp:
		return in, nil
	case wiremessage.CompressorSnappy:
		return snappy.Encode(nil, in), nil
	case wiremessage.CompressorZLib:
		encoder, err := getZlibEncoder(opts.ZlibLevel)
		if err != nil {
			return nil, err
		}
		return encoder.Encode(nil, in)
	case wiremessage.CompressorZstd:
		encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel))
		if err != nil {
			return nil, err
		}
		return encoder.EncodeAll(in, nil), nil
	default:
		return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
	}
}

// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
	switch opts.Compressor {
	case wiremessage.CompressorNoOp:
		return in, nil
	case wiremessage.CompressorSnappy:
		l, err := snappy.DecodedLen(in)
		if err != nil {
			return nil, fmt.Errorf("decoding compressed length %w", err)
		} else if int32(l) != opts.UncompressedSize {
			return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
		}
		uncompressed = make([]byte, opts.UncompressedSize)
		return snappy.Decode(uncompressed, in)
	case wiremessage.CompressorZLib:
		r, err := zlib.NewReader(bytes.NewReader(in))
		if err != nil {
			return nil, err
		}
		defer func() {
			err = r.Close()
		}()
		uncompressed = make([]byte, opts.UncompressedSize)
		_, err = io.ReadFull(r, uncompressed)
		if err != nil {
			return nil, err
		}
		return uncompressed, nil
	case wiremessage.CompressorZstd:
		r, err := zstd.NewReader(bytes.NewBuffer(in))
		if err != nil {
			return nil, err
		}
		defer r.Close()
		uncompressed = make([]byte, opts.UncompressedSize)
		_, err = io.ReadFull(r, uncompressed)
		if err != nil {
			return nil, err
		}
		return uncompressed, nil
	default:
		return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
	}
}