File: adapt.py

package info (click to toggle)
pygresql 1%3A6.1.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,944 kB
  • sloc: python: 15,052; ansic: 5,730; makefile: 16; sh: 10
file content (261 lines) | stat: -rw-r--r-- 7,449 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
"""Type helpers for adaptation of parameters."""

from __future__ import annotations

from datetime import date, datetime, time, timedelta, tzinfo
from json import dumps as jsonencode
from re import compile as regex
from time import localtime
from typing import Any, Callable, Iterable
from uuid import UUID as Uuid  # noqa: N811

from .typecode import TypeCode

__all__ = [
    'ARRAY',
    'BINARY',
    'BOOL',
    'DATE',
    'DATETIME',
    'FLOAT',
    'HSTORE',
    'INTEGER',
    'INTERVAL',
    'JSON',
    'LONG',
    'MONEY',
    'NUMBER',
    'NUMERIC',
    'RECORD',
    'ROWID',
    'SMALLINT',
    'STRING',
    'TIME',
    'TIMESTAMP',
    'UUID',
    'ArrayType',
    'Date',
    'DateFromTicks',
    'DbType',
    'RecordType',
    'Time',
    'TimeFromTicks',
    'Timestamp',
    'TimestampFromTicks'

]


class DbType(frozenset):
    """Type class for a couple of PostgreSQL data types.

    PostgreSQL is object-oriented: types are dynamic.
    We must thus use type names as internal type codes.
    """

    def __new__(cls, values: str | Iterable[str]) -> DbType:
        """Create new type object."""
        if isinstance(values, str):
            values = values.split()
        return super().__new__(cls, values)

    def __eq__(self, other: Any) -> bool:
        """Check whether types are considered equal."""
        if isinstance(other, str):
            if other.startswith('_'):
                other = other[1:]
            return other in self
        return super().__eq__(other)

    def __ne__(self, other: Any) -> bool:
        """Check whether types are not considered equal."""
        if isinstance(other, str):
            if other.startswith('_'):
                other = other[1:]
            return other not in self
        return super().__ne__(other)


class ArrayType:
    """Type class for PostgreSQL array types."""

    def __eq__(self, other: Any) -> bool:
        """Check whether arrays are equal."""
        if isinstance(other, str):
            return other.startswith('_')
        return isinstance(other, ArrayType)

    def __ne__(self, other: Any) -> bool:
        """Check whether arrays are different."""
        if isinstance(other, str):
            return not other.startswith('_')
        return not isinstance(other, ArrayType)


class RecordType:
    """Type class for PostgreSQL record types."""

    def __eq__(self, other: Any) -> bool:
        """Check whether records are equal."""
        if isinstance(other, TypeCode):
            return other.type == 'c'
        if isinstance(other, str):
            return other == 'record'
        return isinstance(other, RecordType)

    def __ne__(self, other: Any) -> bool:
        """Check whether records are different."""
        if isinstance(other, TypeCode):
            return other.type != 'c'
        if isinstance(other, str):
            return other != 'record'
        return not isinstance(other, RecordType)


# Mandatory type objects defined by DB-API 2 specs:

STRING = DbType('char bpchar name text varchar')
BINARY = DbType('bytea')
NUMBER = DbType('int2 int4 serial int8 float4 float8 numeric money')
DATETIME = DbType('date time timetz timestamp timestamptz interval'
                ' abstime reltime')  # these are very old
ROWID = DbType('oid')


# Additional type objects (more specific):

BOOL = DbType('bool')
SMALLINT = DbType('int2')
INTEGER = DbType('int2 int4 int8 serial')
LONG = DbType('int8')
FLOAT = DbType('float4 float8')
NUMERIC = DbType('numeric')
MONEY = DbType('money')
DATE = DbType('date')
TIME = DbType('time timetz')
TIMESTAMP = DbType('timestamp timestamptz')
INTERVAL = DbType('interval')
UUID = DbType('uuid')
HSTORE = DbType('hstore')
JSON = DbType('json jsonb')

# Type object for arrays (also equate to their base types):

ARRAY = ArrayType()

# Type object for records (encompassing all composite types):

RECORD = RecordType()


# Mandatory type helpers defined by DB-API 2 specs:

def Date(year: int, month: int, day: int) -> date:  # noqa: N802
    """Construct an object holding a date value."""
    return date(year, month, day)


def Time(hour: int, minute: int = 0,  # noqa: N802
         second: int = 0, microsecond: int = 0,
         tzinfo: tzinfo | None = None) -> time:
    """Construct an object holding a time value."""
    return time(hour, minute, second, microsecond, tzinfo)


def Timestamp(year: int, month: int, day: int,  # noqa: N802
              hour: int = 0, minute: int = 0,
              second: int = 0, microsecond: int = 0,
              tzinfo: tzinfo | None = None) -> datetime:
    """Construct an object holding a time stamp value."""
    return datetime(year, month, day, hour, minute,
                    second, microsecond, tzinfo)


def DateFromTicks(ticks: float | None) -> date:  # noqa: N802
    """Construct an object holding a date value from the given ticks value."""
    return Date(*localtime(ticks)[:3])


def TimeFromTicks(ticks: float | None) -> time:  # noqa: N802
    """Construct an object holding a time value from the given ticks value."""
    return Time(*localtime(ticks)[3:6])


def TimestampFromTicks(ticks: float | None) -> datetime:  # noqa: N802
    """Construct an object holding a time stamp from the given ticks value."""
    return Timestamp(*localtime(ticks)[:6])


class Binary(bytes):
    """Construct an object capable of holding a binary (long) string value."""


# Additional type helpers for PyGreSQL:

def Interval(days: int | float,  # noqa: N802
             hours: int | float = 0, minutes: int | float = 0,
             seconds: int | float = 0, microseconds: int | float = 0
             ) -> timedelta:
    """Construct an object holding a time interval value."""
    return timedelta(days, hours=hours, minutes=minutes,
                     seconds=seconds, microseconds=microseconds)


Uuid = Uuid  # Construct an object holding a UUID value


class Hstore(dict):
    """Wrapper class for marking hstore values."""

    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
    _re_escape = regex(r'(["\\])')

    @classmethod
    def _quote(cls, s: Any) -> Any:
        if s is None:
            return 'NULL'
        if not isinstance(s, str):
            s = str(s)
        if not s:
            return '""'
        quote = cls._re_quote.search(s)
        s = cls._re_escape.sub(r'\\\1', s)
        if quote:
            s = f'"{s}"'
        return s

    def __str__(self) -> str:
        """Create a printable representation of the hstore value."""
        q = self._quote
        return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items())


class Json:
    """Construct a wrapper for holding an object serializable to JSON."""

    def __init__(self, obj: Any,
                 encode: Callable[[Any], str] | None = None) -> None:
        """Initialize the JSON object."""
        self.obj = obj
        self.encode = encode or jsonencode

    def __str__(self) -> str:
        """Create a printable representation of the JSON object."""
        obj = self.obj
        if isinstance(obj, str):
            return obj
        return self.encode(obj)


class Literal:
    """Construct a wrapper for holding a literal SQL string."""

    def __init__(self, sql: str) -> None:
        """Initialize literal SQL string."""
        self.sql = sql

    def __str__(self) -> str:
        """Return a printable representation of the SQL string."""
        return self.sql

    __pg_repr__ = __str__