1+ import re
2+ from sqlalchemy import inspect
3+
4+ class SQLAutocompleter :
5+ def __init__ (self , engine , log ):
6+ """
7+ Initializes the autocompleter with an SQLAlchemy engine.
8+
9+ Parameters:
10+ - engine: SQLAlchemy engine connected to a database.
11+ """
12+ self .engine = engine
13+ self .inspector = inspect (engine )
14+ self .default_schema = self .inspector .default_schema_name
15+ self .log = log
16+ self .log .info (f"Autocompleter initialized with engine: { engine } " )
17+
18+ def get_real_previous_keyword (self , tokens ):
19+ """
20+ Identifies the real previous keyword in SQL syntax, i.e., section delimiters like `SELECT`, `FROM`, `WHERE`.
21+
22+ Parameters:
23+ - tokens (list): List of tokens (words) before the cursor position.
24+
25+ Returns:
26+ - str: The most recent real SQL keyword (e.g., `SELECT`, `FROM`).
27+ """
28+ sql_keywords = {
29+ "SELECT" , "FROM" , "WHERE" , "GROUP" , "ORDER" , "HAVING" , "INSERT" , "UPDATE" , "DELETE" ,
30+ "JOIN" , "ON" , "LIMIT" , "DISTINCT" , "SET"
31+ }
32+ for token in reversed (tokens ):
33+ if token .upper () in sql_keywords :
34+ return token .upper ()
35+ return ""
36+
37+
38+ def get_completions (self , code , cursor_pos ):
39+ """
40+ Returns autocompletions based on the SQL context.
41+
42+ Parameters:
43+ - code (str): Full SQL query being typed.
44+ - cursor_pos (int): Cursor position in the query.
45+
46+ Returns:
47+ - list: Suggested completions.
48+ """
49+ preceding_text = code [:cursor_pos ]
50+ tokens = re .findall (r"[^ ;\(\)\r\n\t,]+" , preceding_text , re .IGNORECASE )
51+ previous_keyword = self .get_real_previous_keyword (tokens )
52+ previous_word = tokens [- 1 ].upper () if tokens else ""
53+ is_preceding_comma = preceding_text .rstrip ().endswith ("," )
54+ is_preceding_space = preceding_text .endswith (" " )
55+ is_completing_word = preceding_text [- 1 ].isalpha ()
56+ current_completing = ''
57+ if is_completing_word :
58+ previous_word = tokens [- 2 ].upper () if len (tokens ) > 1 else ""
59+ current_completing = tokens [- 1 ].upper () if tokens else ""
60+
61+ if previous_keyword == "SELECT" :
62+ if is_preceding_comma == False and is_preceding_space == True and previous_word != "SELECT" :
63+ completions = ["FROM" ]
64+ else :
65+ completions = self .get_columns (code ) + self .get_functions ()
66+ elif previous_keyword in {"FROM" , "JOIN" }:
67+ completions = self .get_tables ()
68+ elif previous_keyword == "WHERE" :
69+ completions = self .get_columns (code ) + self .get_functions ()
70+ elif previous_word == "GROUP" :
71+ completions = ["BY" ]
72+ elif previous_word == "ORDER" :
73+ completions = ["BY" ]
74+ elif previous_word == "INSERT" :
75+ completions = ["INTO" ]
76+ elif previous_word == "UPDATE" :
77+ completions = self .get_tables ()
78+ elif previous_keyword == 'UPDATE' :
79+ completions = ["SET" ]
80+ completions += self .get_tables ()
81+ elif previous_word == "DELETE" :
82+ completions = ["FROM" ]
83+ elif previous_word == "DISTINCT" :
84+ completions = self .get_columns (code )
85+ elif previous_keyword == "DISTINCT" :
86+ completions = self .get_columns (code ) + self .get_functions ()
87+ elif previous_keyword in {"GROUP" , "ORDER" }:
88+ completions = self .get_columns (code )
89+ elif previous_keyword == "HAVING" :
90+ completions = self .get_columns (code ) + self .get_functions ()
91+ elif previous_keyword == "SET" :
92+ completions = self .get_columns (code )
93+ elif previous_word == "VALUES" :
94+ completions = '('
95+ elif previous_keyword == "VALUES" :
96+ completions = self .get_columns (code )
97+ elif previous_word in {"INNER" , "LEFT" , "RIGHT" , "FULL" }:
98+ completions = ["JOIN" ]
99+ elif previous_keyword == "DISTINCT" or previous_keyword == "LIMIT" or previous_keyword == "OFFSET" :
100+ completions = []
101+ else :
102+ completions = self .get_sql_keywords ()
103+
104+ if is_completing_word :
105+ filter_func = lambda x : x .lower ().startswith (current_completing .lower ())
106+ if is_preceding_comma == False and is_preceding_space == False :
107+ filtered_suggestions = [suggestion for suggestion in completions if filter_func (suggestion )]
108+ return sorted (filtered_suggestions )
109+
110+ return completions
111+
112+
113+ def get_tables (self ):
114+ """b
115+ Returns a list of available tables, excluding the default schema.
116+
117+ Returns:
118+ - list: Tables without default schema.
119+ """
120+
121+ schemas = self .inspector .get_schema_names ()
122+ tables = self .inspector .get_table_names (schema = self .default_schema ) # Get tables in default schema
123+
124+ if self .default_schema :
125+ for schema in schemas :
126+ schema_tables = self .inspector .get_table_names (schema = schema )
127+
128+ if schema != self .default_schema :
129+ tables .extend ([f"{ schema } .{ table } " for table in schema_tables ]) # Keep schema.table
130+
131+ return tables
132+
133+ def get_columns (self , code ):
134+ """
135+ Extracts tables from the query and returns relevant columns.
136+
137+ Parameters:
138+ - code (str): SQL query.
139+
140+ Returns:
141+ - list: Column names from the tables used in the query.
142+ """
143+ table_names = self .extract_table_names (code )
144+ columns = []
145+ for table in table_names :
146+ schema , table_name = self .split_schema_table (table )
147+ try :
148+ table_columns = [col ["name" ] for col in self .inspector .get_columns (table_name , schema = schema )]
149+ columns .extend (table_columns )
150+ except Exception :
151+ pass # Ignore missing tables
152+ return columns
153+
154+ def get_functions (self ):
155+ """Returns common SQL functions."""
156+ return [
157+ "COUNT()" , "AVG()" , "SUM()" , "MIN()" , "MAX()" ,
158+ "LOWER()" , "UPPER()" , "NOW()" , "DATE()" , "ROUND()"
159+ ]
160+
161+ def get_sql_keywords (self ):
162+ """Returns a list of common SQL keywords."""
163+ return [
164+ "SELECT" , "FROM" , "WHERE" , "GROUP BY" , "ORDER BY" , "HAVING" ,
165+ "INSERT INTO" , "VALUES" , "UPDATE" , "SET" , "DELETE FROM" ,
166+ "JOIN" , "INNER JOIN" , "LEFT JOIN" , "RIGHT JOIN" , "FULL JOIN" ,
167+ "ON" , "DISTINCT" , "LIMIT" , "OFFSET"
168+ ]
169+
170+ def extract_table_names (self , code ):
171+ """
172+ Extracts table names (including schema-qualified) from an SQL query.
173+
174+ Parameters:
175+ - code (str): SQL query.
176+
177+ Returns:
178+ - list: Table names found in the query.
179+ """
180+ matches = re .findall (r"FROM\s+([\w.]+)|JOIN\s+([\w.]+)|UPDATE\s+([\w.]+)" , code , re .IGNORECASE )
181+ return [table for tup in matches for table in tup if table ]
182+
183+ def split_schema_table (self , table ):
184+ """
185+ Splits a schema-qualified table into schema and table parts.
186+
187+ Parameters:
188+ - table (str): Table name (could be schema-qualified like 'schema.table').
189+
190+ Returns:
191+ - tuple: (schema, table_name) or (None, table_name) if no schema.
192+ """
193+ parts = table .split ("." )
194+ if len (parts ) == 2 :
195+ schema , table_name = parts
196+ if schema == self .default_schema : # Remove default schema
197+ return None , table_name
198+ return schema , table_name
199+ return self .default_schema , table # No schema
200+
201+
202+ # Example usage
203+ # from sqlalchemy import create_engine
204+ # engine = create_engine("postgresql://postgres@localhost/tume")
205+ # completer = SQLAutocompleter(engine)
206+ # completions = completer.get_completions("SELECT municipio, FROM tume.cadastro", 17)
207+ # print(completions)
0 commit comments