File: __init__.py

package info (click to toggle)
dj-database-url 3.1.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 364 kB
  • sloc: python: 723; makefile: 3
file content (253 lines) | stat: -rw-r--r-- 7,733 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import logging
import os
import urllib.parse as urlparse
from collections.abc import Callable
from typing import Any, TypedDict

DEFAULT_ENV = "DATABASE_URL"
ENGINE_SCHEMES: dict[str, "Engine"] = {}


# From https://docs.djangoproject.com/en/stable/ref/settings/#databases
class DBConfig(TypedDict, total=False):
    ATOMIC_REQUESTS: bool
    AUTOCOMMIT: bool
    CONN_MAX_AGE: int | None
    CONN_HEALTH_CHECKS: bool
    DISABLE_SERVER_SIDE_CURSORS: bool
    ENGINE: str
    HOST: str
    NAME: str
    OPTIONS: dict[str, Any]
    PASSWORD: str
    PORT: str | int
    TEST: dict[str, Any]
    TIME_ZONE: str
    USER: str


PostprocessCallable = Callable[[DBConfig], None]
OptionType = int | str | bool


class ParseError(ValueError):
    def __str__(self) -> str:
        return (
            "This string is not a valid url, possibly because some of its parts"
            " is not properly urllib.parse.quote()'ed."
        )


class UnknownSchemeError(ValueError):
    def __init__(self, scheme: str):
        self.scheme = scheme

    def __str__(self) -> str:
        schemes = ", ".join(sorted(ENGINE_SCHEMES.keys()))
        return (
            f"Scheme '{self.scheme}://' is unknown."
            " Did you forget to register custom backend?"
            f" Following schemes have registered backends: {schemes}."
        )


def default_postprocess(parsed_config: DBConfig) -> None:
    pass


class Engine:
    def __init__(
        self,
        backend: str,
        postprocess: PostprocessCallable = default_postprocess,
    ):
        self.backend = backend
        self.postprocess = postprocess


def register(
    scheme: str, backend: str
) -> Callable[[PostprocessCallable], PostprocessCallable]:
    engine = Engine(backend)
    if scheme not in ENGINE_SCHEMES:
        urlparse.uses_netloc.append(scheme)
    ENGINE_SCHEMES[scheme] = engine

    def inner(func: PostprocessCallable) -> PostprocessCallable:
        engine.postprocess = func
        return func

    return inner


register("spatialite", "django.contrib.gis.db.backends.spatialite")
register("mysql-connector", "mysql.connector.django")
register("mysqlgis", "django.contrib.gis.db.backends.mysql")
register("oraclegis", "django.contrib.gis.db.backends.oracle")
register("cockroach", "django_cockroachdb")


@register("sqlite", "django.db.backends.sqlite3")
def default_to_in_memory_db(parsed_config: DBConfig) -> None:
    # mimic sqlalchemy behaviour
    if not parsed_config.get("NAME"):
        parsed_config["NAME"] = ":memory:"


@register("oracle", "django.db.backends.oracle")
@register("mssqlms", "mssql")
@register("mssql", "sql_server.pyodbc")
def stringify_port(parsed_config: DBConfig) -> None:
    parsed_config["PORT"] = str(parsed_config.get("PORT", ""))


@register("mysql", "django.db.backends.mysql")
@register("mysql2", "django.db.backends.mysql")
def apply_ssl_ca(parsed_config: DBConfig) -> None:
    options = parsed_config.get("OPTIONS", {})
    ca = options.pop("ssl-ca", None)
    if ca:
        options["ssl"] = {"ca": ca}


@register("postgres", "django.db.backends.postgresql")
@register("postgresql", "django.db.backends.postgresql")
@register("pgsql", "django.db.backends.postgresql")
@register("postgis", "django.contrib.gis.db.backends.postgis")
@register("redshift", "django_redshift_backend")
@register("timescale", "timescale.db.backends.postgresql")
@register("timescalegis", "timescale.db.backends.postgis")
def apply_current_schema(parsed_config: DBConfig) -> None:
    options = parsed_config.get("OPTIONS", {})
    schema = options.pop("currentSchema", None)
    if schema:
        options["options"] = f"-c search_path={schema}"


def config(
    env: str = DEFAULT_ENV,
    default: str | None = None,
    engine: str | None = None,
    conn_max_age: int | None = 0,
    conn_health_checks: bool = False,
    disable_server_side_cursors: bool = False,
    ssl_require: bool = False,
    test_options: dict[str, Any] | None = None,
) -> DBConfig:
    """Returns configured DATABASE dictionary from DATABASE_URL."""
    s = os.environ.get(env, default)

    if s is None:
        logging.warning(
            "No %s environment variable set, and so no databases setup", env
        )

    if s:
        return parse(
            s,
            engine,
            conn_max_age,
            conn_health_checks,
            disable_server_side_cursors,
            ssl_require,
            test_options,
        )

    return {}


def parse(
    url: str,
    engine: str | None = None,
    conn_max_age: int | None = 0,
    conn_health_checks: bool = False,
    disable_server_side_cursors: bool = False,
    ssl_require: bool = False,
    test_options: dict[str, Any] | None = None,
) -> DBConfig:
    """Parses a database URL and returns configured DATABASE dictionary."""
    settings = _convert_to_settings(
        engine,
        conn_max_age,
        conn_health_checks,
        disable_server_side_cursors,
        ssl_require,
        test_options,
    )

    if url == "sqlite://:memory:":
        # this is a special case, because if we pass this URL into
        # urlparse, urlparse will choke trying to interpret "memory"
        # as a port number
        return {"ENGINE": ENGINE_SCHEMES["sqlite"].backend, "NAME": ":memory:"}
        # note: no other settings are required for sqlite

    try:
        split_result = urlparse.urlsplit(url)
        engine_obj = ENGINE_SCHEMES.get(split_result.scheme)
        if engine_obj is None:
            raise UnknownSchemeError(split_result.scheme)
        path = split_result.path[1:]
        query = urlparse.parse_qs(split_result.query)
        options = {k: _parse_option_values(v) for k, v in query.items()}
        parsed_config: DBConfig = {
            "ENGINE": engine_obj.backend,
            "USER": urlparse.unquote(split_result.username or ""),
            "PASSWORD": urlparse.unquote(split_result.password or ""),
            "HOST": urlparse.unquote(split_result.hostname or ""),
            "PORT": split_result.port or "",
            "NAME": urlparse.unquote(path),
            "OPTIONS": options,
        }
    except UnknownSchemeError:
        raise
    except ValueError:
        raise ParseError() from None

    # Guarantee that config has options, possibly empty, when postprocess() is called
    assert isinstance(parsed_config["OPTIONS"], dict)
    engine_obj.postprocess(parsed_config)

    # Update the final config with any settings passed in explicitly.
    parsed_config["OPTIONS"].update(settings.pop("OPTIONS", {}))
    parsed_config.update(settings)

    if not parsed_config["OPTIONS"]:
        parsed_config.pop("OPTIONS")
    return parsed_config


def _parse_option_values(values: list[str]) -> OptionType | list[OptionType]:
    parsed_values = [_parse_value(v) for v in values]
    return parsed_values[0] if len(parsed_values) == 1 else parsed_values


def _parse_value(value: str) -> OptionType:
    if value.isdigit():
        return int(value)
    if value.lower() in ("true", "false"):
        return value.lower() == "true"
    return value


def _convert_to_settings(
    engine: str | None,
    conn_max_age: int | None,
    conn_health_checks: bool,
    disable_server_side_cursors: bool,
    ssl_require: bool,
    test_options: dict[str, Any] | None,
) -> DBConfig:
    settings: DBConfig = {
        "CONN_MAX_AGE": conn_max_age,
        "CONN_HEALTH_CHECKS": conn_health_checks,
        "DISABLE_SERVER_SIDE_CURSORS": disable_server_side_cursors,
    }
    if engine:
        settings["ENGINE"] = engine
    if ssl_require:
        settings["OPTIONS"] = {}
        settings["OPTIONS"]["sslmode"] = "require"
    if test_options:
        settings["TEST"] = test_options
    return settings