1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
// Package proxy provides functionality for configuring and using a proxy server.
package proxy
import (
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"time"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/helper/nginx"
)
const (
// matches the default size used in httputil.ReverseProxy
bufferPoolSize = 32 * 1024
)
var (
defaultTarget = helper.URLMustParse("http://localhost")
// pool is a buffer pool that is shared across all Proxy instances to maximize buffer reuse.
pool = newBufferPool()
)
// Proxy represents a proxy configuration with various settings.
type Proxy struct {
Version string
reverseProxy *httputil.ReverseProxy
AllowResponseBuffering bool
customHeaders map[string]string
forceTargetHostHeader bool
}
// WithCustomHeaders is a function that returns a configuration function to set custom headers for a proxy.
func WithCustomHeaders(customHeaders map[string]string) func(*Proxy) {
return func(proxy *Proxy) {
proxy.customHeaders = customHeaders
}
}
// WithForcedTargetHostHeader is a function that returns a configuration function to force the target host header for a proxy.
func WithForcedTargetHostHeader() func(*Proxy) {
return func(proxy *Proxy) {
proxy.forceTargetHostHeader = true
}
}
// NewProxy creates a new Proxy instance with the provided options.
func NewProxy(myURL *url.URL, version string, roundTripper http.RoundTripper, options ...func(*Proxy)) *Proxy {
p := Proxy{Version: version, AllowResponseBuffering: true, customHeaders: make(map[string]string)}
if myURL == nil {
myURL = defaultTarget
}
u := *myURL // Make a copy of p.URL
u.Path = ""
p.reverseProxy = httputil.NewSingleHostReverseProxy(&u)
p.reverseProxy.Transport = roundTripper
p.reverseProxy.BufferPool = pool
chainDirector(p.reverseProxy, func(r *http.Request) {
r.Header.Set("Gitlab-Workhorse", p.Version)
r.Header.Set("Gitlab-Workhorse-Proxy-Start", fmt.Sprintf("%d", time.Now().UnixNano()))
for k, v := range p.customHeaders {
r.Header.Set(k, v)
}
})
for _, option := range options {
option(&p)
}
if p.forceTargetHostHeader {
// because of https://github.com/golang/go/issues/28168, the
// upstream won't receive the expected Host header unless this
// is forced in the Director func here
chainDirector(p.reverseProxy, func(request *http.Request) {
// send original host along for the upstream
// to know it's being proxied under a different Host
// (for redirects and other stuff that depends on this)
request.Header.Set("X-Forwarded-Host", request.Host)
request.Header.Set("Forwarded", fmt.Sprintf("host=%s", request.Host))
// override the Host with the target
request.Host = request.URL.Host
})
}
return &p
}
func chainDirector(rp *httputil.ReverseProxy, nextDirector func(*http.Request)) {
previous := rp.Director
rp.Director = func(r *http.Request) {
previous(r)
nextDirector(r)
}
}
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if p.AllowResponseBuffering {
nginx.AllowResponseBuffering(w)
}
// If the ultimate client disconnects when the response isn't fully written
// to them yet, httputil.ReverseProxy panics with a net/http.ErrAbortHandler
// error. We can catch and discard this to keep the error log clean
defer func() {
if err := recover(); err != nil {
if err != http.ErrAbortHandler {
panic(err)
}
}
}()
p.reverseProxy.ServeHTTP(w, r)
}
type bufferPool struct {
pool sync.Pool
}
func newBufferPool() *bufferPool {
return &bufferPool{
pool: sync.Pool{
New: func() any {
return make([]byte, bufferPoolSize)
},
},
}
}
func (bp *bufferPool) Get() []byte {
return bp.pool.Get().([]byte)
}
func (bp *bufferPool) Put(v []byte) {
bp.pool.Put(v) //lint:ignore SA6002 we either allocate manually to satisfy the linter or let the compiler allocate for us and silence the linter
}
|