1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
|
package testutils
import (
"crypto/tls"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"github.com/vulcand/oxy/v2/internal/holsterv4/clock"
"github.com/vulcand/oxy/v2/utils"
)
// NewHandler creates a new Server.
func NewHandler(handler http.HandlerFunc) *httptest.Server {
return httptest.NewServer(handler)
}
// NewResponder creates a new Server with response.
func NewResponder(response string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(response))
}))
}
// ParseURI is the version of url.ParseRequestURI that panics if incorrect, helpful to shorten the tests.
func ParseURI(uri string) *url.URL {
out, err := url.ParseRequestURI(uri)
if err != nil {
panic(err)
}
return out
}
// ReqOpts request options.
type ReqOpts struct {
Host string
Method string
Body string
Headers http.Header
Auth *utils.BasicAuth
}
// ReqOption request option type.
type ReqOption func(o *ReqOpts) error
// Method sets request method.
func Method(m string) ReqOption {
return func(o *ReqOpts) error {
o.Method = m
return nil
}
}
// Host sets request host.
func Host(h string) ReqOption {
return func(o *ReqOpts) error {
o.Host = h
return nil
}
}
// Body sets request body.
func Body(b string) ReqOption {
return func(o *ReqOpts) error {
o.Body = b
return nil
}
}
// Header sets request header.
func Header(name, val string) ReqOption {
return func(o *ReqOpts) error {
if o.Headers == nil {
o.Headers = make(http.Header)
}
o.Headers.Add(name, val)
return nil
}
}
// Headers sets request headers.
func Headers(h http.Header) ReqOption {
return func(o *ReqOpts) error {
if o.Headers == nil {
o.Headers = make(http.Header)
}
utils.CopyHeaders(o.Headers, h)
return nil
}
}
// BasicAuth sets request basic auth.
func BasicAuth(username, password string) ReqOption {
return func(o *ReqOpts) error {
o.Auth = &utils.BasicAuth{
Username: username,
Password: password,
}
return nil
}
}
// MakeRequest create and do a request.
func MakeRequest(uri string, opts ...ReqOption) (*http.Response, []byte, error) {
o := &ReqOpts{}
for _, s := range opts {
if err := s(o); err != nil {
return nil, nil, err
}
}
if o.Method == "" {
o.Method = http.MethodGet
}
request, err := http.NewRequest(o.Method, uri, strings.NewReader(o.Body))
if err != nil {
return nil, nil, err
}
if o.Headers != nil {
utils.CopyHeaders(request.Header, o.Headers)
}
if o.Auth != nil {
request.Header.Set("Authorization", o.Auth.String())
}
if o.Host != "" {
request.Host = o.Host
}
var tr *http.Transport
if strings.HasPrefix(uri, "https") {
tr = &http.Transport{
DisableKeepAlives: true,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: request.Host,
},
}
} else {
tr = &http.Transport{
DisableKeepAlives: true,
}
}
client := &http.Client{
Transport: tr,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return errors.New("no redirects")
},
}
response, err := client.Do(request)
if err == nil {
bodyBytes, errRead := io.ReadAll(response.Body)
return response, bodyBytes, errRead
}
return response, nil, err
}
// Get do a GET request.
func Get(uri string, opts ...ReqOption) (*http.Response, []byte, error) {
opts = append(opts, Method(http.MethodGet))
return MakeRequest(uri, opts...)
}
// Post do a POST request.
func Post(uri string, opts ...ReqOption) (*http.Response, []byte, error) {
opts = append(opts, Method(http.MethodPost))
return MakeRequest(uri, opts...)
}
// FreezeTime to the predetermined time. Returns a function that should be
// deferred to unfreeze time. Meant for testing.
func FreezeTime() func() {
clock.Freeze(clock.Date(2012, 3, 4, 5, 6, 7, 0, clock.UTC))
return clock.Unfreeze
}
|