File: callback-server.go

package info (click to toggle)
golang-github-minio-pkg 3.4.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,640 kB
  • sloc: xml: 37; makefile: 35; asm: 22
file content (150 lines) | stat: -rw-r--r-- 4,628 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
// Copyright (c) 2015-2025 MinIO, Inc.
//
// # This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.

package oidc

import (
	"context"
	"crypto/rand"
	"errors"
	"fmt"
	"net"
	"net/http"
	"time"

	"github.com/minio/minio-go/v7/pkg/credentials"
)

const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"

// randStr generates a random string of length n using the alphabet constant.
func randStr(n int) (string, error) {
	b := make([]byte, n)
	if _, err := rand.Read(b); err != nil {
		return "", err
	}
	// Map random bytes to alphabet
	for i := 0; i < n; i++ {
		b[i] = alphabet[int(b[i])%len(alphabet)]
	}
	return string(b), nil
}

// CallbackServer represents a local HTTP server that handles OAuth callback redirects.
type CallbackServer struct {
	port      int
	reqID     string
	credsChan chan credentials.Value
	errChan   chan error
	server    *http.Server
}

// NewCallbackServer creates and starts a new callback server on a random available port.
// The server will be automatically shut down when the provided context is canceled.
func NewCallbackServer(ctx context.Context) (*CallbackServer, error) {
	reqID, err := randStr(16)
	if err != nil {
		return nil, fmt.Errorf("failed to generate request ID: %w", err)
	}

	// Start a local HTTP listener on a random available port
	listener, err := net.Listen("tcp", "localhost:0")
	if err != nil {
		return nil, fmt.Errorf("failed to start listener: %w", err)
	}

	// Get the actual port that was assigned
	addr := listener.Addr().(*net.TCPAddr)
	port := addr.Port

	cs := &CallbackServer{
		port:      port,
		reqID:     reqID,
		credsChan: make(chan credentials.Value, 1),
		errChan:   make(chan error, 1),
	}

	// Start HTTP server to handle the callback
	mux := http.NewServeMux()
	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
		// Parse credentials from query parameters
		code := r.URL.Query().Get("code")
		if code == "" {
			http.Error(w, "Missing code parameter", http.StatusBadRequest)
			return
		}

		creds, err := ParseSignedCredentials(code, reqID)
		if err != nil {
			http.Error(w, "Invalid code parameter: "+err.Error(), http.StatusBadRequest)
			return
		}

		// Send success response
		w.WriteHeader(http.StatusOK)
		_, _ = fmt.Fprintf(w, "Credentials received successfully. You can close this window.")

		// Send credentials through channel
		cs.credsChan <- creds
	})

	cs.server = &http.Server{Handler: mux}
	go func() {
		if err := cs.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
			cs.errChan <- err
		}
	}()

	// Shutdown server when context is canceled
	go func() {
		<-ctx.Done()
		// Use a separate context with timeout for graceful shutdown
		shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		_ = cs.server.Shutdown(shutdownCtx)
	}()

	return cs, nil
}

type reqClient interface {
	GetOpenIDLoginURL(ctx context.Context, reqID, configName string, port int) (string, error)
}

// GetLoginURL retrieves the OpenID login URL from the server using the anonymous client.
func (cs *CallbackServer) GetLoginURL(ctx context.Context, client reqClient, configName string) (string, error) {
	loginURL, err := client.GetOpenIDLoginURL(ctx, cs.reqID, configName, cs.port)
	if err != nil {
		return "", fmt.Errorf("failed to get login URL: %w", err)
	}
	return loginURL, nil
}

// WaitForCredentials waits for credentials to be received via the callback or for an error/timeout.
func (cs *CallbackServer) WaitForCredentials(ctx context.Context) (credentials.Value, error) {
	select {
	case creds := <-cs.credsChan:
		return creds, nil
	case err := <-cs.errChan:
		return credentials.Value{}, fmt.Errorf("callback server error: %w", err)
	case <-ctx.Done():
		if errors.Is(ctx.Err(), context.DeadlineExceeded) {
			return credentials.Value{}, fmt.Errorf("timeout waiting for authentication callback")
		}
		return credentials.Value{}, fmt.Errorf("authentication canceled: %w", ctx.Err())
	}
}