1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
|
package proto
import (
"fmt"
"github.com/inconshreveable/muxado/proto/frame"
"io"
"net"
"reflect"
"sync"
"sync/atomic"
"time"
)
const (
defaultWindowSize = 0x10000 // 64KB
defaultAcceptQueueDepth = 100
MinExtensionType = 0xFFFFFFFF - 0x100 // 512 extensions
)
// private interface for Sessions to call Streams
type stream interface {
IStream
handleStreamData(*frame.RStreamData)
handleStreamWndInc(*frame.RStreamWndInc)
handleStreamRst(*frame.RStreamRst)
closeWith(error)
}
// for extensions
type ExtAccept func() (IStream, error)
type Extension interface {
Start(ISession, ExtAccept) frame.StreamType
}
type deadReason struct {
errorCode frame.ErrorCode
err error
remoteDebug []byte
}
// factory function that creates new streams
type streamFactory func(id frame.StreamId, priority frame.StreamPriority, streamType frame.StreamType, finLocal bool, finRemote bool, windowSize uint32, sess session) stream
// checks the parity of a stream id (local vs remote, client vs server)
type parityFn func(frame.StreamId) bool
// state for each half of the session (remote and local)
type halfState struct {
goneAway int32 // true if that half of the stream has gone away
lastId uint32 // last id used/seen from one half of the session
}
// Session implements a simple streaming session manager. It has the following characteristics:
//
// - When closing the Session, it does not linger, all pending write operations will fail immediately.
// - It completely ignores stream priority when processing and writing frames
// - It offers no customization of settings like window size/ping time
type Session struct {
conn net.Conn // connection the transport is running over
transport frame.Transport // transport
streams StreamMap // all active streams
local halfState // client state
remote halfState // server state
syn *frame.WStreamSyn // STREAM_SYN frame for opens
wr sync.Mutex // synchronization when writing frames
accept chan stream // new streams opened by the remote
diebit int32 // true if we're dying
remoteDebug []byte // debugging data sent in the remote's GoAway frame
defaultWindowSize uint32 // window size when creating new streams
newStream streamFactory // factory function to make new streams
dead chan deadReason // dead
isLocal parityFn // determines if a stream id is local or remote
exts map[frame.StreamType]chan stream // map of extension stream type -> accept channel for the extension
}
func NewSession(conn net.Conn, newStream streamFactory, isClient bool, exts []Extension) ISession {
sess := &Session{
conn: conn,
transport: frame.NewBasicTransport(conn),
streams: NewConcurrentStreamMap(),
local: halfState{lastId: 0},
remote: halfState{lastId: 0},
syn: frame.NewWStreamSyn(),
diebit: 0,
defaultWindowSize: defaultWindowSize,
accept: make(chan stream, defaultAcceptQueueDepth),
newStream: newStream,
dead: make(chan deadReason, 1), // don't block die() if there is no Wait call
exts: make(map[frame.StreamType]chan stream),
}
if isClient {
sess.isLocal = sess.isClient
sess.local.lastId += 1
} else {
sess.isLocal = sess.isServer
sess.remote.lastId += 1
}
for _, ext := range exts {
sess.startExtension(ext)
}
go sess.reader()
return sess
}
////////////////////////////////
// public interface
////////////////////////////////
func (s *Session) Open() (IStream, error) {
return s.OpenStream(0, 0, false)
}
func (s *Session) OpenStream(priority frame.StreamPriority, streamType frame.StreamType, fin bool) (ret IStream, err error) {
// check if the remote has gone away
if atomic.LoadInt32(&s.remote.goneAway) == 1 {
return nil, fmt.Errorf("Failed to create stream, remote has gone away.")
}
// this lock prevents the following race:
// goroutine1 goroutine2
// - inc stream id
// - inc stream id
// - send streamsyn
// - send streamsyn
s.wr.Lock()
// get the next id we can use
nextId := frame.StreamId(atomic.AddUint32(&s.local.lastId, 2))
// make the stream
str := s.newStream(nextId, priority, streamType, fin, false, s.defaultWindowSize, s)
// add to to the stream map
s.streams.Set(nextId, str)
// write the frame
if err = s.syn.Set(nextId, priority, streamType, fin); err != nil {
s.wr.Unlock()
s.die(frame.InternalError, err)
return
}
if err = s.transport.WriteFrame(s.syn); err != nil {
s.wr.Unlock()
s.die(frame.InternalError, err)
return
}
s.wr.Unlock()
return str, nil
}
func (s *Session) Accept() (str IStream, err error) {
var ok bool
if str, ok = <-s.accept; !ok {
return nil, fmt.Errorf("Session closed")
}
return
}
func (s *Session) Kill() error {
return s.transport.Close()
}
func (s *Session) Close() error {
return s.die(frame.NoError, fmt.Errorf("Session Close()"))
}
func (s *Session) GoAway(errorCode frame.ErrorCode, debug []byte) (err error) {
if !atomic.CompareAndSwapInt32(&s.local.goneAway, 0, 1) {
return fmt.Errorf("Already sent GoAway!")
}
s.wr.Lock()
f := frame.NewWGoAway()
remoteId := frame.StreamId(atomic.LoadUint32(&s.remote.lastId))
if err = f.Set(remoteId, errorCode, debug); err != nil {
s.wr.Unlock()
s.die(frame.InternalError, err)
return
}
if err = s.transport.WriteFrame(f); err != nil {
s.wr.Unlock()
s.die(frame.InternalError, err)
return
}
s.wr.Unlock()
return
}
func (s *Session) LocalAddr() net.Addr {
return s.conn.LocalAddr()
}
func (s *Session) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}
func (s *Session) Wait() (frame.ErrorCode, error, []byte) {
reason := <-s.dead
return reason.errorCode, reason.err, reason.remoteDebug
}
////////////////////////////////
// private interface for streams
////////////////////////////////
// removeStream removes a stream from this session's stream registry
//
// It does not error if the stream is not present
func (s *Session) removeStream(id frame.StreamId) {
s.streams.Delete(id)
return
}
// writeFrame writes the given frame to the transport and returns the error from the write operation
func (s *Session) writeFrame(f frame.WFrame, dl time.Time) (err error) {
s.wr.Lock()
s.conn.SetWriteDeadline(dl)
err = s.transport.WriteFrame(f)
s.wr.Unlock()
return
}
// die closes the session cleanly with the given error and protocol error code
func (s *Session) die(errorCode frame.ErrorCode, err error) error {
// only one shutdown ever happens
if !atomic.CompareAndSwapInt32(&s.diebit, 0, 1) {
return fmt.Errorf("Shutdown already in progress")
}
// send a go away frame
s.GoAway(errorCode, []byte(err.Error()))
// now we're safe to stop accepting incoming connections
close(s.accept)
// we cleaned up as best as possible, close the transport
s.transport.Close()
// notify all of the streams that we're closing
s.streams.Each(func(id frame.StreamId, str stream) {
str.closeWith(fmt.Errorf("Session closed"))
})
s.dead <- deadReason{errorCode, err, s.remoteDebug}
return nil
}
////////////////////////////////
// internal methods
////////////////////////////////
// reader() reads frames from the underlying transport and handles passes them to handleFrame
func (s *Session) reader() {
defer s.recoverPanic("reader()")
// close all of the extension accept channels when we're done
// we do this here instead of in die() since otherwise it wouldn't
// be safe to access s.exts
defer func() {
for _, extAccept := range s.exts {
close(extAccept)
}
}()
for {
f, err := s.transport.ReadFrame()
if err != nil {
// if we fail to read a frame, terminate the session
_, ok := err.(*frame.FramingError)
if ok {
s.die(frame.ProtocolError, err)
} else {
s.die(frame.InternalError, err)
}
return
}
s.handleFrame(f)
}
}
func (s *Session) handleFrame(rf frame.RFrame) {
switch f := rf.(type) {
case *frame.RStreamSyn:
// if we're going away, refuse new streams
if atomic.LoadInt32(&s.local.goneAway) == 1 {
rstF := frame.NewWStreamRst()
rstF.Set(f.StreamId(), frame.RefusedStream)
go s.writeFrame(rstF, time.Time{})
return
}
if f.StreamId() <= frame.StreamId(atomic.LoadUint32(&s.remote.lastId)) {
s.die(frame.ProtocolError, fmt.Errorf("Stream id %d is less than last remote id.", f.StreamId()))
return
}
if s.isLocal(f.StreamId()) {
s.die(frame.ProtocolError, fmt.Errorf("Stream id has wrong parity for remote endpoint: %d", f.StreamId()))
return
}
// update last remote id
atomic.StoreUint32(&s.remote.lastId, uint32(f.StreamId()))
// make the new stream
str := s.newStream(f.StreamId(), f.StreamPriority(), f.StreamType(), false, f.Fin(), s.defaultWindowSize, s)
// add it to the stream map
s.streams.Set(f.StreamId(), str)
// check if this is an extension stream
if f.StreamType() >= MinExtensionType {
extAccept, ok := s.exts[f.StreamType()]
if !ok {
// Extension type of stream not registered
fRst := frame.NewWStreamRst()
if err := fRst.Set(f.StreamId(), frame.StreamClosed); err != nil {
s.die(frame.InternalError, err)
}
s.wr.Lock()
defer s.wr.Unlock()
s.transport.WriteFrame(fRst)
} else {
extAccept <- str
}
return
}
// put the new stream on the accept channel
s.accept <- str
case *frame.RStreamData:
if str := s.getStream(f.StreamId()); str != nil {
str.handleStreamData(f)
} else {
// if we get a data frame on a non-existent connection, we still
// need to read out the frame body so that the stream stays in a
// good state. read the payload into a throwaway buffer
discard := make([]byte, f.Length())
io.ReadFull(f.Reader(), discard)
// DATA frames on closed connections are just stream-level errors
fRst := frame.NewWStreamRst()
if err := fRst.Set(f.StreamId(), frame.StreamClosed); err != nil {
s.die(frame.InternalError, err)
}
s.wr.Lock()
defer s.wr.Unlock()
s.transport.WriteFrame(fRst)
return
}
case *frame.RStreamRst:
// delegate to the stream to handle these frames
if str := s.getStream(f.StreamId()); str != nil {
str.handleStreamRst(f)
}
case *frame.RStreamWndInc:
// delegate to the stream to handle these frames
if str := s.getStream(f.StreamId()); str != nil {
str.handleStreamWndInc(f)
}
case *frame.RGoAway:
atomic.StoreInt32(&s.remote.goneAway, 1)
s.remoteDebug = f.Debug()
lastId := f.LastStreamId()
s.streams.Each(func(id frame.StreamId, str stream) {
// close all streams that we opened above the last handled id
if s.isLocal(str.Id()) && str.Id() > lastId {
str.closeWith(fmt.Errorf("Remote is going away"))
}
})
default:
s.die(frame.ProtocolError, fmt.Errorf("Unrecognized frame type: %v", reflect.TypeOf(f)))
return
}
}
func (s *Session) recoverPanic(prefix string) {
if r := recover(); r != nil {
s.die(frame.InternalError, fmt.Errorf("%s panic: %v", prefix, r))
}
}
func (s *Session) getStream(id frame.StreamId) (str stream) {
// decide if this id is in the "idle" state (i.e. greater than any we've seen for that parity)
var lastId *uint32
if s.isLocal(id) {
lastId = &s.local.lastId
} else {
lastId = &s.remote.lastId
}
if uint32(id) > atomic.LoadUint32(lastId) {
s.die(frame.ProtocolError, fmt.Errorf("%d is an invalid, unassigned stream id", id))
}
// find the stream in the stream map
var ok bool
if str, ok = s.streams.Get(id); !ok {
return nil
}
return
}
// check if a stream id is for a client stream. client streams are odd
func (s *Session) isClient(id frame.StreamId) bool {
return uint32(id)&1 == 1
}
func (s *Session) isServer(id frame.StreamId) bool {
return !s.isClient(id)
}
//////////////////////////////////////////////
// session extensions
//////////////////////////////////////////////
func (s *Session) startExtension(ext Extension) {
accept := make(chan stream)
extAccept := func() (IStream, error) {
s, ok := <-accept
if !ok {
return nil, fmt.Errorf("Failed to accept connection, shutting down")
}
return s, nil
}
extType := ext.Start(s, extAccept)
s.exts[extType] = accept
}
//////////////////////////////////////////////
// net adaptors
//////////////////////////////////////////////
func (s *Session) NetDial(_, _ string) (net.Conn, error) {
str, err := s.Open()
return net.Conn(str), err
}
func (s *Session) NetListener() net.Listener {
return &netListenerAdaptor{s}
}
type netListenerAdaptor struct {
*Session
}
func (a *netListenerAdaptor) Addr() net.Addr {
return a.LocalAddr()
}
func (a *netListenerAdaptor) Accept() (net.Conn, error) {
str, err := a.Session.Accept()
return net.Conn(str), err
}
|