1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
|
package testdb
import (
"database/sql/driver"
"errors"
)
type conn struct {
queries map[string]query
queryFunc func(query string, args []driver.Value) (driver.Rows, error)
execFunc func(query string, args []driver.Value) (driver.Result, error)
beginFunc func() (driver.Tx, error)
commitFunc func() error
rollbackFunc func() error
}
func newConn() *conn {
return &conn{
queries: make(map[string]query),
}
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
s := new(stmt)
if c.queryFunc != nil {
s.queryFunc = func(args []driver.Value) (driver.Rows, error) {
return c.queryFunc(query, args)
}
}
if c.execFunc != nil {
s.execFunc = func(args []driver.Value) (driver.Result, error) {
return c.execFunc(query, args)
}
}
if q, ok := d.conn.queries[getQueryHash(query)]; ok {
if s.queryFunc == nil && q.rows != nil {
s.queryFunc = func(args []driver.Value) (driver.Rows, error) {
if q.rows != nil {
if rows, ok := q.rows.(*rows); ok {
return rows.clone(), nil
}
return q.rows, nil
}
return nil, q.err
}
}
if s.execFunc == nil && q.result != nil {
s.execFunc = func(args []driver.Value) (driver.Result, error) {
if q.result != nil {
return q.result, nil
}
return nil, q.err
}
}
}
if s.queryFunc == nil && s.execFunc == nil {
return new(stmt), errors.New("Query not stubbed: " + query)
}
return s, nil
}
func (*conn) Close() error {
return nil
}
func (c *conn) Begin() (driver.Tx, error) {
if c.beginFunc != nil {
return c.beginFunc()
}
t := &Tx{}
if c.commitFunc != nil {
t.SetCommitFunc(c.commitFunc)
}
if c.rollbackFunc != nil {
t.SetRollbackFunc(c.rollbackFunc)
}
return t, nil
}
func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
if c.queryFunc != nil {
return c.queryFunc(query, args)
}
if q, ok := d.conn.queries[getQueryHash(query)]; ok {
if rows, ok := q.rows.(*rows); ok {
return rows.clone(), q.err
}
return q.rows, q.err
}
return nil, errors.New("Query not stubbed: " + query)
}
func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
if c.execFunc != nil {
return c.execFunc(query, args)
}
if q, ok := d.conn.queries[getQueryHash(query)]; ok {
if q.result != nil {
return q.result, nil
} else if q.err != nil {
return nil, q.err
}
}
return nil, errors.New("Exec call not stubbed: " + query)
}
|