1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
package pgxtype
import (
"context"
"errors"
"github.com/jackc/pgconn"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4"
)
type Querier interface {
Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row
}
// LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for
// registration on ci.
func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) {
var oid uint32
err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
if err != nil {
return pgtype.DataType{}, err
}
var typtype string
err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
if err != nil {
return pgtype.DataType{}, err
}
switch typtype {
case "b": // array
elementOID, err := GetArrayElementOID(ctx, conn, oid)
if err != nil {
return pgtype.DataType{}, err
}
var element pgtype.ValueTranscoder
if dt, ok := ci.DataTypeForOID(elementOID); ok {
if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok {
return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder")
}
}
newElement := func() pgtype.ValueTranscoder {
return pgtype.NewValue(element).(pgtype.ValueTranscoder)
}
at := pgtype.NewArrayType(typeName, elementOID, newElement)
return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil
case "c": // composite
fields, err := GetCompositeFields(ctx, conn, oid)
if err != nil {
return pgtype.DataType{}, err
}
ct, err := pgtype.NewCompositeType(typeName, fields, ci)
if err != nil {
return pgtype.DataType{}, err
}
return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil
case "e": // enum
members, err := GetEnumMembers(ctx, conn, oid)
if err != nil {
return pgtype.DataType{}, err
}
return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil
default:
return pgtype.DataType{}, errors.New("unknown typtype")
}
}
func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) {
var typelem uint32
err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
if err != nil {
return 0, err
}
return typelem, nil
}
// GetCompositeFields gets the fields of a composite type.
func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) {
var typrelid uint32
err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
if err != nil {
return nil, err
}
var fields []pgtype.CompositeTypeField
rows, err := conn.Query(ctx, `select attname, atttypid
from pg_attribute
where attrelid=$1
order by attnum`, typrelid)
if err != nil {
return nil, err
}
for rows.Next() {
var f pgtype.CompositeTypeField
err := rows.Scan(&f.Name, &f.OID)
if err != nil {
return nil, err
}
fields = append(fields, f)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return fields, nil
}
// GetEnumMembers gets the possible values of the enum by oid.
func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) {
members := []string{}
rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid)
if err != nil {
return nil, err
}
for rows.Next() {
var m string
err := rows.Scan(&m)
if err != nil {
return nil, err
}
members = append(members, m)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return members, nil
}
|