1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
|
// Copyright 2016 Qiang Xue. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
// Package dbx provides a set of DB-agnostic and easy-to-use query building methods for relational databases.
package dbx
import (
"bytes"
"context"
"database/sql"
"regexp"
"strings"
"time"
)
type (
// LogFunc logs a message for each SQL statement being executed.
// This method takes one or multiple parameters. If a single parameter
// is provided, it will be treated as the log message. If multiple parameters
// are provided, they will be passed to fmt.Sprintf() to generate the log message.
LogFunc func(format string, a ...interface{})
// PerfFunc is called when a query finishes execution.
// The query execution time is passed to this function so that the DB performance
// can be profiled. The "ns" parameter gives the number of nanoseconds that the
// SQL statement takes to execute, while the "execute" parameter indicates whether
// the SQL statement is executed or queried (usually SELECT statements).
PerfFunc func(ns int64, sql string, execute bool)
// QueryLogFunc is called each time when performing a SQL query.
// The "t" parameter gives the time that the SQL statement takes to execute,
// while rows and err are the result of the query.
QueryLogFunc func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error)
// ExecLogFunc is called each time when a SQL statement is executed.
// The "t" parameter gives the time that the SQL statement takes to execute,
// while result and err refer to the result of the execution.
ExecLogFunc func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error)
// BuilderFunc creates a Builder instance using the given DB instance and Executor.
BuilderFunc func(*DB, Executor) Builder
// DB enhances sql.DB by providing a set of DB-agnostic query building methods.
// DB allows easier query building and population of data into Go variables.
DB struct {
Builder
// FieldMapper maps struct fields to DB columns. Defaults to DefaultFieldMapFunc.
FieldMapper FieldMapFunc
// TableMapper maps structs to table names. Defaults to GetTableName.
TableMapper TableMapFunc
// LogFunc logs the SQL statements being executed. Defaults to nil, meaning no logging.
LogFunc LogFunc
// PerfFunc logs the SQL execution time. Defaults to nil, meaning no performance profiling.
// Deprecated: Please use QueryLogFunc and ExecLogFunc instead.
PerfFunc PerfFunc
// QueryLogFunc is called each time when performing a SQL query that returns data.
QueryLogFunc QueryLogFunc
// ExecLogFunc is called each time when a SQL statement is executed.
ExecLogFunc ExecLogFunc
sqlDB *sql.DB
driverName string
ctx context.Context
}
// Errors represents a list of errors.
Errors []error
)
// BuilderFuncMap lists supported BuilderFunc according to DB driver names.
// You may modify this variable to add the builder support for a new DB driver.
// If a DB driver is not listed here, the StandardBuilder will be used.
var BuilderFuncMap = map[string]BuilderFunc{
"sqlite": NewSqliteBuilder,
"sqlite3": NewSqliteBuilder,
"mysql": NewMysqlBuilder,
"postgres": NewPgsqlBuilder,
"pgx": NewPgsqlBuilder,
"mssql": NewMssqlBuilder,
"oci8": NewOciBuilder,
}
// NewFromDB encapsulates an existing database connection.
func NewFromDB(sqlDB *sql.DB, driverName string) *DB {
db := &DB{
driverName: driverName,
sqlDB: sqlDB,
FieldMapper: DefaultFieldMapFunc,
TableMapper: GetTableName,
}
db.Builder = db.newBuilder(db.sqlDB)
return db
}
// Open opens a database specified by a driver name and data source name (DSN).
// Note that Open does not check if DSN is specified correctly. It doesn't try to establish a DB connection either.
// Please refer to sql.Open() for more information.
func Open(driverName, dsn string) (*DB, error) {
sqlDB, err := sql.Open(driverName, dsn)
if err != nil {
return nil, err
}
return NewFromDB(sqlDB, driverName), nil
}
// MustOpen opens a database and establishes a connection to it.
// Please refer to sql.Open() and sql.Ping() for more information.
func MustOpen(driverName, dsn string) (*DB, error) {
db, err := Open(driverName, dsn)
if err != nil {
return nil, err
}
if err := db.sqlDB.Ping(); err != nil {
db.Close()
return nil, err
}
return db, nil
}
// Clone makes a shallow copy of DB.
func (db *DB) Clone() *DB {
db2 := &DB{
driverName: db.driverName,
sqlDB: db.sqlDB,
FieldMapper: db.FieldMapper,
TableMapper: db.TableMapper,
PerfFunc: db.PerfFunc,
LogFunc: db.LogFunc,
QueryLogFunc: db.QueryLogFunc,
ExecLogFunc: db.ExecLogFunc,
}
db2.Builder = db2.newBuilder(db.sqlDB)
return db2
}
// WithContext returns a new instance of DB associated with the given context.
func (db *DB) WithContext(ctx context.Context) *DB {
db2 := db.Clone()
db2.ctx = ctx
return db2
}
// Context returns the context associated with the DB instance.
// It returns nil if no context is associated.
func (db *DB) Context() context.Context {
return db.ctx
}
// DB returns the sql.DB instance encapsulated by dbx.DB.
func (db *DB) DB() *sql.DB {
return db.sqlDB
}
// Close closes the database, releasing any open resources.
// It is rare to Close a DB, as the DB handle is meant to be
// long-lived and shared between many goroutines.
func (db *DB) Close() error {
return db.sqlDB.Close()
}
// Begin starts a transaction.
func (db *DB) Begin() (*Tx, error) {
var tx *sql.Tx
var err error
if db.ctx != nil {
tx, err = db.sqlDB.BeginTx(db.ctx, nil)
} else {
tx, err = db.sqlDB.Begin()
}
if err != nil {
return nil, err
}
return &Tx{db.newBuilder(tx), tx}, nil
}
// BeginTx starts a transaction with the given context and transaction options.
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := db.sqlDB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{db.newBuilder(tx), tx}, nil
}
// Wrap encapsulates an existing transaction.
func (db *DB) Wrap(sqlTx *sql.Tx) *Tx {
return &Tx{db.newBuilder(sqlTx), sqlTx}
}
// Transactional starts a transaction and executes the given function.
// If the function returns an error, the transaction will be rolled back.
// Otherwise, the transaction will be committed.
func (db *DB) Transactional(f func(*Tx) error) (err error) {
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
if err2 := tx.Rollback(); err2 != nil {
if err2 == sql.ErrTxDone {
return
}
err = Errors{err, err2}
}
} else {
if err = tx.Commit(); err == sql.ErrTxDone {
err = nil
}
}
}()
err = f(tx)
return err
}
// TransactionalContext starts a transaction and executes the given function with the given context and transaction options.
// If the function returns an error, the transaction will be rolled back.
// Otherwise, the transaction will be committed.
func (db *DB) TransactionalContext(ctx context.Context, opts *sql.TxOptions, f func(*Tx) error) (err error) {
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
if err2 := tx.Rollback(); err2 != nil {
if err2 == sql.ErrTxDone {
return
}
err = Errors{err, err2}
}
} else {
if err = tx.Commit(); err == sql.ErrTxDone {
err = nil
}
}
}()
err = f(tx)
return err
}
// DriverName returns the name of the DB driver.
func (db *DB) DriverName() string {
return db.driverName
}
// QuoteTableName quotes the given table name appropriately.
// If the table name contains DB schema prefix, it will be handled accordingly.
// This method will do nothing if the table name is already quoted or if it contains parenthesis.
func (db *DB) QuoteTableName(s string) string {
if strings.Contains(s, "(") || strings.Contains(s, "{{") {
return s
}
if !strings.Contains(s, ".") {
return db.QuoteSimpleTableName(s)
}
parts := strings.Split(s, ".")
for i, part := range parts {
parts[i] = db.QuoteSimpleTableName(part)
}
return strings.Join(parts, ".")
}
// QuoteColumnName quotes the given column name appropriately.
// If the table name contains table name prefix, it will be handled accordingly.
// This method will do nothing if the column name is already quoted or if it contains parenthesis.
func (db *DB) QuoteColumnName(s string) string {
if strings.Contains(s, "(") || strings.Contains(s, "{{") || strings.Contains(s, "[[") {
return s
}
prefix := ""
if pos := strings.LastIndex(s, "."); pos != -1 {
prefix = db.QuoteTableName(s[:pos]) + "."
s = s[pos+1:]
}
return prefix + db.QuoteSimpleColumnName(s)
}
var (
plRegex = regexp.MustCompile(`\{:\w+\}`)
quoteRegex = regexp.MustCompile(`(\{\{[\w\-\. ]+\}\}|\[\[[\w\-\. ]+\]\])`)
)
// processSQL replaces the named param placeholders in the given SQL with anonymous ones.
// It also quotes table names and column names found in the SQL if these names are enclosed
// within double square/curly brackets. The method will return the updated SQL and the list of parameter names.
func (db *DB) processSQL(s string) (string, []string) {
var placeholders []string
count := 0
s = plRegex.ReplaceAllStringFunc(s, func(m string) string {
count++
placeholders = append(placeholders, m[2:len(m)-1])
return db.GeneratePlaceholder(count)
})
s = quoteRegex.ReplaceAllStringFunc(s, func(m string) string {
if m[0] == '{' {
return db.QuoteTableName(m[2 : len(m)-2])
}
return db.QuoteColumnName(m[2 : len(m)-2])
})
return s, placeholders
}
// newBuilder creates a query builder based on the current driver name.
func (db *DB) newBuilder(executor Executor) Builder {
builderFunc, ok := BuilderFuncMap[db.driverName]
if !ok {
builderFunc = NewStandardBuilder
}
return builderFunc(db, executor)
}
// Error returns the error string of Errors.
func (errs Errors) Error() string {
var b bytes.Buffer
for i, e := range errs {
if i > 0 {
b.WriteRune('\n')
}
b.WriteString(e.Error())
}
return b.String()
}
|