1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
|
from __future__ import unicode_literals
import re
import sys
import logging
import click
import io
import shlex
import sqlparse
import psycopg
from os.path import expanduser
from .namedqueries import NamedQueries
from . import export
from .main import show_extra_help_command, special_command
NAMED_QUERY_PLACEHOLDERS = frozenset({"$1", "$*", "$@"})
DEFAULT_WATCH_SECONDS = 2
_logger = logging.getLogger(__name__)
@export
def editor_command(command):
"""
Is this an external editor command? (\\e or \\ev)
:param command: string
Returns the specific external editor command found.
"""
# It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
# for both conditions.
stripped = command.strip()
for sought in ("\\e ", "\\ev ", "\\ef "):
if stripped.startswith(sought):
return sought.strip()
for sought in ("\\e",):
if stripped.endswith(sought):
return sought
@export
def get_filename(sql):
if sql.strip().startswith("\\e"):
command, _, filename = sql.partition(" ")
return filename.strip() or None
@export
@show_extra_help_command(
"\\watch",
f"\\watch [sec={DEFAULT_WATCH_SECONDS}]",
"Execute query every `sec` seconds.",
)
def get_watch_command(command):
match = re.match(r"(.*?)[\s]*\\watch(\s+\d+)?\s*;?\s*$", command, re.DOTALL)
if match:
groups = match.groups(default=f"{DEFAULT_WATCH_SECONDS}")
return groups[0], int(groups[1])
return None, None
@export
def get_editor_query(sql):
"""Get the query part of an editor command."""
sql = sql.strip()
# The reason we can't simply do .strip('\e') is that it strips characters,
# not a substring. So it'll strip "e" in the end of the sql also!
# Ex: "select * from style\e" -> "select * from styl".
pattern = re.compile(r"(^\\e|\\e$)")
while pattern.search(sql):
sql = pattern.sub("", sql)
return sql
@export
def open_external_editor(filename=None, sql=None, editor=None):
"""
Open external editor, wait for the user to type in his query,
return the query.
:return: list with one tuple, query as first element.
"""
message = None
filename = filename.strip().split(" ", 1)[0] if filename else None
sql = sql or ""
MARKER = "# Type your query above this line.\n"
# Populate the editor buffer with the partial sql (if available) and a
# placeholder comment.
query = click.edit(
"{sql}\n\n{marker}".format(sql=sql, marker=MARKER),
filename=filename,
extension=".sql",
editor=editor,
)
if filename:
try:
query = read_from_file(filename)
except IOError:
message = "Error reading file: %s." % filename
if query is not None:
query = query.split(MARKER, 1)[0].rstrip("\n")
else:
# Don't return None for the caller to deal with.
# Empty string is ok.
query = sql
return (query, message)
def read_from_file(path):
with io.open(expanduser(path), encoding="utf-8") as f:
contents = f.read()
return contents
def _index_of_file_name(tokenlist):
for idx, token in reversed(list(enumerate(tokenlist[:-2]))):
if token.is_keyword and token.value.upper() in ("TO", "FROM"):
return idx + 2
raise Exception("Missing keyword in \\copy command. Either TO or FROM is required.")
@special_command(
"\\copy",
"\\copy [tablename] to/from [filename]",
"Copy data between a file and a table.",
case_sensitive=False,
)
def copy(cur, pattern, verbose):
"""Copies table data to/from files"""
# Replace the specified file destination with STDIN or STDOUT
parsed = sqlparse.parse(pattern)
tokens = parsed[0].tokens
idx = _index_of_file_name(tokens)
file_name = tokens[idx].value
before_file_name = "".join(t.value for t in tokens[:idx])
after_file_name = "".join(t.value for t in tokens[idx + 1 :])
direction = tokens[idx - 2].value.upper()
replacement_file_name = "STDIN" if direction == "FROM" else "STDOUT"
query = f"{before_file_name} {replacement_file_name} {after_file_name}"
open_mode = "r" if direction == "FROM" else "wb"
if file_name.startswith("'") and file_name.endswith("'"):
file = io.open(expanduser(file_name.strip("'")), mode=open_mode)
elif "stdin" in file_name.lower():
file = sys.stdin.buffer
elif "stdout" in file_name.lower():
file = sys.stdout.buffer
else:
raise Exception("Enclose filename in single quotes")
if direction == "FROM":
with cur.copy("copy " + query) as pgcopy:
while True:
data = file.read(8192)
if not data:
break
pgcopy.write(data)
else:
with cur.copy("copy " + query) as pgcopy:
for data in pgcopy:
file.write(bytes(data))
if cur.description:
headers = [x.name for x in cur.description]
return [(None, None, headers, cur.statusmessage)]
else:
return [(None, None, None, cur.statusmessage)]
def subst_favorite_query_args(query, args):
"""replace positional parameters ($1,$2,...$n) in query."""
is_query_with_aggregation = ("$*" in query) or ("$@" in query)
# In case of arguments aggregation we replace all positional arguments until the
# first one not present in the query. Then we aggregate all the remaining ones and
# replace the placeholder with them.
for idx, val in enumerate(args, start=1):
subst_var = "$" + str(idx)
if subst_var not in query:
if is_query_with_aggregation:
# remove consumed arguments ( - 1 to include current value)
args = args[idx - 1 :]
break
return [
None,
"query does not have substitution parameter " + subst_var + ":\n " + query,
]
query = query.replace(subst_var, val)
# we consumed all arguments
else:
args = []
if is_query_with_aggregation and not args:
return [None, "missing substitution for $* or $@ in query:\n" + query]
if "$*" in query:
query = query.replace("$*", ", ".join(args))
elif "$@" in query:
query = query.replace("$@", ", ".join(map("'{}'".format, args)))
match = re.search("\\$\\d+", query)
if match:
return [
None,
"missing substitution for " + match.group(0) + " in query:\n " + query,
]
return [query, None]
@special_command("\\n", "\\n[+] [name] [param1 param2 ...]", "List or execute named queries.")
def execute_named_query(cur, pattern, **_):
"""Returns (title, rows, headers, status)"""
if pattern == "":
return list_named_queries(True)
params = shlex.split(pattern)
pattern = params.pop(0)
query = NamedQueries.instance.get(pattern)
title = "> {}".format(query)
if query is None:
message = "No named query: {}".format(pattern)
return [(None, None, None, message)]
try:
if any(p in query for p in NAMED_QUERY_PLACEHOLDERS):
query, params = subst_favorite_query_args(query, params)
if query is None:
raise Exception("Bad arguments\n" + params)
cur.execute(query)
except psycopg.errors.SyntaxError:
if "%s" in query:
raise Exception('Bad arguments: please use "$1", "$2", etc. for named queries instead of "%s"')
else:
raise
except (IndexError, TypeError):
raise Exception("Bad arguments")
if cur.description:
headers = [x.name for x in cur.description]
return [(title, cur, headers, cur.statusmessage)]
else:
return [(title, None, None, cur.statusmessage)]
def list_named_queries(verbose):
"""List of all named queries.
Returns (title, rows, headers, status)"""
if not verbose:
rows = [[r] for r in NamedQueries.instance.list()]
headers = ["Name"]
else:
headers = ["Name", "Query"]
rows = [[r, NamedQueries.instance.get(r)] for r in NamedQueries.instance.list()]
if not rows:
status = NamedQueries.instance.usage
else:
status = ""
return [("", rows, headers, status)]
@special_command("\\np", "\\np name_pattern", "Print a named query.")
def get_named_query(pattern, **_):
"""Get a named query that matches name_pattern.
The named pattern can be a regular expression. Returns (title,
rows, headers, status)
"""
usage = "Syntax: \\np name.\n\n" + NamedQueries.instance.usage
if not pattern:
return [(None, None, None, usage)]
name = pattern.strip()
if not name:
return [(None, None, None, usage + "Err: A name is required.")]
headers = ["Name", "Query"]
rows = [(r, NamedQueries.instance.get(r)) for r in NamedQueries.instance.list() if re.search(name, r)]
status = ""
if not rows:
status = "No match found"
return [("", rows, headers, status)]
@special_command("\\ns", "\\ns name query", "Save a named query.")
def save_named_query(pattern, **_):
"""Save a new named query.
Returns (title, rows, headers, status)"""
usage = "Syntax: \\ns name query.\n\n" + NamedQueries.instance.usage
if not pattern:
return [(None, None, None, usage)]
name, _, query = pattern.partition(" ")
# If either name or query is missing then print the usage and complain.
if (not name) or (not query):
return [(None, None, None, usage + "Err: Both name and query are required.")]
NamedQueries.instance.save(name, query)
return [(None, None, None, "Saved.")]
@special_command("\\nd", "\\nd [name]", "Delete a named query.")
def delete_named_query(pattern, **_):
"""Delete an existing named query."""
usage = "Syntax: \\nd name.\n\n" + NamedQueries.instance.usage
if not pattern:
return [(None, None, None, usage)]
status = NamedQueries.instance.delete(pattern)
return [(None, None, None, status)]
|