Skip to content

Commit f24e276

Browse files
authored
Merge pull request #1384 from dbcli/RW/conserve-llm-tokens-and-cache
Reduce size of LLM prompts + cache per-schema context
2 parents bbeea07 + 999ec16 commit f24e276

File tree

6 files changed

+157
-32
lines changed

6 files changed

+157
-32
lines changed

changelog.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
TBD
2+
==============
3+
4+
Features
5+
--------
6+
* Options to limit size of LLM prompts; cache LLM prompt data.
7+
8+
9+
Bug Fixes
10+
--------
11+
* Correct mangled schema info sent in LLM prompts.
12+
13+
114
1.50.0 (2026/02/07)
215
==============
316

mycli/main.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,14 @@ def __init__(
169169
self.null_string = c['main'].get('null_string')
170170
self.numeric_alignment = c['main'].get('numeric_alignment', 'right')
171171
self.binary_display = c['main'].get('binary_display')
172+
if 'llm' in c and re.match(r'^\d+$', c['llm'].get('prompt_field_truncate', '')):
173+
self.llm_prompt_field_truncate = int(c['llm'].get('prompt_field_truncate'))
174+
else:
175+
self.llm_prompt_field_truncate = 0
176+
if 'llm' in c and re.match(r'^\d+$', c['llm'].get('prompt_section_truncate', '')):
177+
self.llm_prompt_section_truncate = int(c['llm'].get('prompt_section_truncate'))
178+
else:
179+
self.llm_prompt_section_truncate = 0
172180

173181
# set ssl_mode if a valid option is provided in a config file, otherwise None
174182
ssl_mode = c["main"].get("ssl_mode", None)
@@ -965,9 +973,16 @@ def one_iteration(text: str | None = None) -> None:
965973
while special.is_llm_command(text):
966974
start = time()
967975
try:
976+
assert isinstance(self.sqlexecute, SQLExecute)
968977
assert sqlexecute.conn is not None
969978
cur = sqlexecute.conn.cursor()
970-
context, sql, duration = special.handle_llm(text, cur)
979+
context, sql, duration = special.handle_llm(
980+
text,
981+
cur,
982+
sqlexecute.dbname or '',
983+
self.llm_prompt_field_truncate,
984+
self.llm_prompt_section_truncate,
985+
)
971986
if context:
972987
click.echo("LLM Response:")
973988
click.echo(context)

mycli/myclirc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,17 @@ default_ssl_cipher =
176176
# --ssl-verify-server-cert being set
177177
default_ssl_verify_server_cert = False
178178

179+
[llm]
180+
181+
# If set to a positive integer, truncate text/binary fields to that width
182+
# in bytes when sending sample data, to conserve tokens. Suggestion: 1024.
183+
prompt_field_truncate = None
184+
185+
# If set to a positive integer, attempt to truncate various sections of LLM
186+
# prompt input to that number in bytes, to conserve tokens. Suggestion:
187+
# 1000000.
188+
prompt_section_truncate = None
189+
179190
[keys]
180191
# possible values: auto, fzf, reverse_isearch
181192
control_r = auto

mycli/packages/special/llm.py

Lines changed: 95 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838

3939
LLM_TEMPLATE_NAME = "mycli-llm-template"
4040

41+
SCHEMA_DATA_CACHE: dict[str, str] = {}
42+
43+
SAMPLE_DATA_CACHE: dict[str, dict] = {}
44+
4145

4246
def run_external_cmd(
4347
cmd: str,
@@ -212,7 +216,13 @@ def cli_commands() -> list[str]:
212216
return list(cli.commands.keys())
213217

214218

215-
def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]:
219+
def handle_llm(
220+
text: str,
221+
cur: Cursor,
222+
dbname: str,
223+
prompt_field_truncate: int,
224+
prompt_section_truncate: int,
225+
) -> tuple[str, str | None, float]:
216226
_, verbosity, arg = parse_special_command(text)
217227
if not LLM_IMPORTED:
218228
output = [(None, None, None, NEED_DEPENDENCIES)]
@@ -261,7 +271,13 @@ def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]:
261271
try:
262272
ensure_mycli_template()
263273
start = time()
264-
context, sql = sql_using_llm(cur=cur, question=arg)
274+
context, sql = sql_using_llm(
275+
cur=cur,
276+
question=arg,
277+
dbname=dbname,
278+
prompt_field_truncate=prompt_field_truncate,
279+
prompt_section_truncate=prompt_section_truncate,
280+
)
265281
end = time()
266282
if verbosity == Verbosity.SUCCINCT:
267283
context = ""
@@ -275,51 +291,110 @@ def is_llm_command(command: str) -> bool:
275291
return cmd in ("\\llm", "\\ai")
276292

