File: body.go

package info (click to toggle)
golang-github-lucas-clemente-quic-go 0.54.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,312 kB
  • sloc: sh: 54; makefile: 7
file content (133 lines) | stat: -rw-r--r-- 3,232 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package http3

import (
	"context"
	"errors"
	"io"
	"sync"

	"github.com/quic-go/quic-go"
)

// A Hijacker allows hijacking of the stream creating part of a quic.Conn from a http.ResponseWriter.
// It is used by WebTransport to create WebTransport streams after a session has been established.
type Hijacker interface {
	Connection() *Conn
}

var errTooMuchData = errors.New("peer sent too much data")

// The body is used in the requestBody (for a http.Request) and the responseBody (for a http.Response).
type body struct {
	str *Stream

	remainingContentLength int64
	violatedContentLength  bool
	hasContentLength       bool
}

func newBody(str *Stream, contentLength int64) *body {
	b := &body{str: str}
	if contentLength >= 0 {
		b.hasContentLength = true
		b.remainingContentLength = contentLength
	}
	return b
}

func (r *body) StreamID() quic.StreamID { return r.str.StreamID() }

func (r *body) checkContentLengthViolation() error {
	if !r.hasContentLength {
		return nil
	}
	if r.remainingContentLength < 0 || r.remainingContentLength == 0 && r.str.hasMoreData() {
		if !r.violatedContentLength {
			r.str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
			r.str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
			r.violatedContentLength = true
		}
		return errTooMuchData
	}
	return nil
}

func (r *body) Read(b []byte) (int, error) {
	if err := r.checkContentLengthViolation(); err != nil {
		return 0, err
	}
	if r.hasContentLength {
		b = b[:min(int64(len(b)), r.remainingContentLength)]
	}
	n, err := r.str.Read(b)
	r.remainingContentLength -= int64(n)
	if err := r.checkContentLengthViolation(); err != nil {
		return n, err
	}
	return n, maybeReplaceError(err)
}

func (r *body) Close() error {
	r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
	return nil
}

type requestBody struct {
	body
	connCtx      context.Context
	rcvdSettings <-chan struct{}
	getSettings  func() *Settings
}

var _ io.ReadCloser = &requestBody{}

func newRequestBody(str *Stream, contentLength int64, connCtx context.Context, rcvdSettings <-chan struct{}, getSettings func() *Settings) *requestBody {
	return &requestBody{
		body:         *newBody(str, contentLength),
		connCtx:      connCtx,
		rcvdSettings: rcvdSettings,
		getSettings:  getSettings,
	}
}

type hijackableBody struct {
	body body

	// only set for the http.Response
	// The channel is closed when the user is done with this response:
	// either when Read() errors, or when Close() is called.
	reqDone     chan<- struct{}
	reqDoneOnce sync.Once
}

var _ io.ReadCloser = &hijackableBody{}

func newResponseBody(str *Stream, contentLength int64, done chan<- struct{}) *hijackableBody {
	return &hijackableBody{
		body:    *newBody(str, contentLength),
		reqDone: done,
	}
}

func (r *hijackableBody) Read(b []byte) (int, error) {
	n, err := r.body.Read(b)
	if err != nil {
		r.requestDone()
	}
	return n, maybeReplaceError(err)
}

func (r *hijackableBody) requestDone() {
	if r.reqDone != nil {
		r.reqDoneOnce.Do(func() {
			close(r.reqDone)
		})
	}
}

func (r *hijackableBody) Close() error {
	r.requestDone()
	// If the EOF was read, CancelRead() is a no-op.
	r.body.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
	return nil
}