1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
|
package imds
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/url"
"path"
"time"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
func addAPIRequestMiddleware(stack *middleware.Stack,
options Options,
getPath func(interface{}) (string, error),
getOutput func(*smithyhttp.Response) (interface{}, error),
) (err error) {
err = addRequestMiddleware(stack, options, "GET", getPath, getOutput)
if err != nil {
return err
}
// Token Serializer build and state management.
if !options.disableAPIToken {
err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
if err != nil {
return err
}
err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
if err != nil {
return err
}
}
return nil
}
func addRequestMiddleware(stack *middleware.Stack,
options Options,
method string,
getPath func(interface{}) (string, error),
getOutput func(*smithyhttp.Response) (interface{}, error),
) (err error) {
err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
if err != nil {
return err
}
// Operation timeout
err = stack.Initialize.Add(&operationTimeout{
DefaultTimeout: defaultOperationTimeout,
}, middleware.Before)
if err != nil {
return err
}
// Operation Serializer
err = stack.Serialize.Add(&serializeRequest{
GetPath: getPath,
Method: method,
}, middleware.After)
if err != nil {
return err
}
// Operation endpoint resolver
err = stack.Serialize.Insert(&resolveEndpoint{
Endpoint: options.Endpoint,
EndpointMode: options.EndpointMode,
}, "OperationSerializer", middleware.Before)
if err != nil {
return err
}
// Operation Deserializer
err = stack.Deserialize.Add(&deserializeResponse{
GetOutput: getOutput,
}, middleware.After)
if err != nil {
return err
}
err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{
LogRequest: options.ClientLogMode.IsRequest(),
LogRequestWithBody: options.ClientLogMode.IsRequestWithBody(),
LogResponse: options.ClientLogMode.IsResponse(),
LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(),
}, middleware.After)
if err != nil {
return err
}
err = addSetLoggerMiddleware(stack, options)
if err != nil {
return err
}
// Retry support
return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
Retryer: options.Retryer,
LogRetryAttempts: options.ClientLogMode.IsRetries(),
})
}
func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error {
return middleware.AddSetLoggerMiddleware(stack, o.Logger)
}
type serializeRequest struct {
GetPath func(interface{}) (string, error)
Method string
}
func (*serializeRequest) ID() string {
return "OperationSerializer"
}
func (m *serializeRequest) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
request, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
}
reqPath, err := m.GetPath(in.Parameters)
if err != nil {
return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
}
request.Request.URL.Path = reqPath
request.Request.Method = m.Method
return next.HandleSerialize(ctx, in)
}
type deserializeResponse struct {
GetOutput func(*smithyhttp.Response) (interface{}, error)
}
func (*deserializeResponse) ID() string {
return "OperationDeserializer"
}
func (m *deserializeResponse) HandleDeserialize(
ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
) (
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
) {
out, metadata, err = next.HandleDeserialize(ctx, in)
if err != nil {
return out, metadata, err
}
resp, ok := out.RawResponse.(*smithyhttp.Response)
if !ok {
return out, metadata, fmt.Errorf(
"unexpected transport response type, %T, want %T", out.RawResponse, resp)
}
defer resp.Body.Close()
// read the full body so that any operation timeouts cleanup will not race
// the body being read.
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return out, metadata, fmt.Errorf("read response body failed, %w", err)
}
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
// Anything that's not 200 |< 300 is error
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return out, metadata, &smithyhttp.ResponseError{
Response: resp,
Err: fmt.Errorf("request to EC2 IMDS failed"),
}
}
result, err := m.GetOutput(resp)
if err != nil {
return out, metadata, fmt.Errorf(
"unable to get deserialized result for response, %w", err,
)
}
out.Result = result
return out, metadata, err
}
type resolveEndpoint struct {
Endpoint string
EndpointMode EndpointModeState
}
func (*resolveEndpoint) ID() string {
return "ResolveEndpoint"
}
func (m *resolveEndpoint) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
}
var endpoint string
if len(m.Endpoint) > 0 {
endpoint = m.Endpoint
} else {
switch m.EndpointMode {
case EndpointModeStateIPv6:
endpoint = defaultIPv6Endpoint
case EndpointModeStateIPv4:
fallthrough
case EndpointModeStateUnset:
endpoint = defaultIPv4Endpoint
default:
return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
}
}
req.URL, err = url.Parse(endpoint)
if err != nil {
return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
}
return next.HandleSerialize(ctx, in)
}
const (
defaultOperationTimeout = 5 * time.Second
)
// operationTimeout adds a timeout on the middleware stack if the Context the
// stack was called with does not have a deadline. The next middleware must
// complete before the timeout, or the context will be canceled.
//
// If DefaultTimeout is zero, no default timeout will be used if the Context
// does not have a timeout.
//
// The next middleware must also ensure that any resources that are also
// canceled by the stack's context are completely consumed before returning.
// Otherwise the timeout cleanup will race the resource being consumed
// upstream.
type operationTimeout struct {
DefaultTimeout time.Duration
}
func (*operationTimeout) ID() string { return "OperationTimeout" }
func (m *operationTimeout) HandleInitialize(
ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
) (
output middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
var cancelFn func()
ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
defer cancelFn()
}
return next.HandleInitialize(ctx, input)
}
// appendURIPath joins a URI path component to the existing path with `/`
// separators between the path components. If the path being added ends with a
// trailing `/` that slash will be maintained.
func appendURIPath(base, add string) string {
reqPath := path.Join(base, add)
if len(add) != 0 && add[len(add)-1] == '/' {
reqPath += "/"
}
return reqPath
}
|