File: __init__.py

package info (click to toggle)
geoalchemy2 0.18.2-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,692 kB
  • sloc: python: 10,003; sh: 159; makefile: 133
file content (280 lines) | stat: -rw-r--r-- 8,459 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import os
import platform
import re
import shutil

import pytest
import shapely
from packaging.version import parse as parse_version
from shapely.wkb import dumps
from sqlalchemy import __version__ as SA_VERSION
from sqlalchemy import create_engine
from sqlalchemy import select as raw_select
from sqlalchemy import text
from sqlalchemy.event import listen
from sqlalchemy.exc import OperationalError
from sqlalchemy.sql import func

from geoalchemy2 import load_spatialite
from geoalchemy2 import load_spatialite_gpkg
from geoalchemy2.elements import WKBElement
from geoalchemy2.elements import WKTElement
from geoalchemy2.shape import to_shape


class test_only_with_dialects:
    def __init__(self, *dialects):
        self.tested_dialects = dialects

    def __call__(self, test_obj):
        test_obj.tested_dialects = self.tested_dialects
        return test_obj


def get_postgis_major_version(bind):
    try:
        return parse_version(bind.execute(func.postgis_lib_version()).scalar()).major
    except OperationalError:
        return parse_version("0").major


def get_postgres_major_version(bind):
    try:
        return re.match(
            r"([0-9]*)\.([0-9]*).*",
            bind.execute(text("""SELECT current_setting('server_version');""")).scalar(),
        ).group(1)
    except OperationalError:
        return "0"


def skip_sqla_lt_2():
    return pytest.mark.skipif(
        parse_version(SA_VERSION) < parse_version("2"),
        reason="requires SQLAlchemy >= 2",
    )


def skip_postgis1(bind):
    if get_postgis_major_version(bind) == 1:
        pytest.skip("requires PostGIS != 1")


def skip_postgis2(bind):
    if get_postgis_major_version(bind) == 2:
        pytest.skip("requires PostGIS != 2")


def skip_postgis3(bind):
    if get_postgis_major_version(bind) == 3:
        pytest.skip("requires PostGIS != 3")


def skip_case_insensitivity():
    return pytest.mark.skipif(
        parse_version(SA_VERSION) < parse_version("1.3.4"),
        reason="Case-insensitivity is only available for sqlalchemy>=1.3.4",
    )


def skip_pg12_sa1217(bind):
    if (
        parse_version(SA_VERSION) < parse_version("1.2.17")
        and int(get_postgres_major_version(bind)) >= 12
    ):
        pytest.skip("Reflection for PostgreSQL-12 is only supported by sqlalchemy>=1.2.17")


def skip_pypy(msg=None):
    if platform.python_implementation() == "PyPy":
        pytest.skip(msg if msg is not None else "Incompatible with PyPy")


def select(args):
    if parse_version(SA_VERSION) < parse_version("1.4"):
        return raw_select(args)
    else:
        return raw_select(*args)


def format_wkt(wkt):
    return wkt.replace(", ", ",")


def copy_and_connect_sqlite_db(input_db, tmp_db, engine_echo, dialect):
    if "SPATIALITE_LIBRARY_PATH" not in os.environ:
        pytest.skip("SPATIALITE_LIBRARY_PATH is not defined, skip SpatiaLite tests")

    shutil.copyfile(input_db, tmp_db)

    print("INPUT DB:", input_db)
    print("TEST DB:", tmp_db)

    db_url = f"{dialect}:///{tmp_db}"
    engine = create_engine(
        db_url,
        echo=engine_echo,
        execution_options={"schema_translate_map": {"gis": None}},
        plugins=["geoalchemy2"],
    )

    if dialect == "gpkg":
        listen(engine, "connect", load_spatialite_gpkg)
    else:
        listen(engine, "connect", load_spatialite)

    return engine


