1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
|
package utils
import (
"fmt"
"net/http"
"strings"
)
// SourceExtractor extracts the source from the request, e.g. that may be client ip, or particular header that
// identifies the source. amount stands for amount of connections the source consumes, usually 1 for connection limiters
// error should be returned when source can not be identified.
type SourceExtractor interface {
Extract(req *http.Request) (token string, amount int64, err error)
}
// ExtractorFunc extractor function type.
type ExtractorFunc func(req *http.Request) (token string, amount int64, err error)
// Extract extract from request.
func (f ExtractorFunc) Extract(req *http.Request) (string, int64, error) {
return f(req)
}
// ExtractSource extract source function type.
type ExtractSource func(req *http.Request)
// NewExtractor creates a new SourceExtractor.
func NewExtractor(variable string) (SourceExtractor, error) {
if variable == "client.ip" {
return ExtractorFunc(extractClientIP), nil
}
if variable == "request.host" {
return ExtractorFunc(extractHost), nil
}
if strings.HasPrefix(variable, "request.header.") {
header := strings.TrimPrefix(variable, "request.header.")
if header == "" {
return nil, fmt.Errorf("wrong header: %s", header)
}
return makeHeaderExtractor(header), nil
}
return nil, fmt.Errorf("unsupported limiting variable: '%s'", variable)
}
func extractClientIP(req *http.Request) (string, int64, error) {
vals := strings.SplitN(req.RemoteAddr, ":", 2)
if vals[0] == "" {
return "", 0, fmt.Errorf("failed to parse client IP: %v", req.RemoteAddr)
}
return vals[0], 1, nil
}
func extractHost(req *http.Request) (string, int64, error) {
return req.Host, 1, nil
}
func makeHeaderExtractor(header string) SourceExtractor {
return ExtractorFunc(func(req *http.Request) (string, int64, error) {
return req.Header.Get(header), 1, nil
})
}
|