File: __init__.py

package info (click to toggle)
geoalchemy2 0.15.2-2
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 1,572 kB
  • sloc: python: 8,731; makefile: 133; sh: 132
file content (186 lines) | stat: -rw-r--r-- 5,269 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import platform
import re
import shutil
import sys

import pytest
from packaging.version import parse as parse_version
from sqlalchemy import __version__ as SA_VERSION
from sqlalchemy import create_engine
from sqlalchemy import select as raw_select
from sqlalchemy import text
from sqlalchemy.event import listen
from sqlalchemy.exc import OperationalError
from sqlalchemy.sql import func

from geoalchemy2 import load_spatialite
from geoalchemy2 import load_spatialite_gpkg


class test_only_with_dialects:
    def __init__(self, *dialects):
        self.tested_dialects = dialects

    def __call__(self, test_obj):
        test_obj.tested_dialects = self.tested_dialects
        return test_obj


def get_postgis_major_version(bind):
    try:
        return parse_version(bind.execute(func.postgis_lib_version()).scalar()).major
    except OperationalError:
        return parse_version("0").major


def get_postgres_major_version(bind):
    try:
        return re.match(
            r"([0-9]*)\.([0-9]*).*",
            bind.execute(text("""SELECT current_setting('server_version');""")).scalar(),
        ).group(1)
    except OperationalError:
        return "0"


def skip_postgis1(bind):
    if get_postgis_major_version(bind) == 1:
        pytest.skip("requires PostGIS != 1")


def skip_postgis2(bind):
    if get_postgis_major_version(bind) == 2:
        pytest.skip("requires PostGIS != 2")


def skip_postgis3(bind):
    if get_postgis_major_version(bind) == 3:
        pytest.skip("requires PostGIS != 3")


def skip_case_insensitivity():
    return pytest.mark.skipif(
        parse_version(SA_VERSION) < parse_version("1.3.4"),
        reason="Case-insensitivity is only available for sqlalchemy>=1.3.4",
    )


def skip_pg12_sa1217(bind):
    if (
        parse_version(SA_VERSION) < parse_version("1.2.17")
        and int(get_postgres_major_version(bind)) >= 12
    ):
        pytest.skip("Reflection for PostgreSQL-12 is only supported by sqlalchemy>=1.2.17")


def skip_pypy(msg=None):
    if platform.python_implementation() == "PyPy":
        pytest.skip(msg if msg is not None else "Incompatible with PyPy")


def select(args):
    if parse_version(SA_VERSION) < parse_version("1.4"):
        return raw_select(args)
    else:
        return raw_select(*args)


def format_wkt(wkt):
    return wkt.replace(", ", ",")


def copy_and_connect_sqlite_db(input_db, tmp_db, engine_echo, dialect):
    if "SPATIALITE_LIBRARY_PATH" not in os.environ:
        pytest.skip("SPATIALITE_LIBRARY_PATH is not defined, skip SpatiaLite tests")

    shutil.copyfile(input_db, tmp_db)

    print("INPUT DB:", input_db)
    print("TEST DB:", tmp_db)

    db_url = f"{dialect}:///{tmp_db}"
    engine = create_engine(
        db_url, echo=engine_echo, execution_options={"schema_translate_map": {"gis": None}}
    )

    if dialect == "gpkg":
        listen(engine, "connect", load_spatialite_gpkg)
    else:
        listen(engine, "connect", load_spatialite)

    with engine.begin() as connection:
        print(
            "SPATIALITE VERSION:",
            connection.execute(text("SELECT spatialite_version();")).fetchone()[0],
        )
        print(
            "GEOS VERSION:",
            connection.execute(text("SELECT geos_version();")).fetchone()[0],
        )
        if sys.version_info.minor > 7:
            print(
                "PROJ VERSION:",
                connection.execute(text("SELECT proj_version();")).fetchone()[0],
            )
            print(
                "PROJ DB PATH:",
                connection.execute(text("SELECT PROJ_GetDatabasePath();")).fetchone()[0],
            )

    return engine


def check_indexes(conn, dialect_name, expected, table_name):
    """Check that actual indexes are equal to the expected ones."""
    index_query = {
        "postgresql": text(
            """SELECT indexname, indexdef
            FROM pg_indexes
            WHERE
                tablename = '{}'
            ORDER BY indexname;""".format(
                table_name
            )
        ),
        "sqlite": text(
            """SELECT *
            FROM geometry_columns
            WHERE f_table_name = '{}'
            ORDER BY f_table_name, f_geometry_column;""".format(
                table_name
            )
        ),
        "geopackage": text(
            """SELECT table_name, column_name, extension_name
            FROM gpkg_extensions
            WHERE table_name = '{}' and extension_name = 'gpkg_rtree_index'
            """.format(
                table_name
            )
        ),
    }

    # Query to check the indexes
    actual_indexes = conn.execute(index_query[dialect_name]).fetchall()

    expected_indexes = expected[dialect_name]
    if dialect_name == "postgresql":
        expected_indexes = [(i[0], re.sub("\n *", " ", i[1])) for i in expected_indexes]

    try:
        assert actual_indexes == expected_indexes
    except AssertionError as exc:
        print("###############################################")

        print("Expected indexes:")
        for i in expected_indexes:
            print(i)

        print("Actual indexes:")
        for i in actual_indexes:
            print(i)

        print("###############################################")

        raise exc