1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
|
package httprc
import (
"context"
"fmt"
"net/http"
"sync"
)
type fetchRequest struct {
mu sync.RWMutex
// client contains the HTTP Client that can be used to make a
// request. By setting a custom *http.Client, you can for example
// provide a custom http.Transport
//
// If not specified, http.DefaultClient will be used.
client HTTPClient
wl Whitelist
// u contains the URL to be fetched
url string
// reply is a field that is only used by the internals of the fetcher
// it is used to return the result of fetching
reply chan *fetchResult
}
type fetchResult struct {
mu sync.RWMutex
res *http.Response
err error
}
func (fr *fetchResult) reply(ctx context.Context, reply chan *fetchResult) error {
select {
case <-ctx.Done():
return ctx.Err()
case reply <- fr:
}
close(reply)
return nil
}
type fetcher struct {
requests chan *fetchRequest
}
type Fetcher interface {
Fetch(context.Context, string, ...FetchOption) (*http.Response, error)
fetch(context.Context, *fetchRequest) (*http.Response, error)
}
func NewFetcher(ctx context.Context, options ...FetcherOption) Fetcher {
var nworkers int
var wl Whitelist
for _, option := range options {
//nolint:forcetypeassert
switch option.Ident() {
case identFetcherWorkerCount{}:
nworkers = option.Value().(int)
case identWhitelist{}:
wl = option.Value().(Whitelist)
}
}
if nworkers < 1 {
nworkers = 3
}
incoming := make(chan *fetchRequest)
for i := 0; i < nworkers; i++ {
go runFetchWorker(ctx, incoming, wl)
}
return &fetcher{
requests: incoming,
}
}
func (f *fetcher) Fetch(ctx context.Context, u string, options ...FetchOption) (*http.Response, error) {
var client HTTPClient
var wl Whitelist
for _, option := range options {
//nolint:forcetypeassert
switch option.Ident() {
case identHTTPClient{}:
client = option.Value().(HTTPClient)
case identWhitelist{}:
wl = option.Value().(Whitelist)
}
}
req := fetchRequest{
client: client,
url: u,
wl: wl,
}
return f.fetch(ctx, &req)
}
// fetch (unexported) is the main fetching implemntation.
// it allows the caller to reuse the same *fetchRequest object
func (f *fetcher) fetch(ctx context.Context, req *fetchRequest) (*http.Response, error) {
reply := make(chan *fetchResult, 1)
req.mu.Lock()
req.reply = reply
req.mu.Unlock()
// Send a request to the backend
select {
case <-ctx.Done():
return nil, ctx.Err()
case f.requests <- req:
}
// wait until we get a reply
select {
case <-ctx.Done():
return nil, ctx.Err()
case fr := <-reply:
fr.mu.RLock()
res := fr.res
err := fr.err
fr.mu.RUnlock()
return res, err
}
}
func runFetchWorker(ctx context.Context, incoming chan *fetchRequest, wl Whitelist) {
LOOP:
for {
select {
case <-ctx.Done():
break LOOP
case req := <-incoming:
req.mu.RLock()
reply := req.reply
client := req.client
if client == nil {
client = http.DefaultClient
}
url := req.url
reqwl := req.wl
req.mu.RUnlock()
var wls []Whitelist
for _, v := range []Whitelist{wl, reqwl} {
if v != nil {
wls = append(wls, v)
}
}
if len(wls) > 0 {
for _, wl := range wls {
if !wl.IsAllowed(url) {
r := &fetchResult{
err: fmt.Errorf(`fetching url %q rejected by whitelist`, url),
}
if err := r.reply(ctx, reply); err != nil {
break LOOP
}
continue LOOP
}
}
}
// The body is handled by the consumer of the fetcher
//nolint:bodyclose
res, err := client.Get(url)
r := &fetchResult{
res: res,
err: err,
}
if err := r.reply(ctx, reply); err != nil {
break LOOP
}
}
}
}
|