File: environment.py

package info (click to toggle)
osm2pgsql 2.2.0%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,772 kB
  • sloc: cpp: 60,940; python: 1,115; ansic: 763; sh: 25; makefile: 14
file content (162 lines) | stat: -rw-r--r-- 5,861 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# SPDX-License-Identifier: GPL-2.0-or-later
#
# This file is part of osm2pgsql (https://osm2pgsql.org/).
#
# Copyright (C) 2006-2025 by the osm2pgsql developer community.
# For a full list of authors see the git log.
from contextlib import closing
from pathlib import Path
import subprocess
import tempfile
import importlib.util
import io
from importlib.machinery import SourceFileLoader

from behave import *
try:
    import psycopg2 as psycopg
    from psycopg2 import sql
except ImportError:
    import psycopg
    from psycopg import sql


from steps.geometry_factory import GeometryFactory
from steps.replication_server_mock import ReplicationServerMock

TEST_BASE_DIR = (Path(__file__) / '..' / '..').resolve()

# The following parameters can be changed on the command line using
# the -D parameter. Example:
#
#    behave -DBINARY=/tmp/my-builddir/osm2pgsql -DKEEP_TEST_DB
USER_CONFIG = {
    'BINARY': (TEST_BASE_DIR / '..' / 'build' / 'osm2pgsql').resolve(),
    'REPLICATION_SCRIPT': (TEST_BASE_DIR / '..' / 'scripts' / 'osm2pgsql-replication').resolve(),
    'TEST_DATA_DIR': TEST_BASE_DIR / 'data',
    'SRC_DIR': (TEST_BASE_DIR / '..').resolve(),
    'KEEP_TEST_DB': False,
    'TEST_DB': 'osm2pgsql-test',
    'HAVE_TABLESPACE': True,
    'HAVE_PROJ': True
}

use_step_matcher('re')

def _connect_db(context, dbname):
    """ Connect to the given database and return the connection
        object as a context manager that automatically closes.
        Note that the connection does not commit automatically.
    """
    if psycopg.__version__.startswith('2'):
        conn = psycopg.connect(dbname=dbname)
        conn.autocommit = True
        return closing(conn)

    return psycopg.connect(dbname=dbname, autocommit=True)


def _drop_db(context, dbname, recreate_immediately=False):
    """ Drop the database with the given name if it exists.
    """
    with _connect_db(context, 'postgres') as conn:
        with conn.cursor() as cur:
            db = sql.Identifier(dbname)
            cur.execute(sql.SQL('DROP DATABASE IF EXISTS {}').format(db))
            if recreate_immediately:
                cur.execute(sql.SQL('CREATE DATABASE {}').format(db))


def before_all(context):
    # logging setup
    context.config.setup_logging()
    # set up -D options
    for k,v in USER_CONFIG.items():
        context.config.userdata.setdefault(k, v)

    if context.config.userdata['HAVE_TABLESPACE']:
        with _connect_db(context, 'postgres') as conn:
            with conn.cursor() as cur:
                cur.execute("""SELECT spcname FROM pg_tablespace
                               WHERE spcname = 'tablespacetest'""")
                context.config.userdata['HAVE_TABLESPACE'] = cur.rowcount > 0
                cur.execute("""SELECT setting FROM pg_settings
                               WHERE name = 'server_version_num'""")
                context.config.userdata['PG_VERSION'] = int(cur.fetchone()[0])

    # Get the osm2pgsql configuration
    proc = subprocess.Popen([str(context.config.userdata['BINARY']), '--version'],
                            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    _, serr = proc.communicate()
    ver_info = serr.decode('utf-8')
    if proc.returncode != 0:
        raise RuntimeError('Cannot run osm2pgsql')

    if context.config.userdata['HAVE_PROJ']:
        context.config.userdata['HAVE_PROJ'] = 'Proj [disabled]' not in ver_info

    context.test_data_dir = Path(context.config.userdata['TEST_DATA_DIR']).resolve()
    context.default_data_dir = Path(context.config.userdata['SRC_DIR']).resolve()

    # Set up replication script.
    replicationfile = str(Path(context.config.userdata['REPLICATION_SCRIPT']).resolve())
    spec = importlib.util.spec_from_loader('osm2pgsql_replication',
                                           SourceFileLoader('osm2pgsql_replication',
                                                            replicationfile))
    assert spec, f"File not found: {replicationfile}"
    context.osm2pgsql_replication = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(context.osm2pgsql_replication)


def before_scenario(context, scenario):
    """ Set up a fresh, empty test database.
    """
    if 'config.have_proj' in scenario.tags and not context.config.userdata['HAVE_PROJ']:
        scenario.skip("Generic proj library not configured.")

    context.db = use_fixture(test_db, context)
    context.import_file = None
    context.import_data = {'n': [], 'w': [], 'r': []}
    context.osm2pgsql_params = []
    context.workdir = use_fixture(working_directory, context)
    context.geometry_factory = GeometryFactory()
    context.osm2pgsql_replication.ReplicationServer = ReplicationServerMock()
    context.urlrequest_responses = {}

    def _mock_urlopen(request):
        if not request.full_url in context.urlrequest_responses:
            raise urllib.error.URLError('Unknown URL')

        return closing(io.BytesIO(context.urlrequest_responses[request.full_url].encode('utf-8')))

    context.osm2pgsql_replication.urlrequest.urlopen = _mock_urlopen


@fixture
def test_db(context, **kwargs):
    dbname = context.config.userdata['TEST_DB']
    _drop_db(context, dbname, recreate_immediately=True)

    with _connect_db(context, dbname) as conn:

        with conn.cursor() as cur:
            cur.execute('CREATE EXTENSION postgis')
            cur.execute('CREATE EXTENSION hstore')

        yield conn

    if not context.config.userdata['KEEP_TEST_DB']:
        _drop_db(context, dbname)


@fixture
def working_directory(context, **kwargs):
    with tempfile.TemporaryDirectory() as tmpdir:
        yield Path(tmpdir)


def before_tag(context, tag):
    if tag == 'needs-pg-index-includes':
        if context.config.userdata['PG_VERSION'] < 110000:
            context.scenario.skip("No index includes in PostgreSQL < 11")