1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
|
package sqlmock
import (
"bytes"
"database/sql/driver"
"encoding/csv"
"fmt"
"io"
"strings"
)
const invalidate = "☠☠☠ MEMORY OVERWRITTEN ☠☠☠ "
// CSVColumnParser is a function which converts trimmed csv
// column string to a []byte representation. Currently
// transforms NULL to nil
var CSVColumnParser = func(s string) []byte {
switch {
case strings.ToLower(s) == "null":
return nil
}
return []byte(s)
}
type rowSets struct {
sets []*Rows
pos int
ex *ExpectedQuery
raw [][]byte
}
func (rs *rowSets) Columns() []string {
return rs.sets[rs.pos].cols
}
func (rs *rowSets) Close() error {
rs.invalidateRaw()
rs.ex.rowsWereClosed = true
return rs.sets[rs.pos].closeErr
}
// advances to next row
func (rs *rowSets) Next(dest []driver.Value) error {
r := rs.sets[rs.pos]
r.pos++
rs.invalidateRaw()
if r.pos > len(r.rows) {
return io.EOF // per interface spec
}
for i, col := range r.rows[r.pos-1] {
if b, ok := rawBytes(col); ok {
rs.raw = append(rs.raw, b)
dest[i] = b
continue
}
dest[i] = col
}
return r.nextErr[r.pos-1]
}
// transforms to debuggable printable string
func (rs *rowSets) String() string {
if rs.empty() {
return "with empty rows"
}
msg := "should return rows:\n"
if len(rs.sets) == 1 {
for n, row := range rs.sets[0].rows {
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
}
return strings.TrimSpace(msg)
}
for i, set := range rs.sets {
msg += fmt.Sprintf(" result set: %d\n", i)
for n, row := range set.rows {
msg += fmt.Sprintf(" row %d - %+v\n", n, row)
}
}
return strings.TrimSpace(msg)
}
func (rs *rowSets) empty() bool {
for _, set := range rs.sets {
if len(set.rows) > 0 {
return false
}
}
return true
}
func rawBytes(col driver.Value) (_ []byte, ok bool) {
val, ok := col.([]byte)
if !ok || len(val) == 0 {
return nil, false
}
// Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later
// This allows scanning into sql.RawBytes to correctly become invalid on subsequent calls to Next(), Scan() or Close()
b := make([]byte, len(val))
copy(b, val)
return b, true
}
// Bytes that could have been scanned as sql.RawBytes are only valid until the next call to Next, Scan or Close.
// If those occur, we must replace their content to simulate the shared memory to expose misuse of sql.RawBytes
func (rs *rowSets) invalidateRaw() {
// Replace the content of slices previously returned
b := []byte(invalidate)
for _, r := range rs.raw {
copy(r, bytes.Repeat(b, len(r)/len(b)+1))
}
// Start with new slices for the next scan
rs.raw = nil
}
// Rows is a mocked collection of rows to
// return for Query result
type Rows struct {
converter driver.ValueConverter
cols []string
rows [][]driver.Value
pos int
nextErr map[int]error
closeErr error
}
// NewRows allows Rows to be created from a
// sql driver.Value slice or from the CSV string and
// to be used as sql driver.Rows.
// Use Sqlmock.NewRows instead if using a custom converter
func NewRows(columns []string) *Rows {
return &Rows{
cols: columns,
nextErr: make(map[int]error),
converter: driver.DefaultParameterConverter,
}
}
// CloseError allows to set an error
// which will be returned by rows.Close
// function.
//
// The close error will be triggered only in cases
// when rows.Next() EOF was not yet reached, that is
// a default sql library behavior
func (r *Rows) CloseError(err error) *Rows {
r.closeErr = err
return r
}
// RowError allows to set an error
// which will be returned when a given
// row number is read
func (r *Rows) RowError(row int, err error) *Rows {
r.nextErr[row] = err
return r
}
// AddRow composed from database driver.Value slice
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
func (r *Rows) AddRow(values ...driver.Value) *Rows {
if len(values) != len(r.cols) {
panic("Expected number of values to match number of columns")
}
row := make([]driver.Value, len(r.cols))
for i, v := range values {
// Convert user-friendly values (such as int or driver.Valuer)
// to database/sql native value (driver.Value such as int64)
var err error
v, err = r.converter.ConvertValue(v)
if err != nil {
panic(fmt.Errorf(
"row #%d, column #%d (%q) type %T: %s",
len(r.rows)+1, i, r.cols[i], values[i], err,
))
}
row[i] = v
}
r.rows = append(r.rows, row)
return r
}
// FromCSVString build rows from csv string.
// return the same instance to perform subsequent actions.
// Note that the number of values must match the number
// of columns
func (r *Rows) FromCSVString(s string) *Rows {
res := strings.NewReader(strings.TrimSpace(s))
csvReader := csv.NewReader(res)
for {
res, err := csvReader.Read()
if err != nil || res == nil {
break
}
row := make([]driver.Value, len(r.cols))
for i, v := range res {
row[i] = CSVColumnParser(strings.TrimSpace(v))
}
r.rows = append(r.rows, row)
}
return r
}
|