1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
|
package connection
import (
"context"
"fmt"
"io"
"log"
"net/http"
"net/url"
"sync"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/microsoft/dev-tunnels/go/tunnels"
)
const (
clientName = "gh"
)
type TunnelClient struct {
*tunnels.Client
connected bool
mu sync.Mutex
}
type CodespaceConnection struct {
tunnelProperties api.TunnelProperties
TunnelManager *tunnels.Manager
TunnelClient *TunnelClient
Options *tunnels.TunnelRequestOptions
Tunnel *tunnels.Tunnel
AllowedPortPrivacySettings []string
}
// NewCodespaceConnection initializes a connection to a codespace.
// This connections allows for port forwarding which enables the
// use of most features of the codespace command.
func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpClient *http.Client) (connection *CodespaceConnection, err error) {
// Get the tunnel properties
tunnelProperties := codespace.Connection.TunnelProperties
// Create the tunnel manager
tunnelManager, err := getTunnelManager(tunnelProperties, httpClient)
if err != nil {
return nil, fmt.Errorf("error getting tunnel management client: %w", err)
}
// Calculate allowed port privacy settings
allowedPortPrivacySettings := codespace.RuntimeConstraints.AllowedPortPrivacySettings
// Get the access tokens
connectToken := tunnelProperties.ConnectAccessToken
managementToken := tunnelProperties.ManagePortsAccessToken
// Create the tunnel definition
tunnel := &tunnels.Tunnel{
AccessTokens: map[tunnels.TunnelAccessScope]string{tunnels.TunnelAccessScopeConnect: connectToken, tunnels.TunnelAccessScopeManagePorts: managementToken},
TunnelID: tunnelProperties.TunnelId,
ClusterID: tunnelProperties.ClusterId,
Domain: tunnelProperties.Domain,
}
// Create options
options := &tunnels.TunnelRequestOptions{
IncludePorts: true,
}
// Create the tunnel client (not connected yet)
tunnelClient, err := getTunnelClient(ctx, tunnelManager, tunnel, options)
if err != nil {
return nil, fmt.Errorf("error getting tunnel client: %w", err)
}
return &CodespaceConnection{
tunnelProperties: tunnelProperties,
TunnelManager: tunnelManager,
TunnelClient: tunnelClient,
Options: options,
Tunnel: tunnel,
AllowedPortPrivacySettings: allowedPortPrivacySettings,
}, nil
}
// Connect connects the client to the tunnel.
func (c *CodespaceConnection) Connect(ctx context.Context) error {
// Lock the mutex to prevent race conditions with the underlying SSH connection
c.TunnelClient.mu.Lock()
defer c.TunnelClient.mu.Unlock()
// If already connected, return
if c.TunnelClient.connected {
return nil
}
// Connect to the tunnel
if err := c.TunnelClient.Client.Connect(ctx, ""); err != nil {
return fmt.Errorf("error connecting to tunnel: %w", err)
}
// Set the connected flag so we know we're connected
c.TunnelClient.connected = true
return nil
}
// Close closes the underlying tunnel client SSH connection.
func (c *CodespaceConnection) Close() error {
// Lock the mutex to prevent race conditions with the underlying SSH connection
c.TunnelClient.mu.Lock()
defer c.TunnelClient.mu.Unlock()
// Don't close if we're not connected
if c.TunnelClient != nil && c.TunnelClient.connected {
if err := c.TunnelClient.Close(); err != nil {
return fmt.Errorf("failed to close tunnel client connection: %w", err)
}
c.TunnelClient.connected = false
}
return nil
}
// getTunnelManager creates a tunnel manager for the given codespace.
// The tunnel manager is used to get the tunnel hosted in the codespace that we
// want to connect to and perform operations on ports (add, remove, list, etc.).
func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Client) (tunnelManager *tunnels.Manager, err error) {
userAgent := []tunnels.UserAgent{{Name: clientName}}
url, err := url.Parse(tunnelProperties.ServiceUri)
if err != nil {
return nil, fmt.Errorf("error parsing tunnel service uri: %w", err)
}
// Create the tunnel manager
tunnelManager, err = tunnels.NewManager(userAgent, nil, url, httpClient)
if err != nil {
return nil, fmt.Errorf("error creating tunnel manager: %w", err)
}
return tunnelManager, nil
}
// getTunnelClient creates a tunnel client for the given tunnel.
// The tunnel client is used to connect to the tunnel and allows
// for ports to be forwarded locally.
func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *TunnelClient, err error) {
// Get the tunnel that we want to connect to
codespaceTunnel, err := tunnelManager.GetTunnel(ctx, tunnel, options)
if err != nil {
return nil, fmt.Errorf("error getting tunnel: %w", err)
}
// Copy the access tokens from the tunnel definition
codespaceTunnel.AccessTokens = tunnel.AccessTokens
// We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports
client, err := tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false)
if err != nil {
return nil, fmt.Errorf("error creating tunnel client: %w", err)
}
tunnelClient = &TunnelClient{
Client: client,
connected: false,
}
return tunnelClient, nil
}
|