1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
|
// FIXME(thaJeztah): remove once we are a module; the go:build directive prevents go from downgrading language version to go1.16:
//go:build go1.23
package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"runtime"
"strings"
"time"
"github.com/docker/cli/cli/version"
)
type OAuthAPI interface {
GetDeviceCode(ctx context.Context, audience string) (State, error)
WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error)
RevokeToken(ctx context.Context, refreshToken string) error
GetAutoPAT(ctx context.Context, audience string, res TokenResponse) (string, error)
}
// API represents API interactions with Auth0.
type API struct {
// TenantURL is the base used for each request to Auth0.
TenantURL string
// ClientID is the client ID for the application to auth with the tenant.
ClientID string
// Scopes are the scopes that are requested during the device auth flow.
Scopes []string
}
// TokenResponse represents the response of the /oauth/token route.
type TokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Error *string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
}
var ErrTimeout = errors.New("timed out waiting for device token")
// GetDeviceCode initiates the device-code auth flow with the tenant.
// The state returned contains the device code that the user must use to
// authenticate, as well as the URL to visit, etc.
func (a API) GetDeviceCode(ctx context.Context, audience string) (State, error) {
data := url.Values{
"client_id": {a.ClientID},
"audience": {audience},
"scope": {strings.Join(a.Scopes, " ")},
}
deviceCodeURL := a.TenantURL + "/oauth/device/code"
resp, err := postForm(ctx, deviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
return State{}, err
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return State{}, tryDecodeOAuthError(resp)
}
var state State
err = json.NewDecoder(resp.Body).Decode(&state)
if err != nil {
return state, fmt.Errorf("failed to get device code: %w", err)
}
return state, nil
}
func tryDecodeOAuthError(resp *http.Response) error {
var body map[string]any
if err := json.NewDecoder(resp.Body).Decode(&body); err == nil {
if errorDescription, ok := body["error_description"].(string); ok {
return errors.New(errorDescription)
}
}
return errors.New("unexpected response from tenant: " + resp.Status)
}
// WaitForDeviceToken polls the tenant to get access/refresh tokens for the user.
// This should be called after GetDeviceCode, and will block until the user has
// authenticated or we have reached the time limit for authenticating (based on
// the response from GetDeviceCode).
func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) {
// Ticker for polling tenant for login – based on the interval
// specified by the tenant response.
ticker := time.NewTimer(state.IntervalDuration())
defer ticker.Stop()
// The tenant tells us for as long as we can poll it for credentials
// while the user logs in through their browser. Timeout if we don't get
// credentials within this period.
timeout := time.NewTimer(state.ExpiryDuration())
defer timeout.Stop()
for {
resetTimer(ticker, state.IntervalDuration())
select {
case <-ctx.Done():
// user canceled login
return TokenResponse{}, ctx.Err()
case <-ticker.C:
// tick, check for user login
res, err := a.getDeviceToken(ctx, state)
if err != nil {
if errors.Is(err, context.Canceled) {
// if the caller canceled the context, continue
// and let the select hit the ctx.Done() branch
continue
}
return TokenResponse{}, err
}
if res.Error != nil {
if *res.Error == "authorization_pending" {
continue
}
return res, errors.New(res.ErrorDescription)
}
return res, nil
case <-timeout.C:
// login timed out
return TokenResponse{}, ErrTimeout
}
}
}
// resetTimer is a helper function thatstops, drains and resets the timer.
// This is necessary in go versions <1.23, since the timer isn't stopped +
// the timer's channel isn't drained on timer.Reset.
// See: https://go-review.googlesource.com/c/go/+/568341
// FIXME: remove/simplify this after we update to go1.23
func resetTimer(t *time.Timer, d time.Duration) {
if !t.Stop() {
select {
case <-t.C:
default:
}
}
t.Reset(d)
}
// getDeviceToken calls the token endpoint of Auth0 and returns the response.
func (a API) getDeviceToken(ctx context.Context, state State) (TokenResponse, error) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
data := url.Values{
"client_id": {a.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {state.DeviceCode},
}
oauthTokenURL := a.TenantURL + "/oauth/token"
resp, err := postForm(ctx, oauthTokenURL, strings.NewReader(data.Encode()))
if err != nil {
return TokenResponse{}, fmt.Errorf("failed to get tokens: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
// this endpoint returns a 403 with an `authorization_pending` error until the
// user has authenticated, so we don't check the status code here and instead
// decode the response and check for the error.
var res TokenResponse
err = json.NewDecoder(resp.Body).Decode(&res)
if err != nil {
return res, fmt.Errorf("failed to decode response: %w", err)
}
return res, nil
}
// RevokeToken revokes a refresh token with the tenant so that it can no longer
// be used to get new tokens.
func (a API) RevokeToken(ctx context.Context, refreshToken string) error {
data := url.Values{
"client_id": {a.ClientID},
"token": {refreshToken},
}
revokeURL := a.TenantURL + "/oauth/revoke"
resp, err := postForm(ctx, revokeURL, strings.NewReader(data.Encode()))
if err != nil {
return err
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return tryDecodeOAuthError(resp)
}
return nil
}
func postForm(ctx context.Context, reqURL string, data io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, data)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
cliVersion := strings.ReplaceAll(version.Version, ".", "_")
req.Header.Set("User-Agent", fmt.Sprintf("docker-cli:%s:%s-%s", cliVersion, runtime.GOOS, runtime.GOARCH))
return http.DefaultClient.Do(req)
}
func (API) GetAutoPAT(ctx context.Context, audience string, res TokenResponse) (string, error) {
patURL := audience + "/v2/access-tokens/desktop-generate"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, patURL, nil)
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+res.AccessToken)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusCreated {
return "", fmt.Errorf("unexpected response from Hub: %s", resp.Status)
}
var response patGenerateResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return "", err
}
return response.Data.Token, nil
}
type patGenerateResponse struct {
Data struct {
Token string `json:"token"`
}
}
|