1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
package middleware
import (
"crypto/tls"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"path"
"reflect"
"runtime"
"testing"
"time"
)
var testdataDir string
func init() {
_, filename, _, _ := runtime.Caller(0)
testdataDir = path.Join(path.Dir(filename), "/../testdata")
}
func TestWrapWriterHTTP2(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Proto != "HTTP/2.0" {
t.Fatalf("request proto should be HTTP/2.0 but was %s", r.Proto)
}
_, fl := w.(http.Flusher)
if !fl {
t.Fatal("request should have been a http.Flusher")
}
_, hj := w.(http.Hijacker)
if hj {
t.Fatal("request should not have been a http.Hijacker")
}
_, rf := w.(io.ReaderFrom)
if rf {
t.Fatal("request should not have been an io.ReaderFrom")
}
_, ps := w.(http.Pusher)
if !ps {
t.Fatal("request should have been a http.Pusher")
}
w.Write([]byte("OK"))
})
wmw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(NewWrapResponseWriter(w, r.ProtoMajor), r)
})
}
server := http.Server{
Addr: ":7072",
Handler: wmw(handler),
}
// By serving over TLS, we get HTTP2 requests
go server.ListenAndServeTLS(testdataDir+"/cert.pem", testdataDir+"/key.pem")
defer server.Close()
// We need the server to start before making the request
time.Sleep(100 * time.Millisecond)
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
// The certificates we are using are self signed
InsecureSkipVerify: true,
},
ForceAttemptHTTP2: true,
},
}
resp, err := client.Get("https://localhost:7072")
if err != nil {
t.Fatalf("could not get server: %v", err)
}
if resp.StatusCode != 200 {
t.Fatalf("non 200 response: %v", resp.StatusCode)
}
}
func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
t.Fatal(err)
return nil, ""
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
return nil, ""
}
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
return nil, ""
}
defer resp.Body.Close()
return resp, string(respBody)
}
func testRequestNoRedirect(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
t.Fatal(err)
return nil, ""
}
// http client that doesn't redirect
httpClient := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
resp, err := httpClient.Do(req)
if err != nil {
t.Fatal(err)
return nil, ""
}
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
return nil, ""
}
defer resp.Body.Close()
return resp, string(respBody)
}
func assertNoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("expecting no error")
}
}
func assertError(t *testing.T, err error) {
t.Helper()
if err == nil {
t.Fatalf("expecting error")
}
}
func assertEqual(t *testing.T, a, b interface{}) {
t.Helper()
if !reflect.DeepEqual(a, b) {
t.Fatalf("expecting values to be equal but got: '%v' and '%v'", a, b)
}
}
|