File: channel_settings.go

package info (click to toggle)
gitlab 17.6.5-19
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 629,368 kB
  • sloc: ruby: 1,915,304; javascript: 557,307; sql: 60,639; xml: 6,509; sh: 4,567; makefile: 1,239; python: 406
file content (139 lines) | stat: -rw-r--r-- 3,664 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
// Package api provides internal APIs for gitlab-workhorse.
package api

import (
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"net/http"
	"net/url"

	"github.com/gorilla/websocket"
	"gitlab.com/gitlab-org/labkit/log"
)

// ChannelSettings holds the configuration settings for establishing a websocket channel.
type ChannelSettings struct {
	// The channel provider may require use of a particular subprotocol. If so,
	// it must be specified here, and Workhorse must have a matching codec.
	Subprotocols []string

	// The websocket URL to connect to.
	Url string //nolint:revive,stylecheck // when JSON decoding, the value is provided via 'url'

	// Any headers (e.g., Authorization) to send with the websocket request
	Header http.Header

	// The CA roots to validate the remote endpoint with, for wss:// URLs. The
	// system-provided CA pool will be used if this is blank. PEM-encoded data.
	CAPem string

	// The value is specified in seconds. It is converted to time.Duration
	// later.
	MaxSessionTime int
}

// URL parses the websocket URL in the ChannelSettings and returns a *url.URL.
func (t *ChannelSettings) URL() (*url.URL, error) {
	return url.Parse(t.Url)
}

// Dialer returns a websocket Dialer configured with the settings from ChannelSettings.
func (t *ChannelSettings) Dialer() *websocket.Dialer {
	dialer := &websocket.Dialer{
		Subprotocols: t.Subprotocols,
	}

	pool, err := x509.SystemCertPool()
	if err != nil {
		log.WithError(err).Print("failed to load system cert pool")
		pool = x509.NewCertPool()
	}

	if len(t.CAPem) > 0 {
		pool.AppendCertsFromPEM([]byte(t.CAPem))
	}

	dialer.TLSClientConfig = &tls.Config{RootCAs: pool}
	return dialer
}

// Clone creates and returns a deep copy of the ChannelSettings instance.
func (t *ChannelSettings) Clone() *ChannelSettings {
	// Doesn't clone the strings, but that's OK as strings are immutable in go
	cloned := *t
	cloned.Header = t.Header.Clone()
	if cloned.Header == nil {
		cloned.Header = make(http.Header)
	}
	return &cloned
}

// Dial establishes a websocket connection using the settings from ChannelSettings.
// It returns a websocket connection, an HTTP response, and an error if any.
func (t *ChannelSettings) Dial() (*websocket.Conn, *http.Response, error) {
	return t.Dialer().Dial(t.Url, t.Header)
}

// Validate checks if the ChannelSettings instance is valid.
func (t *ChannelSettings) Validate() error {
	if t == nil {
		return fmt.Errorf("channel details not specified")
	}

	if len(t.Subprotocols) == 0 {
		return fmt.Errorf("no subprotocol specified")
	}

	parsedURL, err := t.URL()
	if err != nil {
		return fmt.Errorf("invalid URL")
	}

	if parsedURL.Scheme != "ws" && parsedURL.Scheme != "wss" {
		return fmt.Errorf("invalid websocket scheme: %q", parsedURL.Scheme)
	}

	return nil
}

// IsEqual compares the current ChannelSettings with another ChannelSettings instance.
// It returns true if both instances are equal (or both nil), otherwise false.
func (t *ChannelSettings) IsEqual(other *ChannelSettings) bool {
	if t == nil && other == nil {
		return true
	}

	if t == nil || other == nil {
		return false
	}

	if len(t.Subprotocols) != len(other.Subprotocols) {
		return false
	}

	for i, subprotocol := range t.Subprotocols {
		if other.Subprotocols[i] != subprotocol {
			return false
		}
	}

	if len(t.Header) != len(other.Header) {
		return false
	}

	for header, values := range t.Header {
		if len(values) != len(other.Header[header]) {
			return false
		}
		for i, value := range values {
			if other.Header[header][i] != value {
				return false
			}
		}
	}

	return t.Url == other.Url &&
		t.CAPem == other.CAPem &&
		t.MaxSessionTime == other.MaxSessionTime
}