1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
|
// Copyright 2021 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package h2 contains basic HTTP/2 handling for Martian.
package h2
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"fmt"
"io"
"net/url"
"sync"
"github.com/google/martian/v3/log"
"golang.org/x/net/http2"
)
var (
// connectionPreface is the constant value of the connection preface.
// https://tools.ietf.org/html/rfc7540#section-3.5
connectionPreface = []byte("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
)
// Config stores the configuration information needed for HTTP/2 processing.
type Config struct {
// AllowedHostsFilter is a function returning true if the argument is a host for which H2 is
// permitted.
AllowedHostsFilter func(string) bool
// RootCAs is the pool of CA certificates used by the MitM client to authenticate the server.
RootCAs *x509.CertPool
// StreamProcessorFactories is a list of factories used to instantiate a chain of HTTP/2 stream
// processors. A chain is created for every stream.
StreamProcessorFactories []StreamProcessorFactory
// EnableDebugLogs turns on fine-grained debug logging for HTTP/2.
EnableDebugLogs bool
}
// Proxy proxies HTTP/2 traffic between a client connection, `cc`, and the HTTP/2 `url` assuming
// h2 is being used. Since no browsers use h2c, it's safe to assume all traffic uses TLS.
func (c *Config) Proxy(closing chan bool, cc io.ReadWriter, url *url.URL) error {
if c.EnableDebugLogs {
log.Infof("\u001b[1;35mProxying %v with HTTP/2\u001b[0m", url)
}
sc, err := tls.Dial("tcp", url.Host, &tls.Config{
RootCAs: c.RootCAs,
NextProtos: []string{"h2"},
})
if err != nil {
return fmt.Errorf("connecting h2 to %v: %w", url, err)
}
if err := forwardPreface(sc, cc); err != nil {
return fmt.Errorf("initializing h2 with %v: %w", url, err)
}
cf, sf := http2.NewFramer(cc, cc), http2.NewFramer(sc, sc)
cToS := newRelay(ClientToServer, "client", url.String(), cf, sf, &c.EnableDebugLogs)
sToC := newRelay(ServerToClient, url.String(), "client", sf, cf, &c.EnableDebugLogs)
// Completes circular parts of the initialization.
// The client-to-server relay depends on the server-to-client relay and vice versa.
cToS.peer, sToC.peer = sToC, cToS
// Creating processors is circular because the create function references the relays and the
// relays need to call create.
cToS.processors = &streamProcessors{
create: func(id uint32) *Processors {
p := &Processors{cToS: &relayAdapter{id, cToS}, sToC: &relayAdapter{id, sToC}}
// Chains the pipeline of processors together.
for i := len(c.StreamProcessorFactories) - 1; i >= 0; i-- {
cToS, sToC := c.StreamProcessorFactories[i](url, p)
// Bypasses any nil processors.
if cToS == nil {
cToS = p.ForDirection(ClientToServer)
}
if sToC == nil {
sToC = p.ForDirection(ServerToClient)
}
p = &Processors{cToS: cToS, sToC: sToC}
}
return p
},
}
sToC.processors = cToS.processors
var wg sync.WaitGroup
wg.Add(2)
go func() { // Forwards frames from client to server.
defer wg.Done()
if err := cToS.relayFrames(closing); err != nil {
log.Errorf("relaying frame from client to %v: %v", url, err)
}
}()
go func() { // Forwards frames from server to client.
defer wg.Done()
if err := sToC.relayFrames(closing); err != nil {
log.Errorf("relaying frame from %v to client: %v", url, err)
}
}()
wg.Wait()
return nil
}
// forwardPreface forwards the connection preface from the client to the server.
func forwardPreface(server io.Writer, client io.Reader) error {
preface := make([]byte, len(connectionPreface))
if _, err := client.Read(preface); err != nil {
return fmt.Errorf("reading preface: %w", err)
}
if !bytes.Equal(preface, connectionPreface) {
return fmt.Errorf("client sent unexpected preface: %s", hex.Dump(preface))
}
for m := len(connectionPreface); m > 0; {
n, err := server.Write([]byte(preface))
if err != nil {
return fmt.Errorf("writing preface: %w", err)
}
preface = preface[n:]
m -= n
}
return nil
}
type streamProcessors struct {
// processors stores `*Processors` instances keyed by uint32 stream ID.
processors sync.Map
// create creates `*Processors` for the given stream ID.
create func(uint32) *Processors
}
// Get returns a the processor with the given ID and direction.
func (s *streamProcessors) Get(id uint32, dir Direction) Processor {
value, ok := s.processors.Load(id)
if !ok {
value, _ = s.processors.LoadOrStore(id, s.create(id))
}
return value.(*Processors).ForDirection(dir)
}
|