1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
// Package connlimit provides control over simultaneous connections coming from the same source
package connlimit
import (
"fmt"
"net/http"
"sync"
"github.com/vulcand/oxy/v2/utils"
)
// ConnLimiter tracks concurrent connection per token
// and is capable of rejecting connections if they are failed.
type ConnLimiter struct {
mutex *sync.Mutex
extract utils.SourceExtractor
connections map[string]int64
maxConnections int64
totalConnections int64
next http.Handler
errHandler utils.ErrorHandler
verbose bool
log utils.Logger
}
// New creates a new ConnLimiter.
func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...Option) (*ConnLimiter, error) {
if extract == nil {
return nil, fmt.Errorf("extract function can not be nil")
}
cl := &ConnLimiter{
mutex: &sync.Mutex{},
extract: extract,
maxConnections: maxConnections,
connections: make(map[string]int64),
next: next,
log: &utils.NoopLogger{},
}
for _, o := range options {
if err := o(cl); err != nil {
return nil, err
}
}
if cl.errHandler == nil {
cl.errHandler = &ConnErrHandler{
debug: cl.verbose,
log: cl.log,
}
}
return cl, nil
}
// Wrap sets the next handler to be called by connection limiter handler.
func (cl *ConnLimiter) Wrap(h http.Handler) {
cl.next = h
}
func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
token, amount, err := cl.extract.Extract(r)
if err != nil {
cl.log.Error("failed to extract source of the connection: %v", err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
if err := cl.acquire(token, amount); err != nil {
cl.log.Debug("limiting request source %s: %v", token, err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
defer cl.release(token, amount)
cl.next.ServeHTTP(w, r)
}
func (cl *ConnLimiter) acquire(token string, amount int64) error {
cl.mutex.Lock()
defer cl.mutex.Unlock()
connections := cl.connections[token]
if connections >= cl.maxConnections {
return &MaxConnError{max: cl.maxConnections}
}
cl.connections[token] += amount
cl.totalConnections += amount
return nil
}
func (cl *ConnLimiter) release(token string, amount int64) {
cl.mutex.Lock()
defer cl.mutex.Unlock()
cl.connections[token] -= amount
cl.totalConnections -= amount
// Otherwise it would grow forever
if cl.connections[token] == 0 {
delete(cl.connections, token)
}
}
// MaxConnError maximum connections reached error.
type MaxConnError struct {
max int64
}
func (m *MaxConnError) Error() string {
return fmt.Sprintf("max connections reached: %d", m.max)
}
// ConnErrHandler connection limiter error handler.
type ConnErrHandler struct {
debug bool
log utils.Logger
}
func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if e.debug {
dump := utils.DumpHTTPRequest(req)
e.log.Debug("vulcand/oxy/connlimit: begin ServeHttp on request: %s", dump)
defer e.log.Debug("vulcand/oxy/connlimit: completed ServeHttp on request: %s", dump)
}
//nolint:errorlint // must be changed
if _, ok := err.(*MaxConnError); ok {
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(err.Error()))
return
}
utils.DefaultHandler.ServeHTTP(w, req, err)
}
|