1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
|
package pgproto3
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
// Backend acts as a server for the PostgreSQL wire protocol version 3.
type Backend struct {
cr ChunkReader
w io.Writer
// Frontend message flyweights
bind Bind
cancelRequest CancelRequest
_close Close
copyFail CopyFail
copyData CopyData
copyDone CopyDone
describe Describe
execute Execute
flush Flush
functionCall FunctionCall
gssEncRequest GSSEncRequest
parse Parse
query Query
sslRequest SSLRequest
startupMessage StartupMessage
sync Sync
terminate Terminate
bodyLen int
msgType byte
partialMsg bool
authType uint32
}
const (
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
)
// NewBackend creates a new Backend.
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w}
}
// Send sends a message to the frontend.
func (b *Backend) Send(msg BackendMessage) error {
_, err := b.w.Write(msg.Encode(nil))
return err
}
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
// because the initial connection message is "special" and does not include the message type as the first byte. This
// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
buf, err := b.cr.Next(4)
if err != nil {
return nil, err
}
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
}
buf, err = b.cr.Next(msgSize)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
code := binary.BigEndian.Uint32(buf)
switch code {
case ProtocolVersionNumber:
err = b.startupMessage.Decode(buf)
if err != nil {
return nil, err
}
return &b.startupMessage, nil
case sslRequestNumber:
err = b.sslRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.sslRequest, nil
case cancelRequestCode:
err = b.cancelRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.cancelRequest, nil
case gssEncReqNumber:
err = b.gssEncRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.gssEncRequest, nil
default:
return nil, fmt.Errorf("unknown startup message code: %d", code)
}
}
// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
func (b *Backend) Receive() (FrontendMessage, error) {
if !b.partialMsg {
header, err := b.cr.Next(5)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true
if b.bodyLen < 0 {
return nil, errors.New("invalid message with negative body length received")
}
}
var msg FrontendMessage
switch b.msgType {
case 'B':
msg = &b.bind
case 'C':
msg = &b._close
case 'D':
msg = &b.describe
case 'E':
msg = &b.execute
case 'F':
msg = &b.functionCall
case 'f':
msg = &b.copyFail
case 'd':
msg = &b.copyData
case 'c':
msg = &b.copyDone
case 'H':
msg = &b.flush
case 'P':
msg = &b.parse
case 'p':
switch b.authType {
case AuthTypeSASL:
msg = &SASLInitialResponse{}
case AuthTypeSASLContinue:
msg = &SASLResponse{}
case AuthTypeSASLFinal:
msg = &SASLResponse{}
case AuthTypeGSS, AuthTypeGSSCont:
msg = &GSSResponse{}
case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough
default:
// to maintain backwards compatability
msg = &PasswordMessage{}
}
case 'Q':
msg = &b.query
case 'S':
msg = &b.sync
case 'X':
msg = &b.terminate
default:
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
}
msgBody, err := b.cr.Next(b.bodyLen)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
b.partialMsg = false
err = msg.Decode(msgBody)
return msg, err
}
// SetAuthType sets the authentication type in the backend.
// Since multiple message types can start with 'p', SetAuthType allows
// contextual identification of FrontendMessages. For example, in the
// PG message flow documentation for PasswordMessage:
//
// Byte1('p')
//
// Identifies the message as a password response. Note that this is also used for
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
// the context.
//
// Since the Frontend does not know about the state of a backend, it is important
// to call SetAuthType() after an authentication request is received by the Frontend.
func (b *Backend) SetAuthType(authType uint32) error {
switch authType {
case AuthTypeOk,
AuthTypeCleartextPassword,
AuthTypeMD5Password,
AuthTypeSCMCreds,
AuthTypeGSS,
AuthTypeGSSCont,
AuthTypeSSPI,
AuthTypeSASL,
AuthTypeSASLContinue,
AuthTypeSASLFinal:
b.authType = authType
default:
return fmt.Errorf("authType not recognized: %d", authType)
}
return nil
}
|