1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
|
from __future__ import annotations
from typing import Any
import sqlparse
from sqlparse.sql import Comparison, Identifier, Where, Token
from .parseutils import last_word, extract_tables, find_prev_keyword
from .special.main import parse_special_command
def suggest_type(full_text: str, text_before_cursor: str) -> list[dict[str, Any]]:
"""Takes the full_text that is typed so far and also the text before the
cursor to suggest completion type and scope.
Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
A scope for a column category will be a list of tables.
"""
word_before_cursor = last_word(text_before_cursor, include="many_punctuations")
identifier: Identifier | None = None
# here should be removed once sqlparse has been fixed
try:
# If we've partially typed a word then word_before_cursor won't be an empty
# string. In that case we want to remove the partially typed string before
# sending it to the sqlparser. Otherwise the last token will always be the
# partially typed string which renders the smart completion useless because
# it will always return the list of keywords as completion.
if word_before_cursor:
if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"):
parsed = sqlparse.parse(text_before_cursor)
else:
parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)])
# word_before_cursor may include a schema qualification, like
# "schema_name.partial_name" or "schema_name.", so parse it
# separately
p = sqlparse.parse(word_before_cursor)[0]
if p.tokens and isinstance(p.tokens[0], Identifier):
identifier = p.tokens[0]
else:
parsed = sqlparse.parse(text_before_cursor)
except (TypeError, AttributeError):
return [{"type": "keyword"}]
if len(parsed) > 1:
# Multiple statements being edited -- isolate the current one by
# cumulatively summing statement lengths to find the one that bounds the
# current position
current_pos = len(text_before_cursor)
stmt_start, stmt_end = 0, 0
for statement in parsed:
stmt_len = len(str(statement))
stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
if stmt_end >= current_pos:
text_before_cursor = full_text[stmt_start:current_pos]
full_text = full_text[stmt_start:]
break
elif parsed:
# A single statement
statement = parsed[0]
else:
# The empty string
statement = None
# Check for special commands and handle those separately
if statement:
# Be careful here because trivial whitespace is parsed as a statement,
# but the statement won't have a first token
tok1 = statement.token_first()
if tok1 and tok1.value.startswith("."):
return suggest_special(text_before_cursor)
elif tok1 and tok1.value.startswith("\\"):
return suggest_special(text_before_cursor)
elif tok1 and tok1.value.startswith("source"):
return suggest_special(text_before_cursor)
elif text_before_cursor and text_before_cursor.startswith(".open "):
return suggest_special(text_before_cursor)
last_token = statement and statement.token_prev(len(statement.tokens))[1] or ""
return suggest_based_on_last_token(last_token, text_before_cursor, full_text, identifier)
def suggest_special(text: str) -> list[dict[str, Any]]:
text = text.lstrip()
cmd, _, arg = parse_special_command(text)
if cmd == text:
# Trying to complete the special command itself
return [{"type": "special"}]
if cmd in ("\\u", "\\r"):
return [{"type": "database"}]
if cmd in ("\\T"):
return [{"type": "table_format"}]
if cmd in ["\\f", "\\fs", "\\fd"]:
return [{"type": "favoritequery"}]
if cmd in ["\\d", "\\dt", "\\dt+", ".schema", ".indexes"]:
return [
{"type": "table", "schema": []},
{"type": "view", "schema": []},
{"type": "schema"},
]
if cmd in ["\\.", "source", ".open", ".read"]:
return [{"type": "file_name"}]
if cmd in [".import"]:
# Usage: .import filename table
if _expecting_arg_idx(arg, text) == 1:
return [{"type": "file_name"}]
else:
return [{"type": "table", "schema": []}]
return [{"type": "keyword"}, {"type": "special"}]
def _expecting_arg_idx(arg: str, text: str) -> int:
"""Return the index of expecting argument.
>>> _expecting_arg_idx("./da", ".import ./da")
1
>>> _expecting_arg_idx("./data.csv", ".import ./data.csv")
1
>>> _expecting_arg_idx("./data.csv", ".import ./data.csv ")
2
>>> _expecting_arg_idx("./data.csv t", ".import ./data.csv t")
2
"""
args = arg.split()
return len(args) + int(text[-1].isspace())
def suggest_based_on_last_token(
token: str | Token | None,
text_before_cursor: str,
full_text: str,
identifier: Identifier | None,
) -> list[dict[str, Any]]:
if isinstance(token, str):
token_v = token.lower()
elif isinstance(token, Comparison):
# If 'token' is a Comparison type such as
# 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
# token.value on the comparison type will only return the lhs of the
# comparison. In this case a.id. So we need to do token.tokens to get
# both sides of the comparison and pick the last token out of that
# list.
token_v = token.tokens[-1].value.lower()
elif isinstance(token, Where):
# sqlparse groups all tokens from the where clause into a single token
# list. This means that token.value may be something like
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
# suggestions in complicated where clauses correctly
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
else:
assert token is not None
token_v = token.value.lower()
def is_operand(x: str | None) -> bool:
if not x:
return False
return any([x.endswith(op) for op in ["+", "-", "*", "/"]])
if not token:
return [{"type": "keyword"}, {"type": "special"}]
elif token_v.endswith("("):
p = sqlparse.parse(text_before_cursor)[0]
if p.tokens and isinstance(p.tokens[-1], Where):
# Four possibilities:
# 1 - Parenthesized clause like "WHERE foo AND ("
# Suggest columns/functions
# 2 - Function call like "WHERE foo("
# Suggest columns/functions
# 3 - Subquery expression like "WHERE EXISTS ("
# Suggest keywords, in order to do a subquery
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
# Suggest columns/functions AND keywords. (If we wanted to be
# really fancy, we could suggest only array-typed columns)
column_suggestions = suggest_based_on_last_token("where", text_before_cursor, full_text, identifier)
# Check for a subquery expression (cases 3 & 4)
where = p.tokens[-1]
idx, prev_tok = where.token_prev(len(where.tokens) - 1)
if isinstance(prev_tok, Comparison):
# e.g. "SELECT foo FROM bar WHERE foo = ANY("
prev_tok = prev_tok.tokens[-1]
prev_tok = prev_tok.value.lower()
if prev_tok == "exists":
return [{"type": "keyword"}]
else:
return column_suggestions
# Get the token before the parens
idx, prev_tok = p.token_prev(len(p.tokens) - 1)
if prev_tok and prev_tok.value and prev_tok.value.lower() == "using":
# tbl1 INNER JOIN tbl2 USING (col1, col2)
tables = extract_tables(full_text)
# suggest columns that are present in more than one table
return [{"type": "column", "tables": tables, "drop_unique": True}]
elif p.token_first().value.lower() == "select":
# If the lparen is preceded by a space chances are we're about to
# do a sub-select.
if last_word(text_before_cursor, "all_punctuations").startswith("("):
return [{"type": "keyword"}]
elif p.token_first().value.lower() == "show":
return [{"type": "show"}]
# We're probably in a function argument list
return [{"type": "column", "tables": extract_tables(full_text)}]
elif token_v in ("set", "order by", "distinct"):
return [{"type": "column", "tables": extract_tables(full_text)}]
elif token_v == "as":
# Don't suggest anything for an alias
return []
elif token_v in ("show"):
return [{"type": "show"}]
elif token_v in ("to",):
p = sqlparse.parse(text_before_cursor)[0]
if p.token_first().value.lower() == "change":
return [{"type": "change"}]
else:
return [{"type": "user"}]
elif token_v in ("user", "for"):
return [{"type": "user"}]
elif token_v in ("select", "where", "having"):
# Check for a table alias or schema qualification
parent = (identifier and identifier.get_parent_name()) or []
tables = extract_tables(full_text)
if parent:
tables = [t for t in tables if identifies(parent, *t)]
return [
{"type": "column", "tables": tables},
{"type": "table", "schema": parent},
{"type": "view", "schema": parent},
{"type": "function", "schema": parent},
]
else:
aliases = [alias or table for (schema, table, alias) in tables]
return [
{"type": "column", "tables": tables},
{"type": "function", "schema": []},
{"type": "alias", "aliases": aliases},
{"type": "keyword"},
]
elif (token_v.endswith("join") and isinstance(token, Token) and token.is_keyword) or (
token_v in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain")
):
schema = (identifier and identifier.get_parent_name()) or []
# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
suggest = [{"type": "table", "schema": schema}]
if not schema:
# Suggest schemas
suggest.insert(0, {"type": "schema"})
# Only tables can be TRUNCATED, otherwise suggest views
if token_v != "truncate":
suggest.append({"type": "view", "schema": schema})
return suggest
elif token_v in ("table", "view", "function"):
# E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
rel_type = token_v
schema = (identifier and identifier.get_parent_name()) or []
if schema:
return [{"type": rel_type, "schema": schema}]
else:
return [{"type": "schema"}, {"type": rel_type, "schema": []}]
elif token_v == "on":
tables = extract_tables(full_text) # [(schema, table, alias), ...]
parent = (identifier and identifier.get_parent_name()) or []
if parent:
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
tables = [t for t in tables if identifies(parent, *t)]
return [
{"type": "column", "tables": tables},
{"type": "table", "schema": parent},
{"type": "view", "schema": parent},
{"type": "function", "schema": parent},
]
else:
# ON <suggestion>
# Use table alias if there is one, otherwise the table name
aliases = [alias or table for (schema, table, alias) in tables]
suggest = [{"type": "alias", "aliases": aliases}]
# The lists of 'aliases' could be empty if we're trying to complete
# a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
# In that case we just suggest all tables.
if not aliases:
suggest.append({"type": "table", "schema": parent})
return suggest
elif token_v in ("use", "database", "template", "connect"):
# "\c <db", "use <db>", "DROP DATABASE <db>",
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
return [{"type": "database"}]
elif token_v == "tableformat":
return [{"type": "table_format"}]
elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]:
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier)
else:
return []
else:
return [{"type": "keyword"}]
def identifies(id: Any, schema: str | None, table: str, alias: str | None) -> bool:
return (id == alias) or (id == table) or (schema is not None and (id == schema + "." + table))
|