277293

278-
def sql_using_llm(
279-
cur: Cursor | None,
280-
question: str | None = None,
281-
) -> tuple[str, str | None]:
282-
if cur is None:
283-
raise RuntimeError("Connect to a database and try again.")
284-
schema_query = """
285-
SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')')
294+
def truncate_list_elements(row: list, prompt_field_truncate: int, prompt_section_truncate: int) -> list:
295+
if not prompt_section_truncate and not prompt_field_truncate:
296+
return row
297+
298+
width = prompt_field_truncate
299+
while width >= 0:
300+
truncated_row = [x[:width] if isinstance(x, (str, bytes)) else x for x in row]
301+
if prompt_section_truncate:
302+
if sum(sys.getsizeof(x) for x in truncated_row) <= prompt_section_truncate:
303+
break
304+
width -= 100
305+
else:
306+
break
307+
return truncated_row
308+
309+
310+
def truncate_table_lines(table: list[str], prompt_section_truncate: int) -> list[str]:
311+
if not prompt_section_truncate:
312+
return table
313+
314+
truncated_table = []
315+
running_sum = 0
316+
while table and running_sum <= prompt_section_truncate:
317+
line = table.pop(0)
318+
running_sum += sys.getsizeof(line)
319+
truncated_table.append(line)
320+
return truncated_table
321+
322+
323+
def get_schema(cur: Cursor, dbname: str, prompt_section_truncate: int) -> str:
324+
if dbname in SCHEMA_DATA_CACHE:
325+
return SCHEMA_DATA_CACHE[dbname]
326+
click.echo("Preparing schema information to feed the LLM")
327+
schema_query = f"""
328+
SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')') AS `schema`
286329
FROM information_schema.columns
287-
WHERE table_schema = DATABASE()
330+
WHERE table_schema = '{dbname}'
288331
GROUP BY table_name
289332
ORDER BY table_name
290333
"""
291-
tables_query = "SHOW TABLES"
292-
sample_row_query = "SELECT * FROM `{table}` LIMIT 1"
293-
click.echo("Preparing schema information to feed the llm")
294334
cur.execute(schema_query)
295-
db_schema = "\n".join([row[0] for (row,) in cur.fetchall()])
335+
db_schema = [row for (row,) in cur.fetchall()]
336+
summary = '\n'.join(truncate_table_lines(db_schema, prompt_section_truncate))
337+
SCHEMA_DATA_CACHE[dbname] = summary
338+
return summary
339+
340+
341+
def get_sample_data(
342+
cur: Cursor,
343+
dbname: str,
344+
prompt_field_truncate: int,
345+
prompt_section_truncate: int,
346+
) -> dict[str, Any]:
347+
if dbname in SAMPLE_DATA_CACHE:
348+
return SAMPLE_DATA_CACHE[dbname]
349+
click.echo("Preparing sample data to feed the LLM")
350+
tables_query = "SHOW TABLES"
351+
sample_row_query = "SELECT * FROM `{dbname}`.`{table}` LIMIT 1"
296352
cur.execute(tables_query)
297353
sample_data = {}
298354
for (table_name,) in cur.fetchall():
299355
try:
300-
cur.execute(sample_row_query.format(table=table_name))
356+
cur.execute(sample_row_query.format(dbname=dbname, table=table_name))
301357
except Exception:
302358
continue
303359
cols = [desc[0] for desc in cur.description]
304360
row = cur.fetchone()
305361
if row is None:
306362
continue
307-
sample_data[table_name] = list(zip(cols, row, strict=True))
363+
sample_data[table_name] = list(
364+
zip(cols, truncate_list_elements(list(row), prompt_field_truncate, prompt_section_truncate), strict=False)
365+
)
366+
SAMPLE_DATA_CACHE[dbname] = sample_data
367+
return sample_data
368+
369+
370+
def sql_using_llm(
371+
cur: Cursor | None,
372+
question: str | None,
373+
dbname: str = '',
374+
prompt_field_truncate: int = 0,
375+
prompt_section_truncate: int = 0,
376+
) -> tuple[str, str | None]:
377+
if cur is None:
378+
raise RuntimeError("Connect to a database and try again.")
379+
if dbname == '':
380+
raise RuntimeError("Choose a schema and try again.")
308381
args = [
309382
"--template",
310383
LLM_TEMPLATE_NAME,
311384
"--param",
312385
"db_schema",
313-
db_schema,
386+
get_schema(cur, dbname, prompt_section_truncate),
314387
"--param",
315388
"sample_data",
316-
sample_data,
389+
get_sample_data(cur, dbname, prompt_field_truncate, prompt_section_truncate),
317390
"--param",
318391
"question",
319392
question,
320393
" ",
321394
]
322-
click.echo("Invoking llm command with schema information")
395+
click.echo(args[4])
396+
click.echo(args[7])
397+
click.echo("Invoking llm command with schema information and sample data")
323398
_, result = run_external_cmd("llm", *args, capture_output=True)
324399
click.echo("Received response from the llm command")
325400
match = re.search(_SQL_CODE_FENCE, result, re.DOTALL)

test/myclirc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,17 @@ default_ssl_cipher =
174174
# --ssl-verify-server-cert being set
175175
default_ssl_verify_server_cert = False
176176

177+
[llm]
178+
179+
# If set to a positive integer, truncate text/binary fields to that width
180+
# in bytes when sending sample data, to conserve tokens. Suggestion: 1024.
181+
prompt_field_truncate = None
182+
183+
# If set to a positive integer, attempt to truncate various sections of LLM
184+
# prompt input to that number in bytes, to conserve tokens. Suggestion:
185+
# 1000000.
186+
prompt_section_truncate = None
187+
177188
[keys]
178189
# possible values: auto, fzf, reverse_isearch
179190
control_r = auto

test/test_llm_special.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_llm_command_without_args(mock_llm, executor):
2626
assert mock_llm is not None
2727
test_text = r"\llm"
2828
with pytest.raises(FinishIteration) as exc_info:
29-
handle_llm(test_text, executor)
29+
handle_llm(test_text, executor, 'mysql', 0, 0)
3030
# Should return usage message when no args provided
3131
assert exc_info.value.args[0] == [(None, None, None, USAGE)]
3232

@@ -38,7 +38,7 @@ def test_llm_command_with_c_flag(mock_run_cmd, mock_llm, executor):
3838
mock_run_cmd.return_value = (0, "Hello, no SQL today.")
3939
test_text = r"\llm -c 'Something?'"
4040
with pytest.raises(FinishIteration) as exc_info:
41-
handle_llm(test_text, executor)
41+
handle_llm(test_text, executor, 'mysql', 0, 0)
4242
# Expect raw output when no SQL fence found
4343
assert exc_info.value.args[0] == [(None, None, None, "Hello, no SQL today.")]
4444

@@ -51,7 +51,7 @@ def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor
5151
fenced = f"Here you go:\n```sql\n{sql_text}\n```"
5252
mock_run_cmd.return_value = (0, fenced)
5353
test_text = r"\llm -c 'Rewrite SQL'"
54-
result, sql, duration = handle_llm(test_text, executor)
54+
result, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0)
5555
# Without verbose, result is empty, sql extracted
5656
assert sql == sql_text
5757
assert result == ""
@@ -64,7 +64,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor):
6464
# 'models' is a known subcommand
6565
test_text = r"\llm models"
6666
with pytest.raises(FinishIteration) as exc_info:
67-
handle_llm(test_text, executor)
67+
handle_llm(test_text, executor, 'mysql', 0, 0)
6868
mock_run_cmd.assert_called_once_with("llm", "models", restart_cli=False)
6969
assert exc_info.value.args[0] is None
7070

@@ -74,7 +74,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor):
7474
def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor):
7575
test_text = r"\llm --help"
7676
with pytest.raises(FinishIteration) as exc_info:
77-
handle_llm(test_text, executor)
77+
handle_llm(test_text, executor, 'mysql', 0, 0)
7878
mock_run_cmd.assert_called_once_with("llm", "--help", restart_cli=False)
7979
assert exc_info.value.args[0] is None
8080

@@ -84,7 +84,7 @@ def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor):
8484
def test_llm_command_with_install_flag(mock_run_cmd, mock_llm, executor):
8585
test_text = r"\llm install openai"
8686
with pytest.raises(FinishIteration) as exc_info:
87-
handle_llm(test_text, executor)
87+
handle_llm(test_text, executor, 'mysql', 0, 0)
8888
mock_run_cmd.assert_called_once_with("llm", "install", "openai", restart_cli=True)
8989
assert exc_info.value.args[0] is None
9090

@@ -98,7 +98,7 @@ def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_
9898
"""
9999
mock_sql_using_llm.return_value = ("CTX", "SELECT 1;")
100100
test_text = r"\llm prompt 'Test?'"
101-
context, sql, duration = handle_llm(test_text, executor)
101+
context, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0)
102102
mock_ensure_template.assert_called_once()
103103
mock_sql_using_llm.assert_called()
104104
assert context == "CTX"
@@ -115,7 +115,7 @@ def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_templ
115115
"""
116116
mock_sql_using_llm.return_value = ("CTX2", "SELECT 2;")
117117
test_text = r"\llm 'Top 10?'"
118-
context, sql, duration = handle_llm(test_text, executor)
118+
context, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0)
119119
mock_ensure_template.assert_called_once()
120120
mock_sql_using_llm.assert_called()
121121
assert context == "CTX2"
@@ -132,7 +132,7 @@ def test_llm_command_question_verbose(mock_sql_using_llm, mock_ensure_template,
132132
"""
133133
mock_sql_using_llm.return_value = ("NO_CTX", "SELECT 42;")
134134
test_text = r"\llm- 'Succinct?'"
135-
context, sql, duration = handle_llm(test_text, executor)
135+
context, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0)
136136
assert context == ""
137137
assert sql == "SELECT 42;"
138138
assert isinstance(duration, float)
@@ -181,7 +181,7 @@ def fetchone(self):
181181
sql_text = "SELECT 1, 'abc';"
182182
fenced = f"Note\n```sql\n{sql_text}\n```"
183183
mock_run_cmd.return_value = (0, fenced)
184-
result, sql = sql_using_llm(dummy_cur, question="dummy")
184+
result, sql = sql_using_llm(dummy_cur, question="dummy", dbname='mysql')
185185
assert result == fenced
186186
assert sql == sql_text
187187

@@ -194,5 +194,5 @@ def test_handle_llm_aliases_without_args(prefix, executor, monkeypatch):
194194

195195
monkeypatch.setattr(llm_module, "llm", object())
196196
with pytest.raises(FinishIteration) as exc_info:
197-
handle_llm(prefix, executor)
197+
handle_llm(prefix, executor, 'mysql', 0, 0)
198198
assert exc_info.value.args[0] == [(None, None, None, USAGE)]

0 commit comments

Comments
 (0)