1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
|
package forward
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/v2/testutils"
)
func TestDefaultErrHandler(t *testing.T) {
f := New(true)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = testutils.ParseURI("http://localhost:63450")
f.ServeHTTP(w, req)
}))
t.Cleanup(proxy.Close)
resp, err := http.Get(proxy.URL)
require.NoError(t, err)
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
}
func TestXForwardedHostHeader(t *testing.T) {
tests := []struct {
Description string
PassHostHeader bool
TargetURL string
ProxyfiedURL string
ExpectedXForwardedHost string
}{
{
Description: "XForwardedHost without PassHostHeader",
PassHostHeader: false,
TargetURL: "http://xforwardedhost.com",
ProxyfiedURL: "http://backend.com",
ExpectedXForwardedHost: "xforwardedhost.com",
},
{
Description: "XForwardedHost with PassHostHeader",
PassHostHeader: true,
TargetURL: "http://xforwardedhost.com",
ProxyfiedURL: "http://backend.com",
ExpectedXForwardedHost: "xforwardedhost.com",
},
}
for _, test := range tests {
test := test
t.Run(test.Description, func(t *testing.T) {
t.Parallel()
f := New(true)
r, err := http.NewRequest(http.MethodGet, test.TargetURL, nil)
require.NoError(t, err)
backendURL, err := url.Parse(test.ProxyfiedURL)
require.NoError(t, err)
r.URL = backendURL
f.Director(r)
require.Equal(t, test.ExpectedXForwardedHost, r.Header.Get(XForwardedHost))
})
}
}
func TestForwardedProto(t *testing.T) {
var proto string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
proto = req.Header.Get(XForwardedProto)
_, _ = w.Write([]byte("hello"))
}))
t.Cleanup(srv.Close)
f := New(true)
proxy := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = testutils.ParseURI(srv.URL)
f.ServeHTTP(w, req)
}))
proxy.StartTLS()
t.Cleanup(proxy.Close)
re, _, err := testutils.Get(proxy.URL)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, re.StatusCode)
assert.Equal(t, "https", proto)
}
|