def get_versions(conn):
    """Get all versions."""
    versions = {}
    dialect_name = conn.dialect.name
    if dialect_name == "postgresql":
        db_version = "PostGIS_Full_Version()"
        geos_version = "PostGIS_GEOS_Version()"
        proj_version = "PostGIS_PROJ_Version()"
        proj_path = ""
    elif dialect_name in ["mysql", "mariadb"]:
        db_version = "VERSION()"
        geos_version = ""
        proj_version = ""
        proj_path = ""
    else:
        db_version = "spatialite_version()"
        geos_version = "geos_version()"
        proj_version = "proj_version()"
        proj_path = "PROJ_GetDatabasePath()"

    versions["dialect_name"] = dialect_name
    versions["db_version"] = (
        conn.execute(text(f"SELECT {db_version};")).fetchone()[0] if db_version else ""
    )
    versions["geos_version"] = (
        conn.execute(text(f"SELECT {geos_version};")).fetchone()[0] if geos_version else ""
    )
    versions["proj_version"] = (
        conn.execute(text(f"SELECT {proj_version};")).fetchone()[0] if proj_version else ""
    )
    versions["proj_path"] = (
        conn.execute(text(f"SELECT {proj_path};")).fetchone()[0] if proj_path else ""
    )
    try:
        versions["shapely"] = shapely.__version__
    except AttributeError:
        versions["shapely"] = ""
    return versions


def print_versions(versions):
    """Print the provided versions."""
    print("#########################################")
    print(f"Versions for the {versions['dialect_name']} dialect")
    if versions["db_version"]:
        print("db_version:", versions["db_version"])
    if versions["geos_version"]:
        print("geos_version:", versions["geos_version"])
    if versions["proj_version"]:
        print("proj_version:", versions["proj_version"])
    if versions["proj_path"]:
        print("proj_path:", versions["proj_path"])
    if versions["shapely"]:
        print("shapely:", versions["shapely"])
    print("#########################################")


def check_indexes(conn, dialect_name, expected, table_name):
    """Check that actual indexes are equal to the expected ones."""
    index_query = {
        "postgresql": text(
            f"""SELECT indexname, indexdef
            FROM pg_indexes
            WHERE
                tablename = '{table_name}'
            ORDER BY indexname;"""
        ),
        "sqlite": text(
            f"""SELECT *
            FROM geometry_columns
            WHERE f_table_name = '{table_name}'
            ORDER BY f_table_name, f_geometry_column;"""
        ),
        "geopackage": text(
            f"""SELECT table_name, column_name, extension_name
            FROM gpkg_extensions
            WHERE table_name = '{table_name}' and extension_name = 'gpkg_rtree_index'
            """
        ),
    }

    # Query to check the indexes
    actual_indexes = conn.execute(index_query[dialect_name]).fetchall()

    expected_indexes = expected[dialect_name]
    if dialect_name == "postgresql":
        expected_indexes = [(i[0], re.sub("\n *", " ", i[1])) for i in expected_indexes]

    try:
        assert actual_indexes == expected_indexes
    except AssertionError as exc:
        print("###############################################")

        print("Expected indexes:")
        for i in expected_indexes:
            print(i)

        print("Actual indexes:")
        for i in actual_indexes:
            print(i)

        print("###############################################")

        raise exc


def create_wkt_points(N=50):
    """Create a list of points for benchmarking."""
    points = []
    for i in range(N):
        for j in range(N):
            wkt = f"POINT({i / N} {j / N})"
            points.append(wkt)
    return points


def create_points(N, convert_wkb=False, extended=False, raw=False):
    points = create_wkt_points(N)
    print(f"Number of points to insert: {len(points)}")

    if convert_wkb:
        if not extended:
            # Convert WKT to WKB
            points = [
                shapely.io.to_wkb(to_shape(WKTElement(point)), flavor="iso") for point in points
            ]
            print(f"Converted points to WKB: {len(points)}")
        else:
            # Convert WKT to EWKB
            points = [
                dumps(to_shape(WKTElement(point)), flavor="extended", srid=4326) for point in points
            ]
            print(f"Converted points to EWKB: {len(points)}")
        if not raw:
            # Convert WKB string to WKBElement
            points = [WKBElement(point) for point in points]
            print(f"Converted points to WKBElement: {len(points)}")
    else:
        if extended:
            # Convert WKT to EWKT
            points = ["SRID=4326; " + point for point in points]
        if not raw:
            # Convert WKT to WKTElement
            points = [WKTElement(point) for point in points]
            print(f"Converted points to WKTElement: {len(points)}")

    if raw:
        print("Example data:", points[0])
    else:
        print("Example data:", points[0], "=>", points[0].data)

    return points