1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
// Package pgx defines and registers usql's PostgreSQL PGX driver.
//
// See: https://github.com/jackc/pgx
package pgx
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/stdlib" // DRIVER
"github.com/xo/dburl"
"github.com/xo/usql/drivers"
"github.com/xo/usql/drivers/metadata"
pgmeta "github.com/xo/usql/drivers/metadata/postgres"
"github.com/xo/usql/text"
)
func init() {
drivers.Register("pgx", drivers.Driver{
AllowDollar: true,
AllowMultilineComments: true,
LexerName: "postgres",
Open: func(ctx context.Context, u *dburl.URL, stdout, stderr func() io.Writer) (func(string, string) (*sql.DB, error), error) {
return func(_, dsn string) (*sql.DB, error) {
config, err := pgx.ParseConfig(dsn)
if err != nil {
return nil, err
}
config.OnNotice = func(_ *pgconn.PgConn, notice *pgconn.Notice) {
out := stderr()
fmt.Fprintln(out, notice.Severity+": ", notice.Message)
if notice.Hint != "" {
fmt.Fprintln(out, "HINT: ", notice.Hint)
}
}
config.OnNotification = func(_ *pgconn.PgConn, notification *pgconn.Notification) {
var payload string
if notification.Payload != "" {
payload = fmt.Sprintf(text.NotificationPayload, notification.Payload)
}
fmt.Fprintln(stdout(), fmt.Sprintf(text.NotificationReceived, notification.Channel, payload, notification.PID))
}
// NOTE: as opposed to the github.com/lib/pq driver, this
// NOTE: driver has a "prefer" mode that is enabled by default.
// NOTE: as such there is no logic here to try to reconnect as
// NOTE: in the postgres driver.
return stdlib.OpenDB(*config), nil
}, nil
},
Version: func(ctx context.Context, db drivers.DB) (string, error) {
var ver string
err := db.QueryRowContext(ctx, `SHOW server_version`).Scan(&ver)
if err != nil {
return "", err
}
return "PostgreSQL " + ver, nil
},
ChangePassword: func(db drivers.DB, user, newpw, _ string) error {
_, err := db.Exec(`ALTER USER ` + user + ` PASSWORD '` + newpw + `'`)
return err
},
Err: func(err error) (string, string) {
var e *pgconn.PgError
if errors.As(err, &e) {
return e.Code, e.Message
}
return "", err.Error()
},
IsPasswordErr: func(err error) bool {
var e *pgconn.PgError
if errors.As(err, &e) {
return e.Code == "28P01"
}
return false
},
NewMetadataReader: pgmeta.NewReader(),
NewMetadataWriter: func(db drivers.DB, w io.Writer, opts ...metadata.ReaderOption) metadata.Writer {
return metadata.NewDefaultWriter(pgmeta.NewReader()(db, opts...))(db, w)
},
Copy: func(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
conn, err := db.Conn(context.Background())
if err != nil {
return 0, fmt.Errorf("failed to get a connection from pool: %w", err)
}
leftParen := strings.IndexRune(table, '(')
colQuery := "SELECT * FROM " + table + " WHERE 1=0"
if leftParen != -1 {
// pgx's CopyFrom needs a slice of column names and splitting them by a comma is unreliable
// so evaluate the possible expressions against the target table
colQuery = "SELECT " + table[leftParen+1:len(table)-1] + " FROM " + table[:leftParen] + " WHERE 1=0"
table = table[:leftParen]
}
colStmt, err := db.PrepareContext(ctx, colQuery)
if err != nil {
return 0, fmt.Errorf("failed to prepare query to determine target table columns: %w", err)
}
colRows, err := colStmt.QueryContext(ctx)
if err != nil {
return 0, fmt.Errorf("failed to execute query to determine target table columns: %w", err)
}
columns, err := colRows.Columns()
if err != nil {
return 0, fmt.Errorf("failed to fetch target table columns: %w", err)
}
clen := len(columns)
crows := ©Rows{
rows: rows,
values: make([]interface{}, clen),
}
for i := 0; i < clen; i++ {
crows.values[i] = new(interface{})
}
var n int64
err = conn.Raw(func(driverConn interface{}) error {
conn := driverConn.(*stdlib.Conn).Conn()
n, err = conn.CopyFrom(ctx, pgx.Identifier(strings.SplitN(table, ".", 2)), columns, crows)
return err
})
return n, err
},
})
}
type copyRows struct {
rows *sql.Rows
values []interface{}
}
func (r *copyRows) Next() bool {
return r.rows.Next()
}
func (r *copyRows) Values() ([]interface{}, error) {
err := r.rows.Scan(r.values...)
actuals := make([]interface{}, len(r.values))
for i, v := range r.values {
actuals[i] = *(v.(*interface{}))
}
return actuals, err
}
func (r *copyRows) Err() error {
return r.rows.Err()
}
|