File: steps_db.py

package info (click to toggle)
osm2pgsql 2.2.0%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,772 kB
  • sloc: cpp: 60,940; python: 1,115; ansic: 763; sh: 25; makefile: 14
file content (315 lines) | stat: -rw-r--r-- 10,540 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# SPDX-License-Identifier: GPL-2.0-or-later
#
# This file is part of osm2pgsql (https://osm2pgsql.org/).
#
# Copyright (C) 2006-2025 by the osm2pgsql developer community.
# For a full list of authors see the git log.
"""
Steps that query the database.
"""
import math
import re
from typing import Iterable

try:
    from psycopg2 import sql
except ImportError:
    from psycopg import sql


@given("the database schema (?P<schema>.+)")
def create_db_schema(context, schema):
    with context.db.cursor() as cur:
        cur.execute("CREATE SCHEMA " + schema)


@when("deleting table (?P<table>.+)")
def delete_table(context, table):
    with context.db.cursor() as cur:
        cur.execute("DROP TABLE " + table)


@then(r"table (?P<table>.+) has (?P<row_num>\d+) rows?(?P<has_where> with condition)?")
def db_table_row_count(context, table, row_num, has_where):
    assert table_exists(context.db, table)

    query = sql.SQL("SELECT count(*) FROM {}").format(sql.Identifier(*table.split('.', 2)))

    if has_where:
        query = sql.SQL("{} WHERE {}").format(query, sql.SQL(context.text))

    actual = scalar(context.db, query)

    assert actual == int(row_num),\
           f"Table {table}: expected {row_num} rows, got {actual}"


@then(r"the sum of '(?P<formula>.+)' in table (?P<table>.+) is (?P<result>\d+)(?P<has_where> with condition)?")
def db_table_sum_up(context, table, formula, result, has_where):
    assert table_exists(context.db, table)

    query = sql.SQL("SELECT round(sum({})) FROM {}")\
               .format(sql.SQL(formula), sql.Identifier(*table.split('.', 2)))

    if has_where:
        query = sql.SQL("{} WHERE {}").format(query, sql.SQL(context.text))


    actual = scalar(context.db, query)

    assert actual == int(result),\
           f"Table {table}: expected sum {result}, got {actual}"


@then("there (?:is|are) (?P<exists>no )?tables? (?P<tables>.+)")
def db_table_existance(context, exists, tables):
    for table in tables.split(','):
        table = table.strip()
        if exists == 'no ':
            assert not table_exists(context.db, table), f"Table '{table}' unexpectedly found"
        else:
            assert table_exists(context.db, table), f"Table '{table}' not found"


@then("table (?P<table>.+) contains(?P<exact> exactly)?")
def db_check_table_content(context, table, exact):
    assert table_exists(context.db, table)

    rows = sql.SQL(', '.join(h.rsplit('@')[0] for h in context.table.headings))

    with context.db.cursor() as cur:
        cur.execute(sql.SQL("SELECT {} FROM {}")
                       .format(rows, sql.Identifier(*table.split('.', 2))))

        actuals = list(DBRow(r, context.table.headings, context.geometry_factory) for r in cur)

    linenr = 1
    for row in context.table.rows:
        try:
            actuals.remove(row)
        except ValueError:
            assert False,\
                   f"{linenr}. entry not found in table. Full content:\n{actuals}"
        linenr += 1

    assert not exact or not actuals,\
           f"Unexpected lines in row:\n{actuals}"

@then("(?P<query>SELECT .*)")
def db_check_sql_statement(context, query):
    with context.db.cursor() as cur:
        cur.execute(query)

        actuals = list(DBRow(r, context.table.headings, context.geometry_factory) for r in cur)

    linenr = 1
    for row in context.table.rows:
        assert any(r == row for r in actuals),\
               f"{linenr}. entry not found in table. Full content:\n{actuals}"
        linenr += 1


### Helper functions and classes

def scalar(conn, sql, args=None):
    with conn.cursor() as cur:
        cur.execute(sql, args)

        assert cur.rowcount == 1
        return cur.fetchone()[0]

def table_exists(conn, table):
    if '.' in table:
        schema, tablename = table.split('.', 2)
    else:
        schema = 'public'
        tablename = table

    num = scalar(conn, """SELECT count(*) FROM pg_tables
                          WHERE tablename = %s AND schemaname = %s""",
                (tablename, schema))
    if num == 1:
        return True

    num = scalar(conn, """SELECT count(*) FROM pg_views
                          WHERE viewname = %s AND schemaname = %s""",
                (tablename, schema))
    return num == 1


class DBRow:

    def __init__(self, row, headings, factory):
        self.data = []
        for value, head in zip(row, headings):
            if '@' in head:
                _, props = head.rsplit('@', 2)
            else:
                props = None

            if isinstance(value, float):
                self.data.append(DBValueFloat(value, props))
            elif value is None:
                self.data.append(None)
            elif head.lower().startswith('st_astext('):
                self.data.append(DBValueGeometry(value, props, factory))
            elif props == 'fullmatch':
                self.data.append(DBValueRegex(value))
            else:
                self.data.append(str(value))

    def __eq__(self, other):
        if not isinstance(other, Iterable):
            return False

        return all((a is None) if b == 'NULL' else (a == b)
                   for a, b in zip(self.data, other))

    def __repr__(self):
        return '\n[' + ', '.join(str(s) for s in self.data) + ']'


class DBValueGeometry:

    def __init__(self, value, props, factory):
        self.precision = float(props) if props else 0.0001
        self.orig_value = value
        self.set_coordinates(value)
        self.factory = factory

    def set_coordinates(self, value):
        if value.startswith('GEOMETRYCOLLECTION('):
            geoms = []
            remain = value[19:-1]
            while remain:
                _, value, remain = self._parse_simple_wkt(remain)
                remain = remain[1:] # delete comma
                geoms.append(value)
            self.geom_type = 'GEOMETRYCOLLECTION'
            self.value = geoms
        else:
            self.geom_type, self.value, remain = self._parse_simple_wkt(value)
            if remain:
                raise RuntimeError('trailing content for geometry: ' + value)

    def _parse_simple_wkt(self, value):
        m = re.fullmatch(r'(MULTI)?(POINT|LINESTRING|POLYGON)\(([^A-Z]*)\)(.*)', value)
        if not m:
            raise RuntimeError(f'Unparsable WKT: {value}')
        geom_type = (m[1] or '') + m[2]
        if m[1] == 'MULTI':
            splitup = m[3][1:-1].split('),(')
            if m[2] == 'POINT':
                value = [self._parse_wkt_coord(c) for c in splitup]
            elif m[2] == 'LINESTRING':
                value = [self._parse_wkt_line(c) for c in splitup]
            elif m[2] == 'POLYGON':
                value = [[self._parse_wkt_line(ln) for ln in poly[1:-1].split('),(')]
                         for poly in splitup]
        else:
            if m[2] == 'POINT':
                value = self._parse_wkt_coord(m[3])
            elif m[2] == 'LINESTRING':
                value = self._parse_wkt_line(m[3])
            elif m[2] == 'POLYGON':
                value = [self._parse_wkt_line(ln) for ln in m[3][1:-1].split('),(')]

        return geom_type, value, m[4]

    def _parse_wkt_coord(self, coord):
        return tuple(DBValueFloat(float(f.strip()), self.precision) for f in coord.split())

    def _parse_wkt_line(self, coords):
        return [self._parse_wkt_coord(pt) for pt in coords.split(',')]

    def __eq__(self, other):
        if other.startswith('[') and other.endswith(']'):
            gtype = 'MULTI'
            toparse = other[1:-1].split(';')
        elif other.startswith('{') and other.endswith('}'):
            gtype = 'GEOMETRYCOLLECTION'
            toparse = other[1:-1].split(';')
        else:
            gtype = None
            toparse = [other]

        geoms = []
        for sub in toparse:
            sub = sub.strip()
            if sub.find(',') < 0:
                geoms.append(self._parse_input_coord(sub))
                if gtype is None:
                    gtype = 'POINT'
                elif gtype.startswith('MULTI'):
                    if gtype == 'MULTI':
                        gtype = 'MULTIPOINT'
                    elif gtype != 'MULTIPOINT':
                        raise RuntimeError('MULTI* geometry with different geometry types is not supported.')
            elif sub.find('(') < 0:
                geoms.append(self._parse_input_line(sub))
                if gtype is None:
                    gtype = 'LINESTRING'
                elif gtype.startswith('MULTI'):
                    if gtype == 'MULTI':
                        gtype = 'MULTILINESTRING'
                    elif gtype != 'MULTILINESTRING':
                        raise RuntimeError('MULTI* geometry with different geometry types is not supported.')
            else:
                geoms.append([self._parse_input_line(ln) for ln in sub.strip()[1:-1].split('),(')])
                if gtype is None:
                    gtype = 'POLYGON'
                elif gtype.startswith('MULTI'):
                    if gtype == 'MULTI':
                        gtype = 'MULTIPOLYGON'
                    elif gtype != 'MULTIPOLYGON':
                        raise RuntimeError('MULTI* geometry with different geometry types is not supported.')

        if not gtype.startswith('MULTI') and gtype != 'GEOMETRYCOLLECTION':
            geoms = geoms[0]

        return gtype == self.geom_type and self.value == geoms

    def _parse_input_coord(self, other):
        coords = other.split(' ')
        if len(coords) == 1:
            return self.factory.grid_node(int(coords[0]))
        if len(coords) == 2:
            return tuple(float(c.strip()) for c in coords)

        raise RuntimeError(f'Bad coordinate: {other}')

    def _parse_input_line(self, other):
        return [self._parse_input_coord(pt.strip()) for pt in other.split(',')]

    def __repr__(self):
        return self.orig_value


class DBValueFloat:

    def __init__(self, value, props):
        self.precision = float(props) if props else 0.0001
        self.value = value

    def __eq__(self, other):
        try:
            fother = float(other)
        except:
            return False

        return math.isclose(self.value, fother, rel_tol=self.precision)

    def __repr__(self):
        return repr(self.value)


class DBValueRegex:

    def __init__(self, value):
        self.value = str(value)

    def __eq__(self, other):
        return re.fullmatch(str(other), self.value) is not None

    def __repr__(self):
        return repr(self.value)