1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
|
package http
import (
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
)
// ComputeContentLength provides a middleware to set the content-length
// header for the length of a serialize request body.
type ComputeContentLength struct {
}
// AddComputeContentLengthMiddleware adds ComputeContentLength to the middleware
// stack's Build step.
func AddComputeContentLengthMiddleware(stack *middleware.Stack) error {
return stack.Build.Add(&ComputeContentLength{}, middleware.After)
}
// ID returns the identifier for the ComputeContentLength.
func (m *ComputeContentLength) ID() string { return "ComputeContentLength" }
// HandleBuild adds the length of the serialized request to the HTTP header
// if the length can be determined.
func (m *ComputeContentLength) HandleBuild(
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}
// do nothing if request content-length was set to 0 or above.
if req.ContentLength >= 0 {
return next.HandleBuild(ctx, in)
}
// attempt to compute stream length
if n, ok, err := req.StreamLength(); err != nil {
return out, metadata, fmt.Errorf(
"failed getting length of request stream, %w", err)
} else if ok {
req.ContentLength = n
}
return next.HandleBuild(ctx, in)
}
// validateContentLength provides a middleware to validate the content-length
// is valid (greater than zero), for the serialized request payload.
type validateContentLength struct{}
// ValidateContentLengthHeader adds middleware that validates request content-length
// is set to value greater than zero.
func ValidateContentLengthHeader(stack *middleware.Stack) error {
return stack.Build.Add(&validateContentLength{}, middleware.After)
}
// ID returns the identifier for the ComputeContentLength.
func (m *validateContentLength) ID() string { return "ValidateContentLength" }
// HandleBuild adds the length of the serialized request to the HTTP header
// if the length can be determined.
func (m *validateContentLength) HandleBuild(
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}
// if request content-length was set to less than 0, return an error
if req.ContentLength < 0 {
return out, metadata, fmt.Errorf(
"content length for payload is required and must be at least 0")
}
return next.HandleBuild(ctx, in)
}
|