File: request_compression_test.go

package info (click to toggle)
golang-github-aws-smithy-go 1.19.0-1~bpo12%2B1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm-backports
  • size: 2,680 kB
  • sloc: java: 15,917; xml: 166; sh: 131; makefile: 66
file content (125 lines) | stat: -rw-r--r-- 3,879 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package requestcompression

import (
	"bytes"
	"compress/gzip"
	"context"
	"fmt"
	"github.com/aws/smithy-go/middleware"
	"github.com/aws/smithy-go/transport/http"
	"io"
	"reflect"
	"strings"
	"testing"
)

func TestRequestCompression(t *testing.T) {
	cases := map[string]struct {
		DisableRequestCompression   bool
		RequestMinCompressSizeBytes int64
		ContentLength               int64
		Header                      map[string][]string
		Stream                      io.Reader
		ExpectedStream              []byte
		ExpectedHeader              map[string][]string
	}{
		"GZip request stream": {
			Stream:         strings.NewReader("Hi, world!"),
			ExpectedStream: []byte("Hi, world!"),
			ExpectedHeader: map[string][]string{
				"Content-Encoding": {"gzip"},
			},
		},
		"GZip request stream with existing encoding header": {
			Stream:         strings.NewReader("Hi, world!"),
			ExpectedStream: []byte("Hi, world!"),
			Header: map[string][]string{
				"Content-Encoding": {"custom"},
			},
			ExpectedHeader: map[string][]string{
				"Content-Encoding": {"custom, gzip"},
			},
		},
		"GZip request stream smaller than min compress request size": {
			RequestMinCompressSizeBytes: 100,
			Stream:                      strings.NewReader("Hi, world!"),
			ExpectedStream:              []byte("Hi, world!"),
			ExpectedHeader:              map[string][]string{},
		},
		"Disable GZip request stream": {
			DisableRequestCompression: true,
			Stream:                    strings.NewReader("Hi, world!"),
			ExpectedStream:            []byte("Hi, world!"),
			ExpectedHeader:            map[string][]string{},
		},
	}

	for name, c := range cases {
		t.Run(name, func(t *testing.T) {
			var err error
			req := http.NewStackRequest().(*http.Request)
			req.ContentLength = c.ContentLength
			req, _ = req.SetStream(c.Stream)
			if c.Header != nil {
				req.Header = c.Header
			}
			var updatedRequest *http.Request

			m := requestCompression{
				disableRequestCompression:   c.DisableRequestCompression,
				requestMinCompressSizeBytes: c.RequestMinCompressSizeBytes,
				compressAlgorithms:          []string{GZIP},
			}
			_, _, err = m.HandleSerialize(context.Background(),
				middleware.SerializeInput{Request: req},
				middleware.SerializeHandlerFunc(func(ctx context.Context, input middleware.SerializeInput) (
					out middleware.SerializeOutput, metadata middleware.Metadata, err error) {
					updatedRequest = input.Request.(*http.Request)
					return out, metadata, nil
				}),
			)
			if err != nil {
				t.Fatalf("expect no error, got %v", err)
			}

			if stream := updatedRequest.GetStream(); stream != nil {
				if err := testUnzipContent(stream, c.ExpectedStream, c.DisableRequestCompression, c.RequestMinCompressSizeBytes); err != nil {
					t.Errorf("error while checking request stream: %q", err)
				}
			}

			if e, a := c.ExpectedHeader, map[string][]string(updatedRequest.Header); !reflect.DeepEqual(e, a) {
				t.Errorf("expect request header to be %q, got %q", e, a)
			}
		})
	}
}

func testUnzipContent(content io.Reader, expect []byte, disableRequestCompression bool, requestMinCompressionSizeBytes int64) error {
	if disableRequestCompression || int64(len(expect)) < requestMinCompressionSizeBytes {
		b, err := io.ReadAll(content)
		if err != nil {
			return fmt.Errorf("error while reading request")
		}
		if e, a := expect, b; !bytes.Equal(e, a) {
			return fmt.Errorf("expect content to be %s, got %s", e, a)
		}
	} else {
		r, err := gzip.NewReader(content)
		if err != nil {
			return fmt.Errorf("error while reading request")
		}

		var actualBytes bytes.Buffer
		_, err = actualBytes.ReadFrom(r)
		if err != nil {
			return fmt.Errorf("error while unzipping request payload")
		}

		if e, a := expect, actualBytes.Bytes(); !bytes.Equal(e, a) {
			return fmt.Errorf("expect unzipped content to be %s, got %s", e, a)
		}
	}

	return nil
}