Skip to content

Implement \llm command. #1229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from mycli.packages.filepaths import dir_path_exists, guess_socket_location
from mycli.packages.parseutils import is_destructive, is_dropping_database
from mycli.packages.prompt_utils import confirm, confirm_destructive_query
from mycli.packages.special.favoritequeries import FavoriteQueries
from mycli.packages.special.main import NO_QUERY
from mycli.packages.tabular_output import sql_format
from mycli.packages.toolkit.history import FileHistoryWithTimestamp
Expand Down Expand Up @@ -125,8 +124,6 @@ def __init__(
special.set_timing_enabled(c["main"].as_bool("timing"))
self.beep_after_seconds = float(c["main"]["beep_after_seconds"] or 0)

FavoriteQueries.instance = FavoriteQueries.from_config(self.config)

self.dsn_alias = None
self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
sql_format.register_new_formatter(self.formatter)
Expand Down Expand Up @@ -656,6 +653,47 @@ def get_continuation(width, *_):
def show_suggestion_tip():
return iterations < 2

def output_res(res, start):
result_count = 0
mutating = False
for title, cur, headers, status in res:
logger.debug("headers: %r", headers)
logger.debug("rows: %r", cur)
logger.debug("status: %r", status)
threshold = 1000
if is_select(status) and cur and cur.rowcount > threshold:
self.echo(
"The result set has more than {} rows.".format(threshold),
fg="red",
)
if not confirm("Do you want to continue?"):
self.echo("Aborted!", err=True, fg="red")
break

if self.auto_vertical_output:
max_width = self.prompt_app.output.get_size().columns
else:
max_width = None

formatted = self.format_output(title, cur, headers, special.is_expanded_output(), max_width)

t = time() - start
try:
if result_count > 0:
self.echo("")
try:
self.output(formatted, status)
except KeyboardInterrupt:
pass
self.echo("Time: %0.03fs" % t)
except KeyboardInterrupt:
pass

start = time()
result_count += 1
mutating = mutating or is_mutating(status)
return mutating

def one_iteration(text=None):
if text is None:
try:
Expand All @@ -682,6 +720,27 @@ def one_iteration(text=None):
logger.error("traceback: %r", traceback.format_exc())
self.echo(str(e), err=True, fg="red")
return
# LLM command support
while special.is_llm_command(text):
try:
start = time()
cur = sqlexecute.conn.cursor()
context, sql, duration = special.handle_llm(text, cur)
if context:
click.echo("LLM Response:")
click.echo(context)
click.echo("---")
click.echo(f"Time: {duration:.2f} seconds")
text = self.prompt_app.prompt(default=sql)
except KeyboardInterrupt:
return
except special.FinishIteration as e:
return output_res(e.results, start) if e.results else None
except RuntimeError as e:
logger.error("sql: %r, error: %r", text, e)
logger.error("traceback: %r", traceback.format_exc())
self.echo(str(e), err=True, fg="red")
return

if not text.strip():
return
Expand Down
2 changes: 2 additions & 0 deletions mycli/packages/completion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def suggest_special(text):
]
elif cmd in ["\\.", "source"]:
return [{"type": "file_name"}]
if cmd in ["\\llm", "\\ai"]:
return [{"type": "llm"}]

return [{"type": "keyword"}, {"type": "special"}]

Expand Down
1 change: 1 addition & 0 deletions mycli/packages/special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ def export(defn):
from mycli.packages.special import (
dbcommands, # noqa: E402 F401
iocommands, # noqa: E402 F401
llm, # noqa: E402 F401
)
3 changes: 0 additions & 3 deletions mycli/packages/special/favoritequeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ class FavoriteQueries(object):
simple: Deleted
"""

# Class-level variable, for convenience to use as a singleton.
instance = None

def __init__(self, config):
self.config = config

Expand Down
22 changes: 15 additions & 7 deletions mycli/packages/special/iocommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from time import sleep

import click
from configobj import ConfigObj
import pyperclip
import sqlparse

Expand All @@ -27,6 +28,13 @@
pipe_once_process = None
written_to_pipe_once_process = False
delimiter_command = DelimiterCommand()
favoritequeries = FavoriteQueries(ConfigObj())


@export
def set_favorite_queries(config):
global favoritequeries
favoritequeries = FavoriteQueries(config)


@export
Expand Down Expand Up @@ -233,7 +241,7 @@ def execute_favorite_query(cur, arg, **_):
name, _, arg_str = arg.partition(" ")
args = shlex.split(arg_str)

query = FavoriteQueries.instance.get(name)
query = favoritequeries.get(name)
if query is None:
message = "No favorite query: %s" % (name)
yield (None, None, None, message)
Expand All @@ -258,10 +266,10 @@ def list_favorite_queries():
Returns (title, rows, headers, status)"""

headers = ["Name", "Query"]
rows = [(r, FavoriteQueries.instance.get(r)) for r in FavoriteQueries.instance.list()]
rows = [(r, favoritequeries.get(r)) for r in favoritequeries.list()]

if not rows:
status = "\nNo favorite queries found." + FavoriteQueries.instance.usage
status = "\nNo favorite queries found." + favoritequeries.usage
else:
status = ""
return [("", rows, headers, status)]
Expand All @@ -288,7 +296,7 @@ def save_favorite_query(arg, **_):
"""Save a new favorite query.
Returns (title, rows, headers, status)"""

usage = "Syntax: \\fs name query.\n\n" + FavoriteQueries.instance.usage
usage = "Syntax: \\fs name query.\n\n" + favoritequeries.usage
if not arg:
return [(None, None, None, usage)]

Expand All @@ -298,18 +306,18 @@ def save_favorite_query(arg, **_):
if (not name) or (not query):
return [(None, None, None, usage + "Err: Both name and query are required.")]

FavoriteQueries.instance.save(name, query)
favoritequeries.save(name, query)
return [(None, None, None, "Saved.")]


@special_command("\\fd", "\\fd [name]", "Delete a favorite query.")
def delete_favorite_query(arg, **_):
"""Delete an existing favorite query."""
usage = "Syntax: \\fd name.\n\n" + FavoriteQueries.instance.usage
usage = "Syntax: \\fd name.\n\n" + favoritequeries.usage
if not arg:
return [(None, None, None, usage)]

status = FavoriteQueries.instance.delete(arg)
status = favoritequeries.delete(arg)

return [(None, None, None, status)]

Expand Down
Loading
Loading