Skip to content

Commit 2411f46

Browse files
committed
Fixed kernel
1 parent b8abc10 commit 2411f46

5 files changed

Lines changed: 449 additions & 27 deletions

File tree

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
build
2+
dist
3+
.ipynb_checkpoints
4+
*.egg-info/
5+
**/__pycache__

mysql_kernel/autocomplete.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import re
2+
from sqlalchemy import inspect
3+
4+
class SQLAutocompleter:
5+
def __init__(self, engine, log):
6+
"""
7+
Initializes the autocompleter with an SQLAlchemy engine.
8+
9+
Parameters:
10+
- engine: SQLAlchemy engine connected to a database.
11+
"""
12+
self.engine = engine
13+
self.inspector = inspect(engine)
14+
self.default_schema = self.inspector.default_schema_name
15+
self.log = log
16+
self.log.info(f"Autocompleter initialized with engine: {engine}")
17+
18+
def get_real_previous_keyword(self, tokens):
19+
"""
20+
Identifies the real previous keyword in SQL syntax, i.e., section delimiters like `SELECT`, `FROM`, `WHERE`.
21+
22+
Parameters:
23+
- tokens (list): List of tokens (words) before the cursor position.
24+
25+
Returns:
26+
- str: The most recent real SQL keyword (e.g., `SELECT`, `FROM`).
27+
"""
28+
sql_keywords = {
29+
"SELECT", "FROM", "WHERE", "GROUP", "ORDER", "HAVING", "INSERT", "UPDATE", "DELETE",
30+
"JOIN", "ON", "LIMIT", "DISTINCT", "SET"
31+
}
32+
for token in reversed(tokens):
33+
if token.upper() in sql_keywords:
34+
return token.upper()
35+
return ""
36+
37+
38+
def get_completions(self, code, cursor_pos):
39+
"""
40+
Returns autocompletions based on the SQL context.
41+
42+
Parameters:
43+
- code (str): Full SQL query being typed.
44+
- cursor_pos (int): Cursor position in the query.
45+
46+
Returns:
47+
- list: Suggested completions.
48+
"""
49+
preceding_text = code[:cursor_pos]
50+
tokens = re.findall(r"[^ ;\(\)\r\n\t,]+", preceding_text, re.IGNORECASE)
51+
previous_keyword = self.get_real_previous_keyword(tokens)
52+
previous_word = tokens[-1].upper() if tokens else ""
53+
is_preceding_comma = preceding_text.rstrip().endswith(",")
54+
is_preceding_space = preceding_text.endswith(" ")
55+
is_completing_word = preceding_text[-1].isalpha()
56+
current_completing = ''
57+
if is_completing_word:
58+
previous_word = tokens[-2].upper() if len(tokens) > 1 else ""
59+
current_completing = tokens[-1].upper() if tokens else ""
60+
61+
if previous_keyword == "SELECT":
62+
if is_preceding_comma == False and is_preceding_space == True and previous_word != "SELECT":
63+
completions = ["FROM"]
64+
else:
65+
completions = self.get_columns(code) + self.get_functions()
66+
elif previous_keyword in {"FROM", "JOIN"}:
67+
completions = self.get_tables()
68+
elif previous_keyword == "WHERE":
69+
completions = self.get_columns(code) + self.get_functions()
70+
elif previous_word == "GROUP":
71+
completions = ["BY"]
72+
elif previous_word == "ORDER":
73+
completions = ["BY"]
74+
elif previous_word == "INSERT":
75+
completions = ["INTO"]
76+
elif previous_word == "UPDATE":
77+
completions = self.get_tables()
78+
elif previous_keyword == 'UPDATE':
79+
completions = ["SET"]
80+
completions += self.get_tables()
81+
elif previous_word == "DELETE":
82+
completions = ["FROM"]
83+
elif previous_word == "DISTINCT":
84+
completions = self.get_columns(code)
85+
elif previous_keyword == "DISTINCT":
86+
completions = self.get_columns(code) + self.get_functions()
87+
elif previous_keyword in {"GROUP", "ORDER"}:
88+
completions = self.get_columns(code)
89+
elif previous_keyword == "HAVING":
90+
completions = self.get_columns(code) + self.get_functions()
91+
elif previous_keyword == "SET":
92+
completions = self.get_columns(code)
93+
elif previous_word == "VALUES":
94+
completions = '('
95+
elif previous_keyword == "VALUES":
96+
completions = self.get_columns(code)
97+
elif previous_word in {"INNER", "LEFT", "RIGHT", "FULL"}:
98+
completions = ["JOIN"]
99+
elif previous_keyword == "DISTINCT" or previous_keyword == "LIMIT" or previous_keyword == "OFFSET":
100+
completions = []
101+
else:
102+
completions = self.get_sql_keywords()
103+
104+
if is_completing_word:
105+
filter_func = lambda x: x.lower().startswith(current_completing.lower())
106+
if is_preceding_comma == False and is_preceding_space == False:
107+
filtered_suggestions = [suggestion for suggestion in completions if filter_func(suggestion)]
108+
return sorted(filtered_suggestions)
109+
110+
return completions
111+
112+
113+
def get_tables(self):
114+
"""b
115+
Returns a list of available tables, excluding the default schema.
116+
117+
Returns:
118+
- list: Tables without default schema.
119+
"""
120+
121+
schemas = self.inspector.get_schema_names()
122+
tables = self.inspector.get_table_names(schema=self.default_schema) # Get tables in default schema
123+
124+
if self.default_schema:
125+
for schema in schemas:
126+
schema_tables = self.inspector.get_table_names(schema=schema)
127+
128+
if schema != self.default_schema:
129+
tables.extend([f"{schema}.{table}" for table in schema_tables]) # Keep schema.table
130+
131+
return tables
132+
133+
def get_columns(self, code):
134+
"""
135+
Extracts tables from the query and returns relevant columns.
136+
137+
Parameters:
138+
- code (str): SQL query.
139+
140+
Returns:
141+
- list: Column names from the tables used in the query.
142+
"""
143+
table_names = self.extract_table_names(code)
144+
columns = []
145+
for table in table_names:
146+
schema, table_name = self.split_schema_table(table)
147+
try:
148+
table_columns = [col["name"] for col in self.inspector.get_columns(table_name, schema=schema)]
149+
columns.extend(table_columns)
150+
except Exception:
151+
pass # Ignore missing tables
152+
return columns
153+
154+
def get_functions(self):
155+
"""Returns common SQL functions."""
156+
return [
157+
"COUNT()", "AVG()", "SUM()", "MIN()", "MAX()",
158+
"LOWER()", "UPPER()", "NOW()", "DATE()", "ROUND()"
159+
]
160+
161+
def get_sql_keywords(self):
162+
"""Returns a list of common SQL keywords."""
163+
return [
164+
"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "HAVING",
165+
"INSERT INTO", "VALUES", "UPDATE", "SET", "DELETE FROM",
166+
"JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL JOIN",
167+
"ON", "DISTINCT", "LIMIT", "OFFSET"
168+
]
169+
170+
def extract_table_names(self, code):
171+
"""
172+
Extracts table names (including schema-qualified) from an SQL query.
173+
174+
Parameters:
175+
- code (str): SQL query.
176+
177+
Returns:
178+
- list: Table names found in the query.
179+
"""
180+
matches = re.findall(r"FROM\s+([\w.]+)|JOIN\s+([\w.]+)|UPDATE\s+([\w.]+)", code, re.IGNORECASE)
181+
return [table for tup in matches for table in tup if table]
182+
183+
def split_schema_table(self, table):
184+
"""
185+
Splits a schema-qualified table into schema and table parts.
186+
187+
Parameters:
188+
- table (str): Table name (could be schema-qualified like 'schema.table').
189+
190+
Returns:
191+
- tuple: (schema, table_name) or (None, table_name) if no schema.
192+
"""
193+
parts = table.split(".")
194+
if len(parts) == 2:
195+
schema, table_name = parts
196+
if schema == self.default_schema: # Remove default schema
197+
return None, table_name
198+
return schema, table_name
199+
return self.default_schema, table # No schema
200+
201+
202+
# Example usage
203+
# from sqlalchemy import create_engine
204+
# engine = create_engine("postgresql://postgres@localhost/tume")
205+
# completer = SQLAutocompleter(engine)
206+
# completions = completer.get_completions("SELECT municipio, FROM tume.cadastro", 17)
207+
# print(completions)

0 commit comments

Comments
 (0)