1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
|
//go:build integration
// +build integration
package s3control
import (
"context"
"crypto/tls"
"flag"
"fmt"
"net/http"
"os"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/internal/integrationtest"
"github.com/aws/aws-sdk-go-v2/service/s3control"
"github.com/aws/aws-sdk-go-v2/service/sts"
)
var (
svc *s3control.Client
s3ControlEndpoint, stsEndpoint string
accountID string
insecureTLS, useDualstack bool
)
var region = "us-west-2"
func TestMain(m *testing.M) {
flag.Parse()
flag.CommandLine.Visit(func(f *flag.Flag) {
if !(f.Name == "run" || f.Name == "test.run") {
return
}
value := f.Value.String()
if value == `NONE` {
os.Exit(0)
}
})
var result int
defer func() {
if r := recover(); r != nil {
fmt.Fprintln(os.Stderr, "S3 integration tests panic,", r)
result = 1
}
os.Exit(result)
}()
flag.StringVar(&stsEndpoint, "sts-endpoint", "",
"The optional `URL` endpoint for the STS service.",
)
flag.StringVar(&s3ControlEndpoint, "s3-control-endpoint", "",
"The optional `URL` endpoint for the S3 Control service.",
)
flag.BoolVar(&insecureTLS, "insecure-tls", false,
"Disables TLS validation on request endpoints.",
)
flag.BoolVar(&useDualstack, "dualstack", false,
"Enables usage of dualstack endpoints.",
)
flag.StringVar(&accountID, "account", "",
"The AWS account `ID`.",
)
// parse flag
flag.Parse()
tlsCfg := &tls.Config{}
if insecureTLS {
tlsCfg.InsecureSkipVerify = true
}
httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsCfg,
},
}
cfg, err := integrationtest.LoadConfigWithDefaultRegion(region)
if err != nil {
fmt.Fprintf(os.Stderr, "Error occurred while loading config with region %v, %v", region, err)
result = 1
return
}
cfg.HTTPClient = httpClient
// initialize context
ctx := context.Background()
if len(accountID) == 0 {
var opts = func(options *sts.Options) {}
if len(stsEndpoint) != 0 {
opts = func(options *sts.Options) {
options.EndpointResolver = sts.EndpointResolverFunc(func(region string, options sts.EndpointResolverOptions) (aws.Endpoint, error) {
return aws.Endpoint{
URL: stsEndpoint,
PartitionID: "aws",
SigningName: "sts",
SigningRegion: region,
}, nil
})
}
}
// initialize a sts client
stsClient := sts.NewFromConfig(cfg, opts)
identity, err := stsClient.GetCallerIdentity(ctx, nil)
if err != nil {
panic(fmt.Sprintf("failed to get accountID, %v", err))
}
accountID = *(identity.Account)
}
var s3controlOpts = func(options *s3control.Options) {}
if len(s3ControlEndpoint) != 0 {
s3controlOpts = func(options *s3control.Options) {
options.EndpointResolver = s3control.EndpointResolverFunc(func(region string, options s3control.EndpointResolverOptions) (aws.Endpoint, error) {
return aws.Endpoint{
URL: s3ControlEndpoint,
PartitionID: "aws",
SigningName: "s3-control",
SigningRegion: region,
}, nil
})
}
}
// construct a s3-control client
svc = s3control.NewFromConfig(cfg, s3controlOpts)
result = m.Run()
}
|