1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
# -*- coding: utf-8 -*-
# :Project: pglast — Tests on the stream.py module
# :Created: sab 05 ago 2017 10:31:23 CEST
# :Author: Lele Gaifax <lele@metapensiero.it>
# :License: GNU General Public License version 3 or later
# :Copyright: © 2017, 2018, 2019, 2021, 2022, 2024 Lele Gaifax
#
import pytest
from pglast import ast, parse_sql
from pglast.printers import NODE_PRINTERS, PrinterAlreadyPresentError, SPECIAL_FUNCTIONS
from pglast.printers import node_printer, special_function
from pglast.stream import IndentedStream, OutputStream, RawStream
def test_output_stream():
output = OutputStream()
output.writes('SELECT *')
output.writes(' FROM')
output.writes('table ')
output.writes('WHERE')
output.write('id = 1')
assert output.getvalue() == 'SELECT * FROM table WHERE id = 1'
def test_raw_stream_with_sql():
raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None)
try:
@node_printer(ast.RawStmt)
def raw_stmt(node, output):
output.write('Yeah')
output = RawStream()
result = output('SELECT 1; SELECT 2')
assert result == 'Yeah; Yeah'
finally:
if raw_stmt_printer is not None:
NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer
else:
NODE_PRINTERS.pop(ast.RawStmt, None)
def test_raw_stream():
raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None)
try:
@node_printer(ast.RawStmt)
def raw_stmt(node, output):
output.write('Yeah')
root = parse_sql('SELECT 1')
output = RawStream()
result = output(root)
assert result == 'Yeah'
finally:
if raw_stmt_printer is not None:
NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer
else:
NODE_PRINTERS.pop(ast.RawStmt, None)
def test_raw_stream_invalid_call():
with pytest.raises(ValueError):
RawStream()(1)
def test_indented_stream_with_sql():
raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None)
try:
@node_printer(ast.RawStmt)
def raw_stmt(node, output):
output.write('Yeah')
output = IndentedStream()
result = output('SELECT 1; SELECT 2')
assert result == 'Yeah;\n\nYeah'
output = IndentedStream(separate_statements=False)
result = output('SELECT 1; SELECT 2')
assert result == 'Yeah;\nYeah'
finally:
if raw_stmt_printer is not None:
NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer
else:
NODE_PRINTERS.pop(ast.RawStmt, None)
def test_separate_statements():
"""Separate statements by ``separate_statements`` (int) newlines."""
raw_stmt_printer = NODE_PRINTERS.pop(ast.RawStmt, None)
try:
@node_printer(ast.RawStmt)
def raw_stmt(node, output):
output.write('Yeah')
output = IndentedStream(separate_statements=2)
result = output('SELECT 1; SELECT 2')
assert result == 'Yeah;\n\n\nYeah'
finally:
if raw_stmt_printer is not None:
NODE_PRINTERS[ast.RawStmt] = raw_stmt_printer
else:
NODE_PRINTERS.pop(ast.RawStmt, None)
def test_special_function():
output = RawStream(special_functions=True)
assert output.get_printer_for_function('foo.test_function') is None
try:
@special_function('foo.test_function')
def test(node, output):
pass
assert output.get_printer_for_function('foo.test_function') is test
with pytest.raises(PrinterAlreadyPresentError):
@special_function('foo.test_function')
def test1(node, output):
pass
@special_function('foo.test_function', override=True)
def test_function(node, output):
output.print_list(node.args, '-')
result = output('SELECT foo.test_function(x, "Y") FROM sometable')
assert result == 'SELECT x - "Y" FROM sometable'
finally:
SPECIAL_FUNCTIONS.pop('foo.test_function')
|