1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
|
package customizations
import (
"context"
"fmt"
"net/url"
"github.com/aws/smithy-go"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
// AddPredictEndpointMiddleware adds the middleware required to set the endpoint
// based on Predict's PredictEndpoint input member.
func AddPredictEndpointMiddleware(stack *middleware.Stack, endpoint func(interface{}) (*string, error)) error {
return stack.Serialize.Insert(&predictEndpoint{}, "ResolveEndpoint", middleware.After)
}
// predictEndpoint rewrites the endpoint with whatever is specified in the
// operation input if it is non-nil and non-empty.
type predictEndpoint struct {
fetchPredictEndpoint func(interface{}) (*string, error)
}
// ID returns the id for the middleware.
func (*predictEndpoint) ID() string { return "MachineLearning:PredictEndpoint" }
// HandleSerialize implements the SerializeMiddleware interface.
func (m *predictEndpoint) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, &smithy.SerializationError{
Err: fmt.Errorf("unknown request type %T", in.Request),
}
}
endpoint, err := m.fetchPredictEndpoint(in.Parameters)
if err != nil {
return out, metadata, &smithy.SerializationError{
Err: fmt.Errorf("failed to fetch PredictEndpoint value, %v", err),
}
}
if endpoint != nil && len(*endpoint) != 0 {
uri, err := url.Parse(*endpoint)
if err != nil {
return out, metadata, &smithy.SerializationError{
Err: fmt.Errorf("unable to parse predict endpoint, %v", err),
}
}
req.URL = uri
}
return next.HandleSerialize(ctx, in)
}
|