1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
package api
import (
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/utils"
ghAPI "github.com/cli/go-gh/v2/pkg/api"
)
type tokenGetter interface {
ActiveToken(string) (string, string)
}
type HTTPClientOptions struct {
AppVersion string
CacheTTL time.Duration
Config tokenGetter
EnableCache bool
Log io.Writer
LogColorize bool
LogVerboseHTTP bool
}
func NewHTTPClient(opts HTTPClientOptions) (*http.Client, error) {
// Provide invalid host, and token values so gh.HTTPClient will not automatically resolve them.
// The real host and token are inserted at request time.
clientOpts := ghAPI.ClientOptions{
Host: "none",
AuthToken: "none",
LogIgnoreEnv: true,
}
debugEnabled, debugValue := utils.IsDebugEnabled()
if strings.Contains(debugValue, "api") {
opts.LogVerboseHTTP = true
}
if opts.LogVerboseHTTP || debugEnabled {
clientOpts.Log = opts.Log
clientOpts.LogColorize = opts.LogColorize
clientOpts.LogVerboseHTTP = opts.LogVerboseHTTP
}
headers := map[string]string{
userAgent: fmt.Sprintf("GitHub CLI %s", opts.AppVersion),
}
clientOpts.Headers = headers
if opts.EnableCache {
clientOpts.EnableCache = opts.EnableCache
clientOpts.CacheTTL = opts.CacheTTL
}
client, err := ghAPI.NewHTTPClient(clientOpts)
if err != nil {
return nil, err
}
if opts.Config != nil {
client.Transport = AddAuthTokenHeader(client.Transport, opts.Config)
}
return client, nil
}
func NewCachedHTTPClient(httpClient *http.Client, ttl time.Duration) *http.Client {
newClient := *httpClient
newClient.Transport = AddCacheTTLHeader(httpClient.Transport, ttl)
return &newClient
}
// AddCacheTTLHeader adds an header to the request telling the cache that the request
// should be cached for a specified amount of time.
func AddCacheTTLHeader(rt http.RoundTripper, ttl time.Duration) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
// If the header is already set in the request, don't overwrite it.
if req.Header.Get(cacheTTL) == "" {
req.Header.Set(cacheTTL, ttl.String())
}
return rt.RoundTrip(req)
}}
}
// AddAuthToken adds an authentication token header for the host specified by the request.
func AddAuthTokenHeader(rt http.RoundTripper, cfg tokenGetter) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
// If the header is already set in the request, don't overwrite it.
if req.Header.Get(authorization) == "" {
var redirectHostnameChange bool
if req.Response != nil && req.Response.Request != nil {
redirectHostnameChange = getHost(req) != getHost(req.Response.Request)
}
// Only set header if an initial request or redirect request to the same host as the initial request.
// If the host has changed during a redirect do not add the authentication token header.
if !redirectHostnameChange {
hostname := ghinstance.NormalizeHostname(getHost(req))
if token, _ := cfg.ActiveToken(hostname); token != "" {
req.Header.Set(authorization, fmt.Sprintf("token %s", token))
}
}
}
return rt.RoundTrip(req)
}}
}
// ExtractHeader extracts a named header from any response received by this client and,
// if non-blank, saves it to dest.
func ExtractHeader(name string, dest *string) func(http.RoundTripper) http.RoundTripper {
return func(tr http.RoundTripper) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
res, err := tr.RoundTrip(req)
if err == nil {
if value := res.Header.Get(name); value != "" {
*dest = value
}
}
return res, err
}}
}
}
type funcTripper struct {
roundTrip func(*http.Request) (*http.Response, error)
}
func (tr funcTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return tr.roundTrip(req)
}
func getHost(r *http.Request) string {
if r.Host != "" {
return r.Host
}
return r.URL.Host
}
|