1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
|
package gock
import (
"compress/gzip"
"encoding/json"
"io"
"io/ioutil"
"net/http"
"reflect"
"regexp"
"strings"
"github.com/h2non/parth"
)
// EOL represents the end of line character.
const EOL = 0xa
// BodyTypes stores the supported MIME body types for matching.
// Currently only text-based types.
var BodyTypes = []string{
"text/html",
"text/plain",
"application/json",
"application/xml",
"multipart/form-data",
"application/x-www-form-urlencoded",
}
// BodyTypeAliases stores a generic MIME type by alias.
var BodyTypeAliases = map[string]string{
"html": "text/html",
"text": "text/plain",
"json": "application/json",
"xml": "application/xml",
"form": "multipart/form-data",
"url": "application/x-www-form-urlencoded",
}
// CompressionSchemes stores the supported Content-Encoding types for decompression.
var CompressionSchemes = []string{
"gzip",
}
// MatchMethod matches the HTTP method of the given request.
func MatchMethod(req *http.Request, ereq *Request) (bool, error) {
return ereq.Method == "" || req.Method == ereq.Method, nil
}
// MatchScheme matches the request URL protocol scheme.
func MatchScheme(req *http.Request, ereq *Request) (bool, error) {
return ereq.URLStruct.Scheme == "" || req.URL.Scheme == "" || ereq.URLStruct.Scheme == req.URL.Scheme, nil
}
// MatchHost matches the HTTP host header field of the given request.
func MatchHost(req *http.Request, ereq *Request) (bool, error) {
url := ereq.URLStruct
if strings.EqualFold(url.Host, req.URL.Host) {
return true, nil
}
if !ereq.Options.DisableRegexpHost {
return regexp.MatchString(url.Host, req.URL.Host)
}
return false, nil
}
// MatchPath matches the HTTP URL path of the given request.
func MatchPath(req *http.Request, ereq *Request) (bool, error) {
if req.URL.Path == ereq.URLStruct.Path {
return true, nil
}
return regexp.MatchString(ereq.URLStruct.Path, req.URL.Path)
}
// MatchHeaders matches the headers fields of the given request.
func MatchHeaders(req *http.Request, ereq *Request) (bool, error) {
for key, value := range ereq.Header {
var err error
var match bool
var matchEscaped bool
for _, field := range req.Header[key] {
match, err = regexp.MatchString(value[0], field)
// Some values may contain reserved regex params e.g. "()", try matching with these escaped.
matchEscaped, err = regexp.MatchString(regexp.QuoteMeta(value[0]), field)
if err != nil {
return false, err
}
if match || matchEscaped {
break
}
}
if !match && !matchEscaped {
return false, nil
}
}
return true, nil
}
// MatchQueryParams matches the URL query params fields of the given request.
func MatchQueryParams(req *http.Request, ereq *Request) (bool, error) {
for key, value := range ereq.URLStruct.Query() {
var err error
var match bool
for _, field := range req.URL.Query()[key] {
match, err = regexp.MatchString(value[0], field)
if err != nil {
return false, err
}
if match {
break
}
}
if !match {
return false, nil
}
}
return true, nil
}
// MatchPathParams matches the URL path parameters of the given request.
func MatchPathParams(req *http.Request, ereq *Request) (bool, error) {
for key, value := range ereq.PathParams {
var s string
if err := parth.Sequent(req.URL.Path, key, &s); err != nil {
return false, nil
}
if s != value {
return false, nil
}
}
return true, nil
}
// MatchBody tries to match the request body.
// TODO: not too smart now, needs several improvements.
func MatchBody(req *http.Request, ereq *Request) (bool, error) {
// If match body is empty, just continue
if req.Method == "HEAD" || len(ereq.BodyBuffer) == 0 {
return true, nil
}
// Only can match certain MIME body types
if !supportedType(req, ereq) {
return false, nil
}
// Can only match certain compression schemes
if !supportedCompressionScheme(req) {
return false, nil
}
// Create a reader for the body depending on compression type
bodyReader := req.Body
if ereq.CompressionScheme != "" {
if ereq.CompressionScheme != req.Header.Get("Content-Encoding") {
return false, nil
}
compressedBodyReader, err := compressionReader(req.Body, ereq.CompressionScheme)
if err != nil {
return false, err
}
bodyReader = compressedBodyReader
}
// Read the whole request body
body, err := ioutil.ReadAll(bodyReader)
if err != nil {
return false, err
}
// Restore body reader stream
req.Body = createReadCloser(body)
// If empty, ignore the match
if len(body) == 0 && len(ereq.BodyBuffer) != 0 {
return false, nil
}
// Match body by atomic string comparison
bodyStr := castToString(body)
matchStr := castToString(ereq.BodyBuffer)
if bodyStr == matchStr {
return true, nil
}
// Match request body by regexp
match, _ := regexp.MatchString(matchStr, bodyStr)
if match == true {
return true, nil
}
// todo - add conditional do only perform the conversion of body bytes
// representation of JSON to a map and then compare them for equality.
// Check if the key + value pairs match
var bodyMap map[string]interface{}
var matchMap map[string]interface{}
// Ensure that both byte bodies that that should be JSON can be converted to maps.
umErr := json.Unmarshal(body, &bodyMap)
umErr2 := json.Unmarshal(ereq.BodyBuffer, &matchMap)
if umErr == nil && umErr2 == nil && reflect.DeepEqual(bodyMap, matchMap) {
return true, nil
}
return false, nil
}
func supportedType(req *http.Request, ereq *Request) bool {
mime := req.Header.Get("Content-Type")
if mime == "" {
return true
}
mimeToMatch := ereq.Header.Get("Content-Type")
if mimeToMatch != "" {
return mime == mimeToMatch
}
for _, kind := range BodyTypes {
if match, _ := regexp.MatchString(kind, mime); match {
return true
}
}
return false
}
func supportedCompressionScheme(req *http.Request) bool {
encoding := req.Header.Get("Content-Encoding")
if encoding == "" {
return true
}
for _, kind := range CompressionSchemes {
if match, _ := regexp.MatchString(kind, encoding); match {
return true
}
}
return false
}
func castToString(buf []byte) string {
str := string(buf)
tail := len(str) - 1
if str[tail] == EOL {
str = str[:tail]
}
return str
}
func compressionReader(r io.ReadCloser, scheme string) (io.ReadCloser, error) {
switch scheme {
case "gzip":
return gzip.NewReader(r)
default:
return r, nil
}
}
|