File: fetcher.go

package info (click to toggle)
golang-github-lestrrat-go-httprc 1.0.6-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 168 kB
  • sloc: perl: 56; sh: 6; makefile: 2
file content (182 lines) | stat: -rw-r--r-- 3,672 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
package httprc

import (
	"context"
	"fmt"
	"net/http"
	"sync"
)

type fetchRequest struct {
	mu sync.RWMutex

	// client contains the HTTP Client that can be used to make a
	// request. By setting a custom *http.Client, you can for example
	// provide a custom http.Transport
	//
	// If not specified, http.DefaultClient will be used.
	client HTTPClient

	wl Whitelist

	// u contains the URL to be fetched
	url string

	// reply is a field that is only used by the internals of the fetcher
	// it is used to return the result of fetching
	reply chan *fetchResult
}

type fetchResult struct {
	mu  sync.RWMutex
	res *http.Response
	err error
}

func (fr *fetchResult) reply(ctx context.Context, reply chan *fetchResult) error {
	select {
	case <-ctx.Done():
		return ctx.Err()
	case reply <- fr:
	}

	close(reply)
	return nil
}

type fetcher struct {
	requests chan *fetchRequest
}

type Fetcher interface {
	Fetch(context.Context, string, ...FetchOption) (*http.Response, error)
	fetch(context.Context, *fetchRequest) (*http.Response, error)
}

func NewFetcher(ctx context.Context, options ...FetcherOption) Fetcher {
	var nworkers int
	var wl Whitelist
	for _, option := range options {
		//nolint:forcetypeassert
		switch option.Ident() {
		case identFetcherWorkerCount{}:
			nworkers = option.Value().(int)
		case identWhitelist{}:
			wl = option.Value().(Whitelist)
		}
	}

	if nworkers < 1 {
		nworkers = 3
	}

	incoming := make(chan *fetchRequest)
	for i := 0; i < nworkers; i++ {
		go runFetchWorker(ctx, incoming, wl)
	}
	return &fetcher{
		requests: incoming,
	}
}

func (f *fetcher) Fetch(ctx context.Context, u string, options ...FetchOption) (*http.Response, error) {
	var client HTTPClient
	var wl Whitelist
	for _, option := range options {
		//nolint:forcetypeassert
		switch option.Ident() {
		case identHTTPClient{}:
			client = option.Value().(HTTPClient)
		case identWhitelist{}:
			wl = option.Value().(Whitelist)
		}
	}

	req := fetchRequest{
		client: client,
		url:    u,
		wl:     wl,
	}

	return f.fetch(ctx, &req)
}

// fetch (unexported) is the main fetching implemntation.
// it allows the caller to reuse the same *fetchRequest object
func (f *fetcher) fetch(ctx context.Context, req *fetchRequest) (*http.Response, error) {
	reply := make(chan *fetchResult, 1)
	req.mu.Lock()
	req.reply = reply
	req.mu.Unlock()

	// Send a request to the backend
	select {
	case <-ctx.Done():
		return nil, ctx.Err()
	case f.requests <- req:
	}

	// wait until we get a reply
	select {
	case <-ctx.Done():
		return nil, ctx.Err()
	case fr := <-reply:
		fr.mu.RLock()
		res := fr.res
		err := fr.err
		fr.mu.RUnlock()
		return res, err
	}
}

func runFetchWorker(ctx context.Context, incoming chan *fetchRequest, wl Whitelist) {
LOOP:
	for {
		select {
		case <-ctx.Done():
			break LOOP
		case req := <-incoming:
			req.mu.RLock()
			reply := req.reply
			client := req.client
			if client == nil {
				client = http.DefaultClient
			}
			url := req.url
			reqwl := req.wl
			req.mu.RUnlock()

			var wls []Whitelist
			for _, v := range []Whitelist{wl, reqwl} {
				if v != nil {
					wls = append(wls, v)
				}
			}

			if len(wls) > 0 {
				for _, wl := range wls {
					if !wl.IsAllowed(url) {
						r := &fetchResult{
							err: fmt.Errorf(`fetching url %q rejected by whitelist`, url),
						}
						if err := r.reply(ctx, reply); err != nil {
							break LOOP
						}
						continue LOOP
					}
				}
			}

			// The body is handled by the consumer of the fetcher
			//nolint:bodyclose
			res, err := client.Get(url)
			r := &fetchResult{
				res: res,
				err: err,
			}
			if err := r.reply(ctx, reply); err != nil {
				break LOOP
			}
		}
	}
}