File: ticketsender.go

package info (click to toggle)
golang-github-google-s2a-go 0.1.8-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,800 kB
  • sloc: sh: 144; makefile: 9
file content (178 lines) | stat: -rw-r--r-- 5,631 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
/*
 *
 * Copyright 2021 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

package record

import (
	"context"
	"fmt"
	"sync"
	"time"

	"github.com/google/s2a-go/internal/handshaker/service"
	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
	"github.com/google/s2a-go/internal/tokenmanager"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/grpclog"
)

// sessionTimeout is the timeout for creating a session with the S2A handshaker
// service.
const sessionTimeout = time.Second * 5

// s2aTicketSender sends session tickets to the S2A handshaker service.
type s2aTicketSender interface {
	// sendTicketsToS2A sends the given session tickets to the S2A handshaker
	// service.
	sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool)
}

// ticketStream is the stream used to send and receive session information.
type ticketStream interface {
	Send(*s2apb.SessionReq) error
	Recv() (*s2apb.SessionResp, error)
}

type ticketSender struct {
	// hsAddr stores the address of the S2A handshaker service.
	hsAddr string
	// connectionID is the connection identifier that was created and sent by
	// S2A at the end of a handshake.
	connectionID uint64
	// localIdentity is the local identity that was used by S2A during session
	// setup and included in the session result.
	localIdentity *commonpb.Identity
	// tokenManager manages access tokens for authenticating to S2A.
	tokenManager tokenmanager.AccessTokenManager
	// ensureProcessSessionTickets allows users to wait and ensure that all
	// available session tickets are sent to S2A before a process completes.
	ensureProcessSessionTickets *sync.WaitGroup
}

// sendTicketsToS2A sends the given sessionTickets to the S2A handshaker
// service. This is done asynchronously and writes to the error logs if an error
// occurs.
func (t *ticketSender) sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool) {
	// Note that the goroutine is in the function rather than at the caller
	// because the fake ticket sender used for testing must run synchronously
	// so that the session tickets can be accessed from it after the tests have
	// been run.
	if t.ensureProcessSessionTickets != nil {
		t.ensureProcessSessionTickets.Add(1)
	}
	go func() {
		if err := func() error {
			defer func() {
				if t.ensureProcessSessionTickets != nil {
					t.ensureProcessSessionTickets.Done()
				}
			}()
			ctx, cancel := context.WithTimeout(context.Background(), sessionTimeout)
			defer cancel()
			// The transportCreds only needs to be set when talking to S2AV2 and also
			// if mTLS is required.
			hsConn, err := service.Dial(ctx, t.hsAddr, nil)
			if err != nil {
				return err
			}
			client := s2apb.NewS2AServiceClient(hsConn)
			session, err := client.SetUpSession(ctx)
			if err != nil {
				return err
			}
			defer func() {
				if err := session.CloseSend(); err != nil {
					grpclog.Error(err)
				}
			}()
			return t.writeTicketsToStream(session, sessionTickets)
		}(); err != nil {
			grpclog.Errorf("failed to send resumption tickets to S2A with identity: %v, %v",
				t.localIdentity, err)
		}
		callComplete <- true
		close(callComplete)
	}()
}

// writeTicketsToStream writes the given session tickets to the given stream.
func (t *ticketSender) writeTicketsToStream(stream ticketStream, sessionTickets [][]byte) error {
	if err := stream.Send(
		&s2apb.SessionReq{
			ReqOneof: &s2apb.SessionReq_ResumptionTicket{
				ResumptionTicket: &s2apb.ResumptionTicketReq{
					InBytes:       sessionTickets,
					ConnectionId:  t.connectionID,
					LocalIdentity: t.localIdentity,
				},
			},
			AuthMechanisms: t.getAuthMechanisms(),
		},
	); err != nil {
		return err
	}
	sessionResp, err := stream.Recv()
	if err != nil {
		return err
	}
	if sessionResp.GetStatus().GetCode() != uint32(codes.OK) {
		return fmt.Errorf("s2a session ticket response had error status: %v, %v",
			sessionResp.GetStatus().GetCode(), sessionResp.GetStatus().GetDetails())
	}
	return nil
}

func (t *ticketSender) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
	if t.tokenManager == nil {
		return nil
	}
	// First handle the special case when no local identity has been provided
	// by the application. In this case, an AuthenticationMechanism with no local
	// identity will be sent.
	if t.localIdentity == nil {
		token, err := t.tokenManager.DefaultToken()
		if err != nil {
			grpclog.Infof("unable to get token for empty local identity: %v", err)
			return nil
		}
		return []*s2apb.AuthenticationMechanism{
			{
				MechanismOneof: &s2apb.AuthenticationMechanism_Token{
					Token: token,
				},
			},
		}
	}

	// Next, handle the case where the application (or the S2A) has specified
	// a local identity.
	token, err := t.tokenManager.Token(t.localIdentity)
	if err != nil {
		grpclog.Infof("unable to get token for local identity %v: %v", t.localIdentity, err)
		return nil
	}
	return []*s2apb.AuthenticationMechanism{
		{
			Identity: t.localIdentity,
			MechanismOneof: &s2apb.AuthenticationMechanism_Token{
				Token: token,
			},
		},
	}
}