1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
|
package requestcompression
import (
"bytes"
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"io"
"net/http"
)
const captureUncompressedRequestID = "CaptureUncompressedRequest"
// AddCaptureUncompressedRequestMiddleware captures http request before compress encoding for check
func AddCaptureUncompressedRequestMiddleware(stack *middleware.Stack, buf *bytes.Buffer) error {
return stack.Serialize.Insert(&captureUncompressedRequestMiddleware{
buf: buf,
}, "RequestCompression", middleware.Before)
}
type captureUncompressedRequestMiddleware struct {
req *http.Request
buf *bytes.Buffer
bytes []byte
}
// ID returns id of the captureUncompressedRequestMiddleware
func (*captureUncompressedRequestMiddleware) ID() string {
return captureUncompressedRequestID
}
// HandleSerialize captures request payload before it is compressed by request compression middleware
func (m *captureUncompressedRequestMiddleware) HandleSerialize(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler,
) (
output middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
request, ok := input.Request.(*smithyhttp.Request)
if !ok {
return output, metadata, fmt.Errorf("error when retrieving http request")
}
_, err = io.Copy(m.buf, request.GetStream())
if err != nil {
return output, metadata, fmt.Errorf("error when copying http request stream: %q", err)
}
if err = request.RewindStream(); err != nil {
return output, metadata, fmt.Errorf("error when rewinding request stream: %q", err)
}
return next.HandleSerialize(ctx, input)
}
|