1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
package pgx
import (
"database/sql/driver"
"fmt"
"reflect"
"github.com/jackc/pgtype"
)
type extendedQueryBuilder struct {
paramValues [][]byte
paramValueBytes []byte
paramFormats []int16
resultFormats []int16
}
func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error {
f := chooseParameterFormatCode(ci, oid, arg)
eqb.paramFormats = append(eqb.paramFormats, f)
v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg)
if err != nil {
return err
}
eqb.paramValues = append(eqb.paramValues, v)
return nil
}
func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) {
eqb.resultFormats = append(eqb.resultFormats, f)
}
// Reset readies eqb to build another query.
func (eqb *extendedQueryBuilder) Reset() {
eqb.paramValues = eqb.paramValues[0:0]
eqb.paramValueBytes = eqb.paramValueBytes[0:0]
eqb.paramFormats = eqb.paramFormats[0:0]
eqb.resultFormats = eqb.resultFormats[0:0]
if cap(eqb.paramValues) > 64 {
eqb.paramValues = make([][]byte, 0, 64)
}
if cap(eqb.paramValueBytes) > 256 {
eqb.paramValueBytes = make([]byte, 0, 256)
}
if cap(eqb.paramFormats) > 64 {
eqb.paramFormats = make([]int16, 0, 64)
}
if cap(eqb.resultFormats) > 64 {
eqb.resultFormats = make([]int16, 0, 64)
}
}
func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) {
if arg == nil {
return nil, nil
}
refVal := reflect.ValueOf(arg)
argIsPtr := refVal.Kind() == reflect.Ptr
if argIsPtr && refVal.IsNil() {
return nil, nil
}
if eqb.paramValueBytes == nil {
eqb.paramValueBytes = make([]byte, 0, 128)
}
var err error
var buf []byte
pos := len(eqb.paramValueBytes)
if arg, ok := arg.(string); ok {
return []byte(arg), nil
}
if formatCode == TextFormatCode {
if arg, ok := arg.(pgtype.TextEncoder); ok {
buf, err = arg.EncodeText(ci, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
} else if formatCode == BinaryFormatCode {
if arg, ok := arg.(pgtype.BinaryEncoder); ok {
buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
}
if argIsPtr {
// We have already checked that arg is not pointing to nil,
// so it is safe to dereference here.
arg = refVal.Elem().Interface()
return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg)
}
if dt, ok := ci.DataTypeForOID(oid); ok {
value := dt.Value
err := value.Set(arg)
if err != nil {
{
if arg, ok := arg.(driver.Valuer); ok {
v, err := callValuerValue(arg)
if err != nil {
return nil, err
}
return eqb.encodeExtendedParamValue(ci, oid, formatCode, v)
}
}
return nil, err
}
return eqb.encodeExtendedParamValue(ci, oid, formatCode, value)
}
// There is no data type registered for the destination OID, but maybe there is data type registered for the arg
// type. If so use it's text encoder (if available).
if dt, ok := ci.DataTypeForValue(arg); ok {
value := dt.Value
if textEncoder, ok := value.(pgtype.TextEncoder); ok {
err := value.Set(arg)
if err != nil {
return nil, err
}
buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}
eqb.paramValueBytes = buf
return eqb.paramValueBytes[pos:], nil
}
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg)
}
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}
|