1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
|
"""Type helpers for adaptation of parameters."""
from __future__ import annotations
from datetime import date, datetime, time, timedelta, tzinfo
from json import dumps as jsonencode
from re import compile as regex
from time import localtime
from typing import Any, Callable, Iterable
from uuid import UUID as Uuid # noqa: N811
from .typecode import TypeCode
__all__ = [
'ARRAY',
'BINARY',
'BOOL',
'DATE',
'DATETIME',
'FLOAT',
'HSTORE',
'INTEGER',
'INTERVAL',
'JSON',
'LONG',
'MONEY',
'NUMBER',
'NUMERIC',
'RECORD',
'ROWID',
'SMALLINT',
'STRING',
'TIME',
'TIMESTAMP',
'UUID',
'ArrayType',
'Date',
'DateFromTicks',
'DbType',
'RecordType',
'Time',
'TimeFromTicks',
'Timestamp',
'TimestampFromTicks'
]
class DbType(frozenset):
"""Type class for a couple of PostgreSQL data types.
PostgreSQL is object-oriented: types are dynamic.
We must thus use type names as internal type codes.
"""
def __new__(cls, values: str | Iterable[str]) -> DbType:
"""Create new type object."""
if isinstance(values, str):
values = values.split()
return super().__new__(cls, values)
def __eq__(self, other: Any) -> bool:
"""Check whether types are considered equal."""
if isinstance(other, str):
if other.startswith('_'):
other = other[1:]
return other in self
return super().__eq__(other)
def __ne__(self, other: Any) -> bool:
"""Check whether types are not considered equal."""
if isinstance(other, str):
if other.startswith('_'):
other = other[1:]
return other not in self
return super().__ne__(other)
class ArrayType:
"""Type class for PostgreSQL array types."""
def __eq__(self, other: Any) -> bool:
"""Check whether arrays are equal."""
if isinstance(other, str):
return other.startswith('_')
return isinstance(other, ArrayType)
def __ne__(self, other: Any) -> bool:
"""Check whether arrays are different."""
if isinstance(other, str):
return not other.startswith('_')
return not isinstance(other, ArrayType)
class RecordType:
"""Type class for PostgreSQL record types."""
def __eq__(self, other: Any) -> bool:
"""Check whether records are equal."""
if isinstance(other, TypeCode):
return other.type == 'c'
if isinstance(other, str):
return other == 'record'
return isinstance(other, RecordType)
def __ne__(self, other: Any) -> bool:
"""Check whether records are different."""
if isinstance(other, TypeCode):
return other.type != 'c'
if isinstance(other, str):
return other != 'record'
return not isinstance(other, RecordType)
# Mandatory type objects defined by DB-API 2 specs:
STRING = DbType('char bpchar name text varchar')
BINARY = DbType('bytea')
NUMBER = DbType('int2 int4 serial int8 float4 float8 numeric money')
DATETIME = DbType('date time timetz timestamp timestamptz interval'
' abstime reltime') # these are very old
ROWID = DbType('oid')
# Additional type objects (more specific):
BOOL = DbType('bool')
SMALLINT = DbType('int2')
INTEGER = DbType('int2 int4 int8 serial')
LONG = DbType('int8')
FLOAT = DbType('float4 float8')
NUMERIC = DbType('numeric')
MONEY = DbType('money')
DATE = DbType('date')
TIME = DbType('time timetz')
TIMESTAMP = DbType('timestamp timestamptz')
INTERVAL = DbType('interval')
UUID = DbType('uuid')
HSTORE = DbType('hstore')
JSON = DbType('json jsonb')
# Type object for arrays (also equate to their base types):
ARRAY = ArrayType()
# Type object for records (encompassing all composite types):
RECORD = RecordType()
# Mandatory type helpers defined by DB-API 2 specs:
def Date(year: int, month: int, day: int) -> date: # noqa: N802
"""Construct an object holding a date value."""
return date(year, month, day)
def Time(hour: int, minute: int = 0, # noqa: N802
second: int = 0, microsecond: int = 0,
tzinfo: tzinfo | None = None) -> time:
"""Construct an object holding a time value."""
return time(hour, minute, second, microsecond, tzinfo)
def Timestamp(year: int, month: int, day: int, # noqa: N802
hour: int = 0, minute: int = 0,
second: int = 0, microsecond: int = 0,
tzinfo: tzinfo | None = None) -> datetime:
"""Construct an object holding a time stamp value."""
return datetime(year, month, day, hour, minute,
second, microsecond, tzinfo)
def DateFromTicks(ticks: float | None) -> date: # noqa: N802
"""Construct an object holding a date value from the given ticks value."""
return Date(*localtime(ticks)[:3])
def TimeFromTicks(ticks: float | None) -> time: # noqa: N802
"""Construct an object holding a time value from the given ticks value."""
return Time(*localtime(ticks)[3:6])
def TimestampFromTicks(ticks: float | None) -> datetime: # noqa: N802
"""Construct an object holding a time stamp from the given ticks value."""
return Timestamp(*localtime(ticks)[:6])
class Binary(bytes):
"""Construct an object capable of holding a binary (long) string value."""
# Additional type helpers for PyGreSQL:
def Interval(days: int | float, # noqa: N802
hours: int | float = 0, minutes: int | float = 0,
seconds: int | float = 0, microseconds: int | float = 0
) -> timedelta:
"""Construct an object holding a time interval value."""
return timedelta(days, hours=hours, minutes=minutes,
seconds=seconds, microseconds=microseconds)
Uuid = Uuid # Construct an object holding a UUID value
class Hstore(dict):
"""Wrapper class for marking hstore values."""
_re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
_re_escape = regex(r'(["\\])')
@classmethod
def _quote(cls, s: Any) -> Any:
if s is None:
return 'NULL'
if not isinstance(s, str):
s = str(s)
if not s:
return '""'
quote = cls._re_quote.search(s)
s = cls._re_escape.sub(r'\\\1', s)
if quote:
s = f'"{s}"'
return s
def __str__(self) -> str:
"""Create a printable representation of the hstore value."""
q = self._quote
return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items())
class Json:
"""Construct a wrapper for holding an object serializable to JSON."""
def __init__(self, obj: Any,
encode: Callable[[Any], str] | None = None) -> None:
"""Initialize the JSON object."""
self.obj = obj
self.encode = encode or jsonencode
def __str__(self) -> str:
"""Create a printable representation of the JSON object."""
obj = self.obj
if isinstance(obj, str):
return obj
return self.encode(obj)
class Literal:
"""Construct a wrapper for holding a literal SQL string."""
def __init__(self, sql: str) -> None:
"""Initialize literal SQL string."""
self.sql = sql
def __str__(self) -> str:
"""Return a printable representation of the SQL string."""
return self.sql
__pg_repr__ = __str__
|