File: test_printers.py

package info (click to toggle)
pglast 7.11-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,368 kB
  • sloc: python: 13,349; sql: 2,405; makefile: 159
file content (123 lines) | stat: -rw-r--r-- 3,538 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# -*- coding: utf-8 -*-
# :Project:   pglast — Tests on the printers registry
# :Created:   sab 05 ago 2017 10:31:23 CEST
# :Author:    Lele Gaifax <lele@metapensiero.it>
# :License:   GNU General Public License version 3 or later
# :Copyright: © 2017, 2018, 2019, 2021, 2022, 2024 Lele Gaifax
#

import warnings

import pytest

from pglast import ast, enums, prettify
from pglast.printers import IntEnumPrinter, NODE_PRINTERS, PrinterAlreadyPresentError
from pglast.printers import get_printer_for_node, node_printer


def test_registry():
    with pytest.raises(ValueError):
        get_printer_for_node(None)

    with pytest.raises(ValueError):
        @node_printer()
        def missing_node(node, output):
            pass

    with pytest.raises(ValueError):
        @node_printer(1)
        def invalid_node(node, output):
            pass

    with pytest.raises(ValueError):
        @node_printer(ast.RawStmt, ast.SelectStmt, ast.UpdateStmt)
        def too_many_nodes(node, output):
            pass

    with pytest.raises(ValueError):
        @node_printer('RawStmt')
        def invalid_node(node, output):
            pass

    with pytest.raises(ValueError):
        @node_printer((ast.RawStmt, 'foo'), ast.SelectStmt)
        def invalid_parents(node, output):
            pass

    raw_stmt = NODE_PRINTERS.pop(ast.RawStmt)
    try:
        @node_printer(ast.RawStmt)
        def raw(node, output):
            pass

        assert NODE_PRINTERS[ast.RawStmt] is raw

        with pytest.raises(PrinterAlreadyPresentError):
            @node_printer(ast.RawStmt)
            def other_raw(node, output):
                pass
    finally:
        NODE_PRINTERS[ast.RawStmt] = raw_stmt


def test_prettify_safety_belt():
    raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None)
    try:
        @node_printer(ast.RawStmt)
        def raw_stmt_1(node, output):
            output.write('Yeah')

        output = prettify('select 42')
        assert output == 'Yeah'

        with warnings.catch_warnings(record=True) as w:
            output = prettify('select 42', safety_belt=True)
            assert output == 'select 42'
            assert 'Detected a bug' in str(w[0].message)

        @node_printer(ast.RawStmt, override=True)
        def raw_stmt_2(node, output):
            output.write('select 1')

        output = prettify('select 42')
        assert output == 'select 1'

        with warnings.catch_warnings(record=True) as w:
            output = prettify('select 42', safety_belt=True)
            assert output == 'select 42'
            assert 'Detected a non-cosmetic difference' in str(w[0].message)
    finally:
        if raw_stmt_printer is not None:
            NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer
        else:
            NODE_PRINTERS.pop(ast.RawStmt, None)


def test_int_enum_printer():
    class LockWaitPrinter(IntEnumPrinter):
        enum = enums.LockWaitPolicy

        def LockWaitBlock(self, node, output):
            output.append('block')

    lwp = LockWaitPrinter()
    result = []
    lwp('LockWaitBlock', object(), result)
    assert result == ['block']

    with pytest.raises(NotImplementedError):
        lwp('LockWaitError', object(), result)

    with pytest.raises(ValueError):
        lwp('FooBar', object(), result)

    lwp(None, object(), result)
    assert result == ['block']*2


def test_not_int_enum_printer():
    class NotIntEnum(IntEnumPrinter):
        enum = enums.FunctionParameterMode

    with pytest.raises(ValueError):
        NotIntEnum()