File: check_sql_script.py

package info (click to toggle)
timescaledb 2.24.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 13,552 kB
  • sloc: ansic: 58,664; sql: 24,761; sh: 1,742; python: 1,254; perl: 78; makefile: 14
file content (139 lines) | stat: -rw-r--r-- 4,077 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python

# Check SQL script components for problematic patterns. This script is
# intended to be run on the scripts that are added to every update script,
# but not the compiled update script or the pre_install scripts.
#
# This script will find patterns that are not idempotent and therefore
# should be moved to the pre_install part.

from pglast import parse_sql
from pglast.visitors import Visitor, Skip, Continue
from pglast.stream import RawStream
import sys
import re
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("filename", type=argparse.FileType("r"), nargs="+")
args = parser.parse_args()


class SQLVisitor(Visitor):
    def __init__(self, file):
        self.errors = 0
        self.file = file
        super().__init__()

    def error(self, node, hint):
        self.errors += 1
        print(
            f"Invalid statement found in sql script({self.file}):\n",
            RawStream()(node),
        )
        print(hint, "\n")

    def visit_RawStmt(self, _ancestors, _node):
        # Statements are nested in RawStmt so we need to let the visitor descend
        return Continue

    def visit(self, _ancestors, node):
        self.error(node, "Consider moving the statement into a pre_install script")

        # We are only interested in checking top-level statements
        return Skip

    def visit_CommentStmt(self, _ancestors, _node):
        return Skip

    def visit_GrantStmt(self, _ancestors, _node):
        return Skip

    def visit_SelectStmt(self, _ancestors, _node):
        return Skip

    def visit_InsertStmt(self, _ancestors, _node):
        return Skip

    def visit_DeleteStmt(self, _ancestors, _node):
        return Skip

    def visit_DoStmt(self, _ancestors, _node):
        return Skip

    def visit_CreateEventTrigStmt(self, _ancestors, _node):
        return Skip

    def visit_CreateTrigStmt(self, _ancestors, node):
        if not node.replace:
            self.error(node, "Consider using CREATE OR REPLACE TRIGGER")

        return Skip

    def visit_DefineStmt(self, _ancestors, node):
        if not node.replace:
            self.error(node, "Consider using CREATE OR REPLACE")

        return Skip

    def visit_DropStmt(self, _ancestors, node):
        if not node.missing_ok:
            self.error(node, "Consider using DROP IF EXISTS")

        return Skip

    def visit_ViewStmt(self, _ancestors, node):
        if not node.replace:
            self.error(node, "Consider using CREATE OR REPLACE VIEW")

        return Skip

    def visit_CreateFunctionStmt(self, _ancestors, node):
        if not node.replace:
            fn_str = ("FUNCTION", "PROCEDURE")[node.is_procedure is True]
            self.error(node, f"Consider using CREATE OR REPLACE {fn_str}")

        return Skip


# copied from pgspot
def visit_sql(sql, file):
    # @extschema@ is placeholder in extension scripts for
    # the schema the extension gets installed in
    sql = sql.replace("@extschema@", "extschema")
    sql = sql.replace("@extowner@", "extowner")
    sql = sql.replace("@database_owner@", "database_owner")
    # postgres contrib modules are protected by psql meta commands to
    # prevent running extension files in psql.
    # The SQL parser will error on those since they are not valid
    # SQL, so we comment out all psql meta commands before parsing.
    sql = re.sub(r"^\\", "-- \\\\", sql, flags=re.MULTILINE)

    visitor = SQLVisitor(file)
    for stmt in parse_sql(sql):
        visitor(stmt)
    return visitor.errors


def main(args):
    errors = 0
    error_files = []
    for file in args.filename:
        sql = file.read()
        result = visit_sql(sql, file.name)
        if result > 0:
            errors += result
            error_files.append(file.name)

    if errors > 0:
        numbering = "errors" if errors > 1 else "error"
        print(
            f"{errors} {numbering} detected in {len(error_files)} files({', '.join(error_files)})"
        )
        sys.exit(1)
    sys.exit(0)


if __name__ == "__main__":
    main(args)
    sys.exit(0)