1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
|
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package sql
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"strings"
)
// ColumnScanner is the interface that wraps the
// four sql.Rows methods used for scanning.
type ColumnScanner interface {
Next() bool
Scan(...interface{}) error
Columns() ([]string, error)
Err() error
}
// ScanOne scans one row to the given value. It fails if the rows holds more than 1 row.
func ScanOne(rows ColumnScanner, v interface{}) error {
columns, err := rows.Columns()
if err != nil {
return fmt.Errorf("sql/scan: failed getting column names: %v", err)
}
if n := len(columns); n != 1 {
return fmt.Errorf("sql/scan: unexpected number of columns: %d", n)
}
if !rows.Next() {
if err := rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
if err := rows.Scan(v); err != nil {
return err
}
if rows.Next() {
return fmt.Errorf("sql/scan: expect exactly one row in result set")
}
return rows.Err()
}
// ScanInt64 scans and returns an int64 from the rows columns.
func ScanInt64(rows ColumnScanner) (int64, error) {
var n int64
if err := ScanOne(rows, &n); err != nil {
return 0, err
}
return n, nil
}
// ScanInt scans and returns an int from the rows columns.
func ScanInt(rows ColumnScanner) (int, error) {
n, err := ScanInt64(rows)
if err != nil {
return 0, err
}
return int(n), nil
}
// ScanString scans and returns a string from the rows columns.
func ScanString(rows ColumnScanner) (string, error) {
var s string
if err := ScanOne(rows, &s); err != nil {
return "", err
}
return s, nil
}
// ScanValue scans and returns a driver.Value from the rows columns.
func ScanValue(rows ColumnScanner) (driver.Value, error) {
var v driver.Value
if err := ScanOne(rows, &v); err != nil {
return "", err
}
return v, nil
}
// ScanSlice scans the given ColumnScanner (basically, sql.Row or sql.Rows) into the given slice.
func ScanSlice(rows ColumnScanner, v interface{}) error {
columns, err := rows.Columns()
if err != nil {
return fmt.Errorf("sql/scan: failed getting column names: %v", err)
}
rv := reflect.Indirect(reflect.ValueOf(v))
if k := rv.Kind(); k != reflect.Slice {
return fmt.Errorf("sql/scan: invalid type %s. expected slice as an argument", k)
}
scan, err := scanType(rv.Type().Elem(), columns)
if err != nil {
return err
}
if n, m := len(columns), len(scan.columns); n > m {
return fmt.Errorf("sql/scan: columns do not match (%d > %d)", n, m)
}
for rows.Next() {
values := scan.values()
if err := rows.Scan(values...); err != nil {
return fmt.Errorf("sql/scan: failed scanning rows: %v", err)
}
vv := reflect.Append(rv, scan.value(values...))
rv.Set(vv)
}
return rows.Err()
}
// rowScan is the configuration for scanning one sql.Row.
type rowScan struct {
// column types of a row.
columns []reflect.Type
// value functions that converts the row columns (result) to a reflect.Value.
value func(v ...interface{}) reflect.Value
}
// values returns a []interface{} from the configured column types.
func (r *rowScan) values() []interface{} {
values := make([]interface{}, len(r.columns))
for i := range r.columns {
values[i] = reflect.New(r.columns[i]).Interface()
}
return values
}
// scanType returns rowScan for the given reflect.Type.
func scanType(typ reflect.Type, columns []string) (*rowScan, error) {
switch k := typ.Kind(); {
case assignable(typ):
return &rowScan{
columns: []reflect.Type{typ},
value: func(v ...interface{}) reflect.Value {
return reflect.Indirect(reflect.ValueOf(v[0]))
},
}, nil
case k == reflect.Ptr:
return scanPtr(typ, columns)
case k == reflect.Struct:
return scanStruct(typ, columns)
default:
return nil, fmt.Errorf("sql/scan: unsupported type ([]%s)", k)
}
}
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
// assignable reports if the given type can be assigned directly by `Rows.Scan`.
func assignable(typ reflect.Type) bool {
switch k := typ.Kind(); {
case typ.Implements(scannerType):
case k == reflect.Interface && typ.NumMethod() == 0:
case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64:
case (k == reflect.Slice || k == reflect.Array) && typ.Elem().Kind() == reflect.Uint8:
default:
return false
}
return true
}
// scanStruct returns the a configuration for scanning an sql.Row into a struct.
func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) {
var (
scan = &rowScan{}
names = make(map[string]int)
idx = make([]int, 0, typ.NumField())
)
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
name := strings.ToLower(f.Name)
if tag, ok := f.Tag.Lookup("sql"); ok {
name = tag
} else if tag, ok := f.Tag.Lookup("json"); ok {
name = strings.Split(tag, ",")[0]
}
names[name] = i
}
for _, c := range columns {
// normalize columns if necessary, for example: COUNT(*) => count.
name := strings.ToLower(strings.Split(c, "(")[0])
i, ok := names[name]
if !ok {
return nil, fmt.Errorf("sql/scan: missing struct field for column: %s (%s)", c, name)
}
idx = append(idx, i)
scan.columns = append(scan.columns, typ.Field(i).Type)
}
scan.value = func(vs ...interface{}) reflect.Value {
st := reflect.New(typ).Elem()
for i, v := range vs {
st.Field(idx[i]).Set(reflect.Indirect(reflect.ValueOf(v)))
}
return st
}
return scan, nil
}
// scanPtr wraps the underlying type with rowScan.
func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) {
typ = typ.Elem()
scan, err := scanType(typ, columns)
if err != nil {
return nil, err
}
wrap := scan.value
scan.value = func(vs ...interface{}) reflect.Value {
v := wrap(vs...)
pt := reflect.PtrTo(v.Type())
pv := reflect.New(pt.Elem())
pv.Elem().Set(v)
return pv
}
return scan, nil
}
|