File: source.go

package info (click to toggle)
golang-github-vulcand-oxy 2.0.0-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 728 kB
  • sloc: makefile: 14
file content (61 lines) | stat: -rw-r--r-- 1,878 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
package utils

import (
	"fmt"
	"net/http"
	"strings"
)

// SourceExtractor extracts the source from the request, e.g. that may be client ip, or particular header that
// identifies the source. amount stands for amount of connections the source consumes, usually 1 for connection limiters
// error should be returned when source can not be identified.
type SourceExtractor interface {
	Extract(req *http.Request) (token string, amount int64, err error)
}

// ExtractorFunc extractor function type.
type ExtractorFunc func(req *http.Request) (token string, amount int64, err error)

// Extract extract from request.
func (f ExtractorFunc) Extract(req *http.Request) (string, int64, error) {
	return f(req)
}

// ExtractSource extract source function type.
type ExtractSource func(req *http.Request)

// NewExtractor creates a new SourceExtractor.
func NewExtractor(variable string) (SourceExtractor, error) {
	if variable == "client.ip" {
		return ExtractorFunc(extractClientIP), nil
	}
	if variable == "request.host" {
		return ExtractorFunc(extractHost), nil
	}
	if strings.HasPrefix(variable, "request.header.") {
		header := strings.TrimPrefix(variable, "request.header.")
		if header == "" {
			return nil, fmt.Errorf("wrong header: %s", header)
		}
		return makeHeaderExtractor(header), nil
	}
	return nil, fmt.Errorf("unsupported limiting variable: '%s'", variable)
}

func extractClientIP(req *http.Request) (string, int64, error) {
	vals := strings.SplitN(req.RemoteAddr, ":", 2)
	if vals[0] == "" {
		return "", 0, fmt.Errorf("failed to parse client IP: %v", req.RemoteAddr)
	}
	return vals[0], 1, nil
}

func extractHost(req *http.Request) (string, int64, error) {
	return req.Host, 1, nil
}

func makeHeaderExtractor(header string) SourceExtractor {
	return ExtractorFunc(func(req *http.Request) (string, int64, error) {
		return req.Header.Get(header), 1, nil
	})
}