File: tests.py

package info (click to toggle)
tds-fdw 2.0.5-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,080 kB
  • sloc: ansic: 5,402; sql: 581; python: 418; makefile: 23; sh: 1
file content (161 lines) | stat: -rw-r--r-- 6,548 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from glob import glob
from json import load
from lib.messages import print_error, print_info
from os import listdir
from os.path import basename, isfile, realpath
from re import match
from psycopg2.extensions import Diagnostics


def version_to_array(version, dbtype):
    """ Convert a version string to a version array, or return an empty string if
        the original string was empty

    Keyword arguments:
    version -- A string with a dot separated version
    dbtype  -- A string with the database type (postgresql|mssql)
    """
    if version != '':
        try:
            version = version.decode('utf-8')
        except (UnicodeDecodeError, AttributeError):
            pass
        # Cleanup version, since Ubuntu decided to add their own
        # versioning starting with PostgreSQL 10.2:
        # '10.2 (Ubuntu 10.2-1.pgdg14.04+1' instead of '10.2'
        version = match('(\d[\.\d]+).*', version).group(1)
        version = version.split('.')
        for i in range(0, len(version)):
            version[i] = int(version[i])
        # To be able to work with MSSQL 2000 and 7.0,
        # see https://sqlserverbuilds.blogspot.ru/
        if dbtype == 'mssql':
            if len(version) == 3:
                version.append(0)
    else:
        version = []
    return(version)


def check_ver(conn, min_ver, max_ver, dbtype):
    """ Check SQL server version against test required max and min versions.

    Keyword arguments:
    conn    -- A db connection (according to Python DB API v2.0)
    min_ver -- A string with the minimum version required
    max_ver -- A string with the maximum version required
    dbtype  -- A string with the database type (postgresql|mssql)
    """
    cursor = conn.cursor()
    if dbtype == 'postgresql':
        sentence = 'SELECT setting FROM pg_settings WHERE name = \'server_version\';'
    elif dbtype == 'mssql':
        sentence = 'SELECT serverproperty(\'ProductVersion\');'
    cursor.execute(sentence)
    server_ver = cursor.fetchone()[0]
    cursor.close()
    min_ver = version_to_array(min_ver, dbtype)
    max_ver = version_to_array(max_ver, dbtype)
    server_ver = version_to_array(server_ver, dbtype)
    if server_ver >= min_ver and (server_ver <= max_ver or len(max_ver) == 0):
        return(True)
    else:
        return(False)

def get_logs_path(conn, dbtype):
    """ Get PostgreSQL logs

    Keyword arguments:
    conn  -- A db connection (according to Python DB API v2.0)
    dbtye -- A string with the database type (postgresql|mssql). mssql will return an empty a array.
    """
    logs = []
    if dbtype == 'mssql':
        return(logs)
    cursor = conn.cursor()
    try:
        cursor.execute("SELECT setting FROM pg_catalog.pg_settings WHERE name = 'data_directory';")
        data_dir = cursor.fetchone()[0]
    except TypeError:
        print_error("The user does not have SUPERUSER access to PostgreSQL.")
        print_error("Cannot access pg_catalog.pg_settings required values, so logs cannot be found")
        return(logs)
    cursor.execute("SELECT setting FROM pg_catalog.pg_settings WHERE name = 'log_directory';")
    log_dir = cursor.fetchone()[0]
    if log_dir[0] != '/':
        log_dir = "%s/%s" % (data_dir, log_dir)
    cursor.execute("SELECT setting FROM pg_catalog.pg_settings WHERE name = 'logging_collector';")
    # No logging collector, add stdout from postmaster (assume stderr is redirected to stdout)
    if cursor.fetchone()[0] == 'off':
        with open("%s/postmaster.pid" % data_dir, "r") as f:
            postmaster_pid = f.readline().rstrip('\n')
        postmaster_log = "/proc/%s/fd/1" % postmaster_pid
        if isfile(postmaster_log):
            logs.append(realpath(postmaster_log))
    # Logging collector enabled
    else:
        # Add stdout from logger (assume stderr is redirected to stdout)
        pids = [pid for pid in listdir('/proc') if pid.isdigit()]
        for pid in pids:
            try:
                cmdline = open('/proc/' + pid + '/cmdline', 'r').read()
                if 'postgres: logger' in cmdline:
                    logger_log = "/proc/%s/fd/2" % pid
                    if isfile(logger_log):
                        logs.append(realpath(logger_log))
            except IOError: # proc has already terminated
                continue
        # Add all files from log_dir
        for f in listdir(log_dir):
            logs.append(realpath(log_dir + '/' + f))
    return(logs)


def run_tests(path, conn, replaces, dbtype, debugging=False, unattended_debugging=False):
    """Run SQL tests over a connection, returns a dict with results.

    Keyword arguments:
    path     -- String with the path having the SQL files for tests
    conn     -- A db connection (according to Python DB API v2.0)
    replaces -- A dict with replaces to perform at testing code
    dbtype   -- A string with the database type (postgresql|mssql)
    """
    files = sorted(glob(path))
    tests = {'total': 0, 'ok': 0, 'errors': 0}
    for fname in files:
        test_file = open('%s.json' % fname.rsplit('.', 1)[0], 'r')
        test_properties = load(test_file)
        test_desc = test_properties['test_desc']
        test_number = basename(fname).split('_')[0]
        req_ver = test_properties['server']['version']
        if check_ver(conn, req_ver['min'], req_ver['max'], dbtype):
            tests['total'] += 1
            f = open(fname, 'r')
            sentence = f.read()
            for key, elem in replaces.items():
                sentence = sentence.replace(key, elem)
            print_info("%s: Testing %s" % (test_number, test_desc))
            if debugging or unattended_debugging:
                print_info("Query:")
                print(sentence)
            try:
                cursor = conn.cursor()
                cursor.execute(sentence)
                conn.commit()
                cursor.close()
                tests['ok'] += 1
            except Exception as e:
                print_error("Error running %s (%s)" % (test_desc, fname))
                print_error("Query:")
                print(sentence)
                try:
                    print_error(e.pgcode)
                    print_error(e.pgerror)
                    for att in [member for member in dir(Diagnostics) if not member.startswith("__")]:
                        print_error("%s: %s"%(att, getattr(e.diag,att)))
                except:
                    print_error(e)
                conn.rollback()
                tests['errors'] += 1
            f.close()
    return(tests)