1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
|
package test
import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httptest"
)
type httpMockResponse struct {
http.Response
Data []byte
}
func (r *httpMockResponse) ServeHTTP(http.ResponseWriter, *http.Request) {
}
type HTTPMockTransport struct {
mux *http.ServeMux
}
func (t *HTTPMockTransport) Clear() {
t.mux = http.NewServeMux()
}
func (t *HTTPMockTransport) Add(path string, res *http.Response, data []byte) {
if t.mux == nil {
t.mux = http.NewServeMux()
}
t.mux.Handle(path, &httpMockResponse{
Response: *res,
Data: data,
})
}
func (t *HTTPMockTransport) AddHandlerFunc(path string, hf http.HandlerFunc) {
if t.mux == nil {
t.mux = http.NewServeMux()
}
t.mux.HandleFunc(path, hf)
}
func (t *HTTPMockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
h, ptn := t.mux.Handler(req)
var res *httpMockResponse
if ptn == "" {
res = &httpMockResponse{
Response: http.Response{
StatusCode: 404,
},
}
} else {
var ok bool
res, ok = h.(*httpMockResponse)
if !ok {
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
res := &http.Response{
StatusCode: rw.Code,
Header: rw.HeaderMap,
Body: ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes())),
}
return res, nil
}
}
res.Response.Body = ioutil.NopCloser(bytes.NewReader(res.Data))
return &res.Response, nil
}
|