File: streams_map_incoming.go

package info (click to toggle)
golang-github-lucas-clemente-quic-go 0.54.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,312 kB
  • sloc: sh: 54; makefile: 7
file content (209 lines) | stat: -rw-r--r-- 6,414 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
package quic

import (
	"context"
	"fmt"
	"sync"

	"github.com/quic-go/quic-go/internal/protocol"
	"github.com/quic-go/quic-go/internal/qerr"
	"github.com/quic-go/quic-go/internal/wire"
)

type incomingStream interface {
	closeForShutdown(error)
}

// When a stream is deleted before it was accepted, we can't delete it from the map immediately.
// We need to wait until the application accepts it, and delete it then.
type incomingStreamEntry[T incomingStream] struct {
	stream       T
	shouldDelete bool
}

type incomingStreamsMap[T incomingStream] struct {
	mutex         sync.RWMutex
	newStreamChan chan struct{}

	streamType protocol.StreamType
	streams    map[protocol.StreamID]incomingStreamEntry[T]

	nextStreamToAccept protocol.StreamID // the next stream that will be returned by AcceptStream()
	nextStreamToOpen   protocol.StreamID // the highest stream that the peer opened
	maxStream          protocol.StreamID // the highest stream that the peer is allowed to open
	maxNumStreams      uint64            // maximum number of streams

	newStream        func(protocol.StreamID) T
	queueMaxStreamID func(*wire.MaxStreamsFrame)

	closeErr error
}

func newIncomingStreamsMap[T incomingStream](
	streamType protocol.StreamType,
	newStream func(protocol.StreamID) T,
	maxStreams uint64,
	queueControlFrame func(wire.Frame),
	pers protocol.Perspective,
) *incomingStreamsMap[T] {
	var nextStreamToAccept protocol.StreamID
	switch {
	case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveServer:
		nextStreamToAccept = protocol.FirstIncomingBidiStreamServer
	case streamType == protocol.StreamTypeBidi && pers == protocol.PerspectiveClient:
		nextStreamToAccept = protocol.FirstIncomingBidiStreamClient
	case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveServer:
		nextStreamToAccept = protocol.FirstIncomingUniStreamServer
	case streamType == protocol.StreamTypeUni && pers == protocol.PerspectiveClient:
		nextStreamToAccept = protocol.FirstIncomingUniStreamClient
	}
	return &incomingStreamsMap[T]{
		newStreamChan:      make(chan struct{}, 1),
		streamType:         streamType,
		streams:            make(map[protocol.StreamID]incomingStreamEntry[T]),
		maxStream:          protocol.StreamNum(maxStreams).StreamID(streamType, pers.Opposite()),
		maxNumStreams:      maxStreams,
		newStream:          newStream,
		nextStreamToOpen:   nextStreamToAccept,
		nextStreamToAccept: nextStreamToAccept,
		queueMaxStreamID:   func(f *wire.MaxStreamsFrame) { queueControlFrame(f) },
	}
}

func (m *incomingStreamsMap[T]) AcceptStream(ctx context.Context) (T, error) {
	// drain the newStreamChan, so we don't check the map twice if the stream doesn't exist
	select {
	case <-m.newStreamChan:
	default:
	}

	m.mutex.Lock()

	var id protocol.StreamID
	var entry incomingStreamEntry[T]
	for {
		id = m.nextStreamToAccept
		if m.closeErr != nil {
			m.mutex.Unlock()
			return *new(T), m.closeErr
		}
		var ok bool
		entry, ok = m.streams[id]
		if ok {
			break
		}
		m.mutex.Unlock()
		select {
		case <-ctx.Done():
			return *new(T), ctx.Err()
		case <-m.newStreamChan:
		}
		m.mutex.Lock()
	}
	m.nextStreamToAccept += 4
	// If this stream was completed before being accepted, we can delete it now.
	if entry.shouldDelete {
		if err := m.deleteStream(id); err != nil {
			m.mutex.Unlock()
			return *new(T), err
		}
	}
	m.mutex.Unlock()
	return entry.stream, nil
}

func (m *incomingStreamsMap[T]) GetOrOpenStream(id protocol.StreamID) (T, error) {
	m.mutex.RLock()
	if id > m.maxStream {
		m.mutex.RUnlock()
		return *new(T), &qerr.TransportError{
			ErrorCode:    qerr.StreamLimitError,
			ErrorMessage: fmt.Sprintf("peer tried to open stream %d (current limit: %d)", id, m.maxStream),
		}
	}
	// if the num is smaller than the highest we accepted
	// * this stream exists in the map, and we can return it, or
	// * this stream was already closed, then we can return the nil
	if id < m.nextStreamToOpen {
		var s T
		// If the stream was already queued for deletion, and is just waiting to be accepted, don't return it.
		if entry, ok := m.streams[id]; ok && !entry.shouldDelete {
			s = entry.stream
		}
		m.mutex.RUnlock()
		return s, nil
	}
	m.mutex.RUnlock()

	m.mutex.Lock()
	// no need to check the two error conditions from above again
	// * maxStream can only increase, so if the id was valid before, it definitely is valid now
	// * highestStream is only modified by this function
	for newNum := m.nextStreamToOpen; newNum <= id; newNum += 4 {
		m.streams[newNum] = incomingStreamEntry[T]{stream: m.newStream(newNum)}
		select {
		case m.newStreamChan <- struct{}{}:
		default:
		}
	}
	m.nextStreamToOpen = id + 4
	entry := m.streams[id]
	m.mutex.Unlock()
	return entry.stream, nil
}

func (m *incomingStreamsMap[T]) DeleteStream(id protocol.StreamID) error {
	m.mutex.Lock()
	defer m.mutex.Unlock()

	if err := m.deleteStream(id); err != nil {
		return &qerr.TransportError{
			ErrorCode:    qerr.StreamStateError,
			ErrorMessage: err.Error(),
		}
	}
	return nil
}

func (m *incomingStreamsMap[T]) deleteStream(id protocol.StreamID) error {
	if _, ok := m.streams[id]; !ok {
		return fmt.Errorf("tried to delete unknown incoming stream %d", id)
	}

	// Don't delete this stream yet, if it was not yet accepted.
	// Just save it to streamsToDelete map, to make sure it is deleted as soon as it gets accepted.
	if id >= m.nextStreamToAccept {
		entry, ok := m.streams[id]
		if ok && entry.shouldDelete {
			return fmt.Errorf("tried to delete incoming stream %d multiple times", id)
		}
		entry.shouldDelete = true
		m.streams[id] = entry // can't assign to struct in map, so we need to reassign
		return nil
	}

	delete(m.streams, id)
	// queue a MAX_STREAM_ID frame, giving the peer the option to open a new stream
	if m.maxNumStreams > uint64(len(m.streams)) {
		maxStream := m.nextStreamToOpen + 4*protocol.StreamID(m.maxNumStreams-uint64(len(m.streams))-1)
		// never send a value larger than the maximum value for a stream number
		if maxStream <= protocol.MaxStreamID {
			m.maxStream = maxStream
			m.queueMaxStreamID(&wire.MaxStreamsFrame{
				Type:         m.streamType,
				MaxStreamNum: m.maxStream.StreamNum(),
			})
		}
	}
	return nil
}

func (m *incomingStreamsMap[T]) CloseWithError(err error) {
	m.mutex.Lock()
	m.closeErr = err
	for _, entry := range m.streams {
		entry.stream.closeForShutdown(err)
	}
	m.mutex.Unlock()
	close(m.newStreamChan)
}