1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
package websocket
import (
"bufio"
"bytes"
"errors"
"net"
"net/http"
"strings"
)
var (
ErrInvalidMethod = errors.New("Only GET Supported")
ErrInvalidVersion = errors.New("Sec-Websocket-Version: 13")
ErrInvalidUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"")
ErrInvalidConnection = errors.New("\"Connection\" must be \"Upgrade\"")
ErrMissingKey = errors.New("Missing Key")
ErrHijacker = errors.New("Not implement http.Hijacker")
ErrNoEmptyConn = errors.New("Conn ReadBuf must be empty")
)
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
if r.Method != "GET" {
return nil, ErrInvalidMethod
}
if r.Header.Get("Sec-Websocket-Version") != "13" {
return nil, ErrInvalidVersion
}
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
return nil, ErrInvalidUpgrade
}
if strings.ToLower(r.Header.Get("Connection")) != "upgrade" {
return nil, ErrInvalidConnection
}
var acceptKey string
if key := r.Header.Get("Sec-Websocket-key"); len(key) == 0 {
return nil, ErrMissingKey
} else {
acceptKey = calcAcceptKey(key)
}
var (
netConn net.Conn
br *bufio.Reader
err error
)
h, ok := w.(http.Hijacker)
if !ok {
return nil, ErrHijacker
}
var rw *bufio.ReadWriter
netConn, rw, err = h.Hijack()
br = rw.Reader
if br.Buffered() > 0 {
netConn.Close()
return nil, ErrNoEmptyConn
}
c := NewConn(netConn, true)
buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ")
buf.WriteString(acceptKey)
buf.WriteString("\r\n")
subProtol := selectSubProtocol(r)
if len(subProtol) > 0 {
buf.WriteString("Sec-Websocket-Protocol: ")
buf.WriteString(subProtol)
buf.WriteString("\r\n")
}
for k, vs := range responseHeader {
for _, v := range vs {
buf.WriteString(k)
buf.WriteString(": ")
buf.WriteString(v)
buf.WriteString("\r\n")
}
}
buf.WriteString("\r\n")
if _, err = netConn.Write(buf.Bytes()); err != nil {
netConn.Close()
return nil, err
}
return c, nil
}
func selectSubProtocol(r *http.Request) string {
h := r.Header.Get("Sec-Websocket-Protocol")
if len(h) == 0 {
return ""
}
return strings.Split(h, ",")[0]
}
|