1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
|
package testfixtures
import (
"database/sql"
"fmt"
)
// MySQL is the MySQL helper for this package
type MySQL struct {
baseHelper
}
func (*MySQL) paramType() int {
return paramTypeQuestion
}
func (*MySQL) quoteKeyword(str string) string {
return fmt.Sprintf("`%s`", str)
}
func (*MySQL) databaseName(db *sql.DB) (dbName string) {
db.QueryRow("SELECT DATABASE()").Scan(&dbName)
return
}
func (h *MySQL) tableNames(db *sql.DB) ([]string, error) {
query := `
SELECT table_name
FROM information_schema.tables
WHERE table_schema=?;
`
rows, err := db.Query(query, h.databaseName(db))
if err != nil {
return nil, err
}
defer rows.Close()
var tables []string
for rows.Next() {
var table string
if err = rows.Scan(&table); err != nil {
return nil, err
}
tables = append(tables, table)
}
return tables, nil
}
func (h *MySQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
// re-enable after load
defer db.Exec("SET FOREIGN_KEY_CHECKS = 1")
tx, err := db.Begin()
if err != nil {
return err
}
if _, err = tx.Exec("SET FOREIGN_KEY_CHECKS = 0"); err != nil {
return err
}
if err = loadFn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
|