File: setup_test.go

package info (click to toggle)
golang-github-aws-aws-sdk-go-v2 1.24.1-2~bpo12%2B1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm-backports
  • size: 554,032 kB
  • sloc: java: 15,941; makefile: 419; sh: 175
file content (133 lines) | stat: -rw-r--r-- 3,209 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
//go:build integration
// +build integration

package s3control

import (
	"context"
	"crypto/tls"
	"flag"
	"fmt"
	"net/http"
	"os"
	"testing"

	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/service/internal/integrationtest"
	"github.com/aws/aws-sdk-go-v2/service/s3control"
	"github.com/aws/aws-sdk-go-v2/service/sts"
)

var (
	svc                            *s3control.Client
	s3ControlEndpoint, stsEndpoint string
	accountID                      string
	insecureTLS, useDualstack      bool
)

var region = "us-west-2"

func TestMain(m *testing.M) {
	flag.Parse()
	flag.CommandLine.Visit(func(f *flag.Flag) {
		if !(f.Name == "run" || f.Name == "test.run") {
			return
		}
		value := f.Value.String()
		if value == `NONE` {
			os.Exit(0)
		}
	})

	var result int
	defer func() {
		if r := recover(); r != nil {
			fmt.Fprintln(os.Stderr, "S3 integration tests panic,", r)
			result = 1
		}
		os.Exit(result)
	}()

	flag.StringVar(&stsEndpoint, "sts-endpoint", "",
		"The optional `URL` endpoint for the STS service.",
	)
	flag.StringVar(&s3ControlEndpoint, "s3-control-endpoint", "",
		"The optional `URL` endpoint for the S3 Control service.",
	)
	flag.BoolVar(&insecureTLS, "insecure-tls", false,
		"Disables TLS validation on request endpoints.",
	)
	flag.BoolVar(&useDualstack, "dualstack", false,
		"Enables usage of dualstack endpoints.",
	)
	flag.StringVar(&accountID, "account", "",
		"The AWS account `ID`.",
	)
	// parse flag
	flag.Parse()

	tlsCfg := &tls.Config{}
	if insecureTLS {
		tlsCfg.InsecureSkipVerify = true
	}

	httpClient := &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: tlsCfg,
		},
	}

	cfg, err := integrationtest.LoadConfigWithDefaultRegion(region)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Error occurred while loading config with region %v, %v", region, err)
		result = 1
		return
	}
	cfg.HTTPClient = httpClient

	// initialize context
	ctx := context.Background()

	if len(accountID) == 0 {
		var opts = func(options *sts.Options) {}
		if len(stsEndpoint) != 0 {
			opts = func(options *sts.Options) {
				options.EndpointResolver = sts.EndpointResolverFunc(func(region string, options sts.EndpointResolverOptions) (aws.Endpoint, error) {
					return aws.Endpoint{
						URL:           stsEndpoint,
						PartitionID:   "aws",
						SigningName:   "sts",
						SigningRegion: region,
					}, nil
				})
			}
		}

		// initialize a sts client
		stsClient := sts.NewFromConfig(cfg, opts)

		identity, err := stsClient.GetCallerIdentity(ctx, nil)
		if err != nil {
			panic(fmt.Sprintf("failed to get accountID, %v", err))
		}
		accountID = *(identity.Account)
	}

	var s3controlOpts = func(options *s3control.Options) {}
	if len(s3ControlEndpoint) != 0 {
		s3controlOpts = func(options *s3control.Options) {
			options.EndpointResolver = s3control.EndpointResolverFunc(func(region string, options s3control.EndpointResolverOptions) (aws.Endpoint, error) {
				return aws.Endpoint{
					URL:           s3ControlEndpoint,
					PartitionID:   "aws",
					SigningName:   "s3-control",
					SigningRegion: region,
				}, nil
			})
		}
	}
	// construct a s3-control client
	svc = s3control.NewFromConfig(cfg, s3controlOpts)

	result = m.Run()
}