File: recursion_detection.go

package info (click to toggle)
golang-github-aws-aws-sdk-go-v2 1.24.1-2~bpo12%2B1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm-backports
  • size: 554,032 kB
  • sloc: java: 15,941; makefile: 419; sh: 175
file content (94 lines) | stat: -rw-r--r-- 2,447 bytes parent folder | download | duplicates (7)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package middleware

import (
	"context"
	"fmt"
	"github.com/aws/smithy-go/middleware"
	smithyhttp "github.com/aws/smithy-go/transport/http"
	"os"
)

const envAwsLambdaFunctionName = "AWS_LAMBDA_FUNCTION_NAME"
const envAmznTraceID = "_X_AMZN_TRACE_ID"
const amznTraceIDHeader = "X-Amzn-Trace-Id"

// AddRecursionDetection adds recursionDetection to the middleware stack
func AddRecursionDetection(stack *middleware.Stack) error {
	return stack.Build.Add(&RecursionDetection{}, middleware.After)
}

// RecursionDetection detects Lambda environment and sets its X-Ray trace ID to request header if absent
// to avoid recursion invocation in Lambda
type RecursionDetection struct{}

// ID returns the middleware identifier
func (m *RecursionDetection) ID() string {
	return "RecursionDetection"
}

// HandleBuild detects Lambda environment and adds its trace ID to request header if absent
func (m *RecursionDetection) HandleBuild(
	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
) (
	out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
	req, ok := in.Request.(*smithyhttp.Request)
	if !ok {
		return out, metadata, fmt.Errorf("unknown request type %T", req)
	}

	_, hasLambdaEnv := os.LookupEnv(envAwsLambdaFunctionName)
	xAmznTraceID, hasTraceID := os.LookupEnv(envAmznTraceID)
	value := req.Header.Get(amznTraceIDHeader)
	// only set the X-Amzn-Trace-Id header when it is not set initially, the
	// current environment is Lambda and the _X_AMZN_TRACE_ID env variable exists
	if value != "" || !hasLambdaEnv || !hasTraceID {
		return next.HandleBuild(ctx, in)
	}

	req.Header.Set(amznTraceIDHeader, percentEncode(xAmznTraceID))
	return next.HandleBuild(ctx, in)
}

func percentEncode(s string) string {
	upperhex := "0123456789ABCDEF"
	hexCount := 0
	for i := 0; i < len(s); i++ {
		c := s[i]
		if shouldEncode(c) {
			hexCount++
		}
	}

	if hexCount == 0 {
		return s
	}

	required := len(s) + 2*hexCount
	t := make([]byte, required)
	j := 0
	for i := 0; i < len(s); i++ {
		if c := s[i]; shouldEncode(c) {
			t[j] = '%'
			t[j+1] = upperhex[c>>4]
			t[j+2] = upperhex[c&15]
			j += 3
		} else {
			t[j] = c
			j++
		}
	}
	return string(t)
}

func shouldEncode(c byte) bool {
	if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
		return false
	}
	switch c {
	case '-', '=', ';', ':', '+', '&', '[', ']', '{', '}', '"', '\'', ',':
		return false
	default:
		return true
	}
}