File: pgxtype.go

package info (click to toggle)
golang-github-jackc-pgtype 1.10.0-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 1,656 kB
  • sloc: sh: 32; makefile: 4
file content (145 lines) | stat: -rw-r--r-- 3,807 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package pgxtype

import (
	"context"
	"errors"

	"github.com/jackc/pgconn"
	"github.com/jackc/pgtype"
	"github.com/jackc/pgx/v4"
)

type Querier interface {
	Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error)
	Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error)
	QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row
}

// LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for
// registration on ci.
func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) {
	var oid uint32

	err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
	if err != nil {
		return pgtype.DataType{}, err
	}

	var typtype string

	err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
	if err != nil {
		return pgtype.DataType{}, err
	}

	switch typtype {
	case "b": // array
		elementOID, err := GetArrayElementOID(ctx, conn, oid)
		if err != nil {
			return pgtype.DataType{}, err
		}

		var element pgtype.ValueTranscoder
		if dt, ok := ci.DataTypeForOID(elementOID); ok {
			if element, ok = dt.Value.(pgtype.ValueTranscoder); !ok {
				return pgtype.DataType{}, errors.New("array element OID not registered as ValueTranscoder")
			}
		}

		newElement := func() pgtype.ValueTranscoder {
			return pgtype.NewValue(element).(pgtype.ValueTranscoder)
		}

		at := pgtype.NewArrayType(typeName, elementOID, newElement)
		return pgtype.DataType{Value: at, Name: typeName, OID: oid}, nil
	case "c": // composite
		fields, err := GetCompositeFields(ctx, conn, oid)
		if err != nil {
			return pgtype.DataType{}, err
		}
		ct, err := pgtype.NewCompositeType(typeName, fields, ci)
		if err != nil {
			return pgtype.DataType{}, err
		}
		return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil
	case "e": // enum
		members, err := GetEnumMembers(ctx, conn, oid)
		if err != nil {
			return pgtype.DataType{}, err
		}
		return pgtype.DataType{Value: pgtype.NewEnumType(typeName, members), Name: typeName, OID: oid}, nil
	default:
		return pgtype.DataType{}, errors.New("unknown typtype")
	}
}

func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) {
	var typelem uint32

	err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
	if err != nil {
		return 0, err
	}

	return typelem, nil
}

// GetCompositeFields gets the fields of a composite type.
func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) {
	var typrelid uint32

	err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
	if err != nil {
		return nil, err
	}

	var fields []pgtype.CompositeTypeField

	rows, err := conn.Query(ctx, `select attname, atttypid
from pg_attribute
where attrelid=$1
order by attnum`, typrelid)
	if err != nil {
		return nil, err
	}

	for rows.Next() {
		var f pgtype.CompositeTypeField
		err := rows.Scan(&f.Name, &f.OID)
		if err != nil {
			return nil, err
		}
		fields = append(fields, f)
	}

	if rows.Err() != nil {
		return nil, rows.Err()
	}

	return fields, nil
}

// GetEnumMembers gets the possible values of the enum by oid.
func GetEnumMembers(ctx context.Context, conn Querier, oid uint32) ([]string, error) {
	members := []string{}

	rows, err := conn.Query(ctx, "select enumlabel from pg_enum where enumtypid=$1 order by enumsortorder", oid)
	if err != nil {
		return nil, err
	}

	for rows.Next() {
		var m string
		err := rows.Scan(&m)
		if err != nil {
			return nil, err
		}
		members = append(members, m)
	}

	if rows.Err() != nil {
		return nil, rows.Err()
	}

	return members, nil
}