1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
|
// Copyright 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tcpproxy lets users build TCP proxies, optionally making
// routing decisions based on HTTP/1 Host headers and the SNI hostname
// in TLS connections.
//
// Typical usage:
//
// var p tcpproxy.Proxy
// p.AddHTTPHostRoute(":80", "foo.com", tcpproxy.To("10.0.0.1:8081"))
// p.AddHTTPHostRoute(":80", "bar.com", tcpproxy.To("10.0.0.2:8082"))
// p.AddRoute(":80", tcpproxy.To("10.0.0.1:8081")) // fallback
// p.AddSNIRoute(":443", "foo.com", tcpproxy.To("10.0.0.1:4431"))
// p.AddSNIRoute(":443", "bar.com", tcpproxy.To("10.0.0.2:4432"))
// p.AddRoute(":443", tcpproxy.To("10.0.0.1:4431")) // fallback
// log.Fatal(p.Run())
//
// Calling Run (or Start) on a proxy also starts all the necessary
// listeners.
//
// For each accepted connection, the rules for that ipPort are
// matched, in order. If one matches (currently HTTP Host, SNI, or
// always), then the connection is handed to the target.
//
// The two predefined Target implementations are:
//
// 1) DialProxy, proxying to another address (use the To func to return a
// DialProxy value),
//
// 2) TargetListener, making the matched connection available via a
// net.Listener.Accept call.
//
// But Target is an interface, so you can also write your own.
//
// Note that tcpproxy does not do any TLS encryption or decryption. It
// only (via DialProxy) copies bytes around. The SNI hostname in the TLS
// header is unencrypted, for better or worse.
//
// This package makes no API stability promises. If you depend on it,
// vendor it.
package tcpproxy
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"log"
"net"
"time"
)
// Proxy is a proxy. Its zero value is a valid proxy that does
// nothing. Call methods to add routes before calling Start or Run.
//
// The order that routes are added in matters; each is matched in the order
// registered.
type Proxy struct {
configs map[string]*config // ip:port => config
lns []net.Listener
donec chan struct{} // closed before err
err error // any error from listening
// ListenFunc optionally specifies an alternate listen
// function. If nil, net.Dial is used.
// The provided net is always "tcp".
ListenFunc func(net, laddr string) (net.Listener, error)
}
// Matcher reports whether hostname matches the Matcher's criteria.
type Matcher func(ctx context.Context, hostname string) bool
// equals is a trivial Matcher that implements string equality.
func equals(want string) Matcher {
return func(_ context.Context, got string) bool {
return want == got
}
}
// config contains the proxying state for one listener.
type config struct {
routes []route
}
// A route matches a connection to a target.
type route interface {
// match examines the initial bytes of a connection, looking for a
// match. If a match is found, match returns a non-nil Target to
// which the stream should be proxied. match returns nil if the
// connection doesn't match.
//
// match must not consume bytes from the given bufio.Reader, it
// can only Peek.
//
// If an sni or host header was parsed successfully, that will be
// returned as the second parameter.
match(*bufio.Reader) (Target, string)
}
func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) {
if p.ListenFunc != nil {
return p.ListenFunc
}
return net.Listen
}
func (p *Proxy) configFor(ipPort string) *config {
if p.configs == nil {
p.configs = make(map[string]*config)
}
if p.configs[ipPort] == nil {
p.configs[ipPort] = &config{}
}
return p.configs[ipPort]
}
func (p *Proxy) addRoute(ipPort string, r route) {
cfg := p.configFor(ipPort)
cfg.routes = append(cfg.routes, r)
}
// AddRoute appends an always-matching route to the ipPort listener,
// directing any connection to dest.
//
// This is generally used as either the only rule (for simple TCP
// proxies), or as the final fallback rule for an ipPort.
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddRoute(ipPort string, dest Target) {
p.addRoute(ipPort, fixedTarget{dest})
}
type fixedTarget struct {
t Target
}
func (m fixedTarget) match(*bufio.Reader) (Target, string) { return m.t, "" }
// Run is calls Start, and then Wait.
//
// It blocks until there's an error. The return value is always
// non-nil.
func (p *Proxy) Run() error {
if err := p.Start(); err != nil {
return err
}
return p.Wait()
}
// Wait waits for the Proxy to finish running. Currently this can only
// happen if a Listener is closed, or Close is called on the proxy.
//
// It is only valid to call Wait after a successful call to Start.
func (p *Proxy) Wait() error {
<-p.donec
return p.err
}
// Close closes all the proxy's self-opened listeners.
func (p *Proxy) Close() error {
for _, c := range p.lns {
c.Close()
}
return nil
}
// Start creates a TCP listener for each unique ipPort from the
// previously created routes and starts the proxy. It returns any
// error from starting listeners.
//
// If it returns a non-nil error, any successfully opened listeners
// are closed.
func (p *Proxy) Start() error {
if p.donec != nil {
return errors.New("already started")
}
p.donec = make(chan struct{})
errc := make(chan error, len(p.configs))
p.lns = make([]net.Listener, 0, len(p.configs))
for ipPort, config := range p.configs {
ln, err := p.netListen()("tcp", ipPort)
if err != nil {
p.Close()
return err
}
p.lns = append(p.lns, ln)
go p.serveListener(errc, ln, config.routes)
}
go p.awaitFirstError(errc)
return nil
}
func (p *Proxy) awaitFirstError(errc <-chan error) {
p.err = <-errc
close(p.donec)
}
func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) {
for {
c, err := ln.Accept()
if err != nil {
ret <- err
return
}
go p.serveConn(c, routes)
}
}
// serveConn runs in its own goroutine and matches c against routes.
// It returns whether it matched purely for testing.
func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
br := bufio.NewReader(c)
for _, route := range routes {
if target, hostName := route.match(br); target != nil {
if n := br.Buffered(); n > 0 {
peeked, _ := br.Peek(br.Buffered())
c = &Conn{
HostName: hostName,
Peeked: peeked,
Conn: c,
}
}
target.HandleConn(c)
return true
}
}
// TODO: hook for this?
log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String())
c.Close()
return false
}
// Conn is an incoming connection that has had some bytes read from it
// to determine how to route the connection. The Read method stitches
// the peeked bytes and unread bytes back together.
type Conn struct {
// HostName is the hostname field that was sent to the request router.
// In the case of TLS, this is the SNI header, in the case of HTTPHost
// route, it will be the host header. In the case of a fixed
// route, i.e. those created with AddRoute(), this will always be
// empty. This can be useful in the case where further routing decisions
// need to be made in the Target impementation.
HostName string
// Peeked are the bytes that have been read from Conn for the
// purposes of route matching, but have not yet been consumed
// by Read calls. It set to nil by Read when fully consumed.
Peeked []byte
// Conn is the underlying connection.
// It can be type asserted against *net.TCPConn or other types
// as needed. It should not be read from directly unless
// Peeked is nil.
net.Conn
}
func (c *Conn) Read(p []byte) (n int, err error) {
if len(c.Peeked) > 0 {
n = copy(p, c.Peeked)
c.Peeked = c.Peeked[n:]
if len(c.Peeked) == 0 {
c.Peeked = nil
}
return n, nil
}
return c.Conn.Read(p)
}
// Target is what an incoming matched connection is sent to.
type Target interface {
// HandleConn is called when an incoming connection is
// matched. After the call to HandleConn, the tcpproxy
// package never touches the conn again. Implementations are
// responsible for closing the connection when needed.
//
// The concrete type of conn will be of type *Conn if any
// bytes have been consumed for the purposes of route
// matching.
HandleConn(net.Conn)
}
// To is shorthand way of writing &tcpproxy.DialProxy{Addr: addr}.
func To(addr string) *DialProxy {
return &DialProxy{Addr: addr}
}
// DialProxy implements Target by dialing a new connection to Addr
// and then proxying data back and forth.
//
// The To func is a shorthand way of creating a DialProxy.
type DialProxy struct {
// Addr is the TCP address to proxy to.
Addr string
// KeepAlivePeriod sets the period between TCP keep alives.
// If zero, a default is used. To disable, use a negative number.
// The keep-alive is used for both the client connection and
KeepAlivePeriod time.Duration
// DialTimeout optionally specifies a dial timeout.
// If zero, a default is used.
// If negative, the timeout is disabled.
DialTimeout time.Duration
// DialContext optionally specifies an alternate dial function
// for TCP targets. If nil, the standard
// net.Dialer.DialContext method is used.
DialContext func(ctx context.Context, network, address string) (net.Conn, error)
// OnDialError optionally specifies an alternate way to handle errors dialing Addr.
// If nil, the error is logged and src is closed.
// If non-nil, src is not closed automatically.
OnDialError func(src net.Conn, dstDialErr error)
// ProxyProtocolVersion optionally specifies the version of
// HAProxy's PROXY protocol to use. The PROXY protocol provides
// connection metadata to the DialProxy target, via a header
// inserted ahead of the client's traffic. The DialProxy target
// must explicitly support and expect the PROXY header; there is
// no graceful downgrade.
// If zero, no PROXY header is sent. Currently, version 1 is supported.
ProxyProtocolVersion int
}
// UnderlyingConn returns c.Conn if c of type *Conn,
// otherwise it returns c.
func UnderlyingConn(c net.Conn) net.Conn {
if wrap, ok := c.(*Conn); ok {
return wrap.Conn
}
return c
}
func tcpConn(c net.Conn) (t *net.TCPConn, ok bool) {
if c, ok := UnderlyingConn(c).(*net.TCPConn); ok {
return c, ok
}
if c, ok := c.(*net.TCPConn); ok {
return c, ok
}
return nil, false
}
func goCloseConn(c net.Conn) { go c.Close() }
func closeRead(c net.Conn) {
if c, ok := tcpConn(c); ok {
c.CloseRead()
}
}
func closeWrite(c net.Conn) {
if c, ok := tcpConn(c); ok {
c.CloseWrite()
}
}
// HandleConn implements the Target interface.
func (dp *DialProxy) HandleConn(src net.Conn) {
ctx := context.Background()
var cancel context.CancelFunc
if dp.DialTimeout >= 0 {
ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout())
}
dst, err := dp.dialContext()(ctx, "tcp", dp.Addr)
if cancel != nil {
cancel()
}
if err != nil {
dp.onDialError()(src, err)
return
}
defer goCloseConn(dst)
if err = dp.sendProxyHeader(dst, src); err != nil {
dp.onDialError()(src, err)
return
}
defer goCloseConn(src)
if ka := dp.keepAlivePeriod(); ka > 0 {
for _, c := range []net.Conn{src, dst} {
if c, ok := tcpConn(c); ok {
c.SetKeepAlive(true)
c.SetKeepAlivePeriod(ka)
}
}
}
errc := make(chan error, 2)
go proxyCopy(errc, src, dst)
go proxyCopy(errc, dst, src)
<-errc
<-errc
}
func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
switch dp.ProxyProtocolVersion {
case 0:
return nil
case 1:
var srcAddr, dstAddr *net.TCPAddr
if a, ok := src.RemoteAddr().(*net.TCPAddr); ok {
srcAddr = a
}
if a, ok := src.LocalAddr().(*net.TCPAddr); ok {
dstAddr = a
}
if srcAddr == nil || dstAddr == nil {
_, err := io.WriteString(w, "PROXY UNKNOWN\r\n")
return err
}
family := "TCP4"
if srcAddr.IP.To4() == nil {
family = "TCP6"
}
_, err := fmt.Fprintf(w, "PROXY %s %s %s %d %d\r\n", family, srcAddr.IP, dstAddr.IP, srcAddr.Port, dstAddr.Port)
return err
default:
return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion)
}
}
// proxyCopy is the function that copies bytes around.
// It's a named function instead of a func literal so users get
// named goroutines in debug goroutine stack dumps.
func proxyCopy(errc chan<- error, dst, src net.Conn) {
defer closeRead(src)
defer closeWrite(dst)
// Before we unwrap src and/or dst, copy any buffered data.
if wc, ok := src.(*Conn); ok && len(wc.Peeked) > 0 {
if _, err := dst.Write(wc.Peeked); err != nil {
errc <- err
return
}
wc.Peeked = nil
}
// Unwrap the src and dst from *Conn to *net.TCPConn so Go
// 1.11's splice optimization kicks in.
src = UnderlyingConn(src)
dst = UnderlyingConn(dst)
_, err := io.Copy(dst, src)
errc <- err
}
func (dp *DialProxy) keepAlivePeriod() time.Duration {
if dp.KeepAlivePeriod != 0 {
return dp.KeepAlivePeriod
}
return time.Minute
}
func (dp *DialProxy) dialTimeout() time.Duration {
if dp.DialTimeout > 0 {
return dp.DialTimeout
}
return 10 * time.Second
}
var defaultDialer = new(net.Dialer)
func (dp *DialProxy) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) {
if dp.DialContext != nil {
return dp.DialContext
}
return defaultDialer.DialContext
}
func (dp *DialProxy) onDialError() func(src net.Conn, dstDialErr error) {
if dp.OnDialError != nil {
return dp.OnDialError
}
return func(src net.Conn, dstDialErr error) {
log.Printf("tcpproxy: for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), dp.Addr, dstDialErr)
src.Close()
}
}
|