diff --git a/mycli/main.py b/mycli/main.py index 7b8018ec..c38fa28a 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -48,7 +48,6 @@ from mycli.packages.filepaths import dir_path_exists, guess_socket_location from mycli.packages.parseutils import is_destructive, is_dropping_database from mycli.packages.prompt_utils import confirm, confirm_destructive_query -from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import NO_QUERY from mycli.packages.tabular_output import sql_format from mycli.packages.toolkit.history import FileHistoryWithTimestamp @@ -125,8 +124,6 @@ def __init__( special.set_timing_enabled(c["main"].as_bool("timing")) self.beep_after_seconds = float(c["main"]["beep_after_seconds"] or 0) - FavoriteQueries.instance = FavoriteQueries.from_config(self.config) - self.dsn_alias = None self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) sql_format.register_new_formatter(self.formatter) @@ -656,6 +653,47 @@ def get_continuation(width, *_): def show_suggestion_tip(): return iterations < 2 + def output_res(res, start): + result_count = 0 + mutating = False + for title, cur, headers, status in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + threshold = 1000 + if is_select(status) and cur and cur.rowcount > threshold: + self.echo( + "The result set has more than {} rows.".format(threshold), + fg="red", + ) + if not confirm("Do you want to continue?"): + self.echo("Aborted!", err=True, fg="red") + break + + if self.auto_vertical_output: + max_width = self.prompt_app.output.get_size().columns + else: + max_width = None + + formatted = self.format_output(title, cur, headers, special.is_expanded_output(), max_width) + + t = time() - start + try: + if result_count > 0: + self.echo("") + try: + self.output(formatted, status) + except KeyboardInterrupt: + pass + self.echo("Time: %0.03fs" % t) + except KeyboardInterrupt: + pass + + start = time() + result_count += 1 + mutating = mutating or is_mutating(status) + return mutating + def one_iteration(text=None): if text is None: try: @@ -682,6 +720,27 @@ def one_iteration(text=None): logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg="red") return + # LLM command support + while special.is_llm_command(text): + try: + start = time() + cur = sqlexecute.conn.cursor() + context, sql, duration = special.handle_llm(text, cur) + if context: + click.echo("LLM Response:") + click.echo(context) + click.echo("---") + click.echo(f"Time: {duration:.2f} seconds") + text = self.prompt_app.prompt(default=sql) + except KeyboardInterrupt: + return + except special.FinishIteration as e: + return output_res(e.results, start) if e.results else None + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + return if not text.strip(): return diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index 095ed1b3..1ebf55ee 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -103,6 +103,8 @@ def suggest_special(text): ] elif cmd in ["\\.", "source"]: return [{"type": "file_name"}] + if cmd in ["\\llm", "\\ai"]: + return [{"type": "llm"}] return [{"type": "keyword"}, {"type": "special"}] diff --git a/mycli/packages/special/__init__.py b/mycli/packages/special/__init__.py index 9f05514c..8405d0c3 100644 --- a/mycli/packages/special/__init__.py +++ b/mycli/packages/special/__init__.py @@ -11,4 +11,5 @@ def export(defn): from mycli.packages.special import ( dbcommands, # noqa: E402 F401 iocommands, # noqa: E402 F401 + llm, # noqa: E402 F401 ) diff --git a/mycli/packages/special/favoritequeries.py b/mycli/packages/special/favoritequeries.py index 3f8648cf..ef155006 100644 --- a/mycli/packages/special/favoritequeries.py +++ b/mycli/packages/special/favoritequeries.py @@ -30,9 +30,6 @@ class FavoriteQueries(object): simple: Deleted """ - # Class-level variable, for convenience to use as a singleton. - instance = None - def __init__(self, config): self.config = config diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py index fb593e11..c9ddfb1b 100644 --- a/mycli/packages/special/iocommands.py +++ b/mycli/packages/special/iocommands.py @@ -7,6 +7,7 @@ from time import sleep import click +from configobj import ConfigObj import pyperclip import sqlparse @@ -27,6 +28,13 @@ pipe_once_process = None written_to_pipe_once_process = False delimiter_command = DelimiterCommand() +favoritequeries = FavoriteQueries(ConfigObj()) + + +@export +def set_favorite_queries(config): + global favoritequeries + favoritequeries = FavoriteQueries(config) @export @@ -233,7 +241,7 @@ def execute_favorite_query(cur, arg, **_): name, _, arg_str = arg.partition(" ") args = shlex.split(arg_str) - query = FavoriteQueries.instance.get(name) + query = favoritequeries.get(name) if query is None: message = "No favorite query: %s" % (name) yield (None, None, None, message) @@ -258,10 +266,10 @@ def list_favorite_queries(): Returns (title, rows, headers, status)""" headers = ["Name", "Query"] - rows = [(r, FavoriteQueries.instance.get(r)) for r in FavoriteQueries.instance.list()] + rows = [(r, favoritequeries.get(r)) for r in favoritequeries.list()] if not rows: - status = "\nNo favorite queries found." + FavoriteQueries.instance.usage + status = "\nNo favorite queries found." + favoritequeries.usage else: status = "" return [("", rows, headers, status)] @@ -288,7 +296,7 @@ def save_favorite_query(arg, **_): """Save a new favorite query. Returns (title, rows, headers, status)""" - usage = "Syntax: \\fs name query.\n\n" + FavoriteQueries.instance.usage + usage = "Syntax: \\fs name query.\n\n" + favoritequeries.usage if not arg: return [(None, None, None, usage)] @@ -298,18 +306,18 @@ def save_favorite_query(arg, **_): if (not name) or (not query): return [(None, None, None, usage + "Err: Both name and query are required.")] - FavoriteQueries.instance.save(name, query) + favoritequeries.save(name, query) return [(None, None, None, "Saved.")] @special_command("\\fd", "\\fd [name]", "Delete a favorite query.") def delete_favorite_query(arg, **_): """Delete an existing favorite query.""" - usage = "Syntax: \\fd name.\n\n" + FavoriteQueries.instance.usage + usage = "Syntax: \\fd name.\n\n" + favoritequeries.usage if not arg: return [(None, None, None, usage)] - status = FavoriteQueries.instance.delete(arg) + status = favoritequeries.delete(arg) return [(None, None, None, status)] diff --git a/mycli/packages/special/llm.py b/mycli/packages/special/llm.py new file mode 100644 index 00000000..f0016dfa --- /dev/null +++ b/mycli/packages/special/llm.py @@ -0,0 +1,269 @@ +import contextlib +import io +import logging +import os +import re +from runpy import run_module +import shlex +import sys +from time import time +from typing import Optional, Tuple + +import click +import llm +from llm.cli import cli + +from mycli.packages.special import export +from mycli.packages.special.main import Verbosity, parse_special_command + +log = logging.getLogger(__name__) + +LLM_CLI_COMMANDS = list(cli.commands.keys()) +MODELS = {x.model_id: None for x in llm.get_models()} +LLM_TEMPLATE_NAME = "mycli-llm-template" + + +def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_exception=True): + original_exe = sys.executable + original_args = sys.argv + try: + sys.argv = [cmd] + list(args) + code = 0 + if capture_output: + buffer = io.StringIO() + redirect = contextlib.ExitStack() + redirect.enter_context(contextlib.redirect_stdout(buffer)) + redirect.enter_context(contextlib.redirect_stderr(buffer)) + else: + redirect = contextlib.nullcontext() + with redirect: + try: + run_module(cmd, run_name="__main__") + except SystemExit as e: + code = e.code + if code != 0 and raise_exception: + if capture_output: + raise RuntimeError(buffer.getvalue()) + else: + raise RuntimeError(f"Command {cmd} failed with exit code {code}.") + except Exception as e: + code = 1 + if raise_exception: + if capture_output: + raise RuntimeError(buffer.getvalue()) + else: + raise RuntimeError(f"Command {cmd} failed: {e}") + if restart_cli and code == 0: + os.execv(original_exe, [original_exe] + original_args) + if capture_output: + return code, buffer.getvalue() + else: + return code, "" + finally: + sys.argv = original_args + + +def build_command_tree(cmd): + tree = {} + if isinstance(cmd, click.Group): + for name, subcmd in cmd.commands.items(): + if cmd.name == "models" and name == "default": + tree[name] = MODELS + else: + tree[name] = build_command_tree(subcmd) + else: + tree = None + return tree + + +# Generate the command tree for autocompletion +COMMAND_TREE = build_command_tree(cli) if cli else {} + + +def get_completions(tokens, tree=COMMAND_TREE): + for token in tokens: + if token.startswith("-"): + continue + if tree and token in tree: + tree = tree[token] + else: + return [] + return list(tree.keys()) if tree else [] + + +@export +class FinishIteration(Exception): + def __init__(self, results=None): + self.results = results + + +USAGE = """ +Use an LLM to create SQL queries to answer questions from your database. +Examples: + +# Ask a question. +> \\llm 'Most visited urls?' + +# List available models +> \\llm models +> gpt-4o +> gpt-3.5-turbo + +# Change default model +> \\llm models default llama3 + +# Set api key (not required for local models) +> \\llm keys set openai + +# Install a model plugin +> \\llm install llm-ollama +> llm-ollama installed. + +# Plugins directory +# https://llm.datasette.io/en/stable/plugins/directory.html +""" +_SQL_CODE_FENCE = r"```sql\n(.*?)\n```" +PROMPT = """A MySQL database has the following schema: + +$db_schema + +Here is a sample row of data from each table: + +$sample_data + +Use the provided schema and the sample data to construct a SQL query that +can be run in MySQL to answer + +$question + +Explain the reason for choosing each table in the SQL query you have +written. Keep the explanation concise. +Finally include a sql query in a code fence such as this one: + +```sql +SELECT count(*) FROM table_name; +```""" + + +def ensure_mycli_template(replace=False): + if not replace: + code, _ = run_external_cmd("llm", "templates", "show", LLM_TEMPLATE_NAME, capture_output=True, raise_exception=False) + if code == 0: + return + run_external_cmd("llm", PROMPT, "--save", LLM_TEMPLATE_NAME) + return + + +@export +def handle_llm(text, cur) -> Tuple[str, Optional[str], float]: + _, verbosity, arg = parse_special_command(text) + if not arg.strip(): + output = [(None, None, None, USAGE)] + raise FinishIteration(output) + parts = shlex.split(arg) + restart = False + if "-c" in parts: + capture_output = True + use_context = False + elif "prompt" in parts: + capture_output = True + use_context = True + elif "install" in parts or "uninstall" in parts: + capture_output = False + use_context = False + restart = True + elif parts and parts[0] in LLM_CLI_COMMANDS: + capture_output = False + use_context = False + elif parts and parts[0] == "--help": + capture_output = False + use_context = False + else: + capture_output = True + use_context = True + if not use_context: + args = parts + if capture_output: + click.echo("Calling llm command") + start = time() + _, result = run_external_cmd("llm", *args, capture_output=capture_output) + end = time() + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + output = [(None, None, None, result)] + raise FinishIteration(output) + return (result if verbosity == Verbosity.SUCCINCT else "", sql, end - start) + else: + run_external_cmd("llm", *args, restart_cli=restart) + raise FinishIteration(None) + try: + ensure_mycli_template() + start = time() + context, sql = sql_using_llm(cur=cur, question=arg) + end = time() + if verbosity == Verbosity.SUCCINCT: + context = "" + return (context, sql, end - start) + except Exception as e: + raise RuntimeError(e) + + +@export +def is_llm_command(command) -> bool: + cmd, _, _ = parse_special_command(command) + return cmd in ("\\llm", "\\ai") + + +@export +def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]: + if cur is None: + raise RuntimeError("Connect to a database and try again.") + schema_query = """ + SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')') + FROM information_schema.columns + WHERE table_schema = DATABASE() + GROUP BY table_name + ORDER BY table_name + """ + tables_query = "SHOW TABLES" + sample_row_query = "SELECT * FROM `{table}` LIMIT 1" + click.echo("Preparing schema information to feed the llm") + cur.execute(schema_query) + db_schema = "\n".join([row[0] for (row,) in cur.fetchall()]) + cur.execute(tables_query) + sample_data = {} + for (table_name,) in cur.fetchall(): + try: + cur.execute(sample_row_query.format(table=table_name)) + except Exception: + continue + cols = [desc[0] for desc in cur.description] + row = cur.fetchone() + if row is None: + continue + sample_data[table_name] = list(zip(cols, row)) + args = [ + "--template", + LLM_TEMPLATE_NAME, + "--param", + "db_schema", + db_schema, + "--param", + "sample_data", + sample_data, + "--param", + "question", + question, + " ", + ] + click.echo("Invoking llm command with schema information") + _, result = run_external_cmd("llm", *args, capture_output=True) + click.echo("Received response from the llm command") + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + sql = "" + return (result, sql) diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py index ac946fb7..a9d9dcb5 100644 --- a/mycli/packages/special/main.py +++ b/mycli/packages/special/main.py @@ -1,4 +1,5 @@ from collections import namedtuple +from enum import Enum import logging from mycli.packages.special import export @@ -19,12 +20,22 @@ class CommandNotFound(Exception): pass +class Verbosity(Enum): + SUCCINCT = "succinct" + NORMAL = "normal" + VERBOSE = "verbose" + + @export def parse_special_command(sql): command, _, arg = sql.partition(" ") - verbose = "+" in command - command = command.strip().replace("+", "") - return (command, verbose, arg.strip()) + verbosity = Verbosity.NORMAL + if "+" in command: + verbosity = Verbosity.VERBOSE + elif "-" in command: + verbosity = Verbosity.SUCCINCT + command = command.strip().strip("+-") + return (command, verbosity, arg.strip()) @export @@ -52,7 +63,7 @@ def execute(cur, sql): """Execute a special command and return the results. If the special command is not supported a KeyError will be raised. """ - command, verbose, arg = parse_special_command(sql) + command, verbosity, arg = parse_special_command(sql) if (command not in COMMANDS) and (command.lower() not in COMMANDS): raise CommandNotFound @@ -72,7 +83,7 @@ def execute(cur, sql): if special_cmd.arg_type == NO_QUERY: return special_cmd.handler() elif special_cmd.arg_type == PARSED_QUERY: - return special_cmd.handler(cur=cur, arg=arg, verbose=verbose) + return special_cmd.handler(cur=cur, arg=arg, verbose=(verbosity == Verbosity.VERBOSE)) elif special_cmd.arg_type == RAW_QUERY: return special_cmd.handler(cur=cur, query=sql) diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index 692cacae..ec998ee4 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -7,6 +7,7 @@ from mycli.packages.completion_engine import suggest_type from mycli.packages.filepaths import complete_path, parse_path, suggest_path from mycli.packages.parseutils import last_word +from mycli.packages.special import llm from mycli.packages.special.favoritequeries import FavoriteQueries _logger = logging.getLogger(__name__) @@ -1192,6 +1193,19 @@ def get_completions(self, document, complete_event, smart_completion=None): elif suggestion["type"] == "file_name": file_names = self.find_files(word_before_cursor) completions.extend(file_names) + elif suggestion["type"] == "llm": + if not word_before_cursor: + tokens = document.text.split()[1:] + else: + tokens = document.text.split()[1:-1] + possible_entries = llm.get_completions(tokens) + subcommands = self.find_matches( + word_before_cursor, + possible_entries, + start_only=False, + fuzzy=True, + ) + completions.extend(subcommands) return completions diff --git a/pyproject.toml b/pyproject.toml index 1276512c..03a5cdc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,9 @@ dependencies = [ "pyperclip >= 1.8.1", "pyaes >= 1.6.1", "pyfzf >= 0.3.1", - "importlib_resources >= 5.0.0; python_version<'3.9'", + "llm>=0.19.0", + "setuptools", # Required by llm commands to install models + "pip", ] [build-system] @@ -34,6 +36,7 @@ build-backend = "setuptools.build_meta" [tool.setuptools_scm] + [project.optional-dependencies] ssh = ["paramiko", "sshtunnel"] dev = [ @@ -44,6 +47,8 @@ dev = [ "pytest-cov>=4.1.0", "tox>=4.8.0", "pdbpp>=0.10.3", + "paramiko", + "sshtunnel", ] [project.scripts] @@ -60,34 +65,21 @@ target-version = 'py39' line-length = 140 [tool.ruff.lint] -select = [ - 'A', - 'I', - 'E', - 'W', - 'F', - 'C4', - 'PIE', - 'TID', -] +select = ['A', 'I', 'E', 'W', 'F', 'C4', 'PIE', 'TID'] ignore = [ 'E401', # Multiple imports on one line 'E402', # Module level import not at top of file 'PIE808', # range() starting with 0 # https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules - 'E111', # indentation-with-invalid-multiple - 'E114', # indentation-with-invalid-multiple-comment - 'E117', # over-indented - 'W191', # tab-indentation + 'E111', # indentation-with-invalid-multiple + 'E114', # indentation-with-invalid-multiple-comment + 'E117', # over-indented + 'W191', # tab-indentation ] [tool.ruff.lint.isort] force-sort-within-sections = true -known-first-party = [ - 'mycli', - 'test', - 'steps', -] +known-first-party = ['mycli', 'test', 'steps'] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = 'all' @@ -95,7 +87,4 @@ ban-relative-imports = 'all' [tool.ruff.format] preview = true quote-style = 'preserve' -exclude = [ - 'build', - 'mycli_dev', -] +exclude = ['build', 'mycli_dev'] diff --git a/test/myclirc b/test/myclirc index fef49f2d..bd590158 100644 --- a/test/myclirc +++ b/test/myclirc @@ -160,8 +160,8 @@ foo_args = 'SELECT $1, "$2", "$3"' # Initial commands to execute when connecting to any database. [init-commands] +global_limit = set sql_select_limit=9999 # read_only = "SET SESSION TRANSACTION READ ONLY" -global_limit = "set sql_select_limit=9999" # Use the -d option to reference a DSN. diff --git a/test/test_llm_special.py b/test/test_llm_special.py new file mode 100644 index 00000000..a7fa578a --- /dev/null +++ b/test/test_llm_special.py @@ -0,0 +1,198 @@ +from unittest.mock import patch + +import pytest + +from mycli.packages.special.llm import ( + USAGE, + FinishIteration, + handle_llm, + is_llm_command, + sql_using_llm, +) + + +# Override executor fixture to avoid real DB connections during llm tests +@pytest.fixture +def executor(): + """Dummy executor fixture""" + return None + + +@patch("mycli.packages.special.llm.llm") +def test_llm_command_without_args(mock_llm, executor): + r""" + Invoking \llm without any arguments should print the usage and raise FinishIteration. + """ + assert mock_llm is not None + test_text = r"\llm" + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + # Should return usage message when no args provided + assert exc_info.value.args[0] == [(None, None, None, USAGE)] + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.run_external_cmd") +def test_llm_command_with_c_flag(mock_run_cmd, mock_llm, executor): + # Suppose the LLM returns some text without fenced SQL + mock_run_cmd.return_value = (0, "Hello, no SQL today.") + test_text = r"\llm -c 'Something?'" + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + # Expect raw output when no SQL fence found + assert exc_info.value.args[0] == [(None, None, None, "Hello, no SQL today.")] + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.run_external_cmd") +def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor): + # Return text containing a fenced SQL block + sql_text = "SELECT * FROM users;" + fenced = f"Here you go:\n```sql\n{sql_text}\n```" + mock_run_cmd.return_value = (0, fenced) + test_text = r"\llm -c 'Rewrite SQL'" + result, sql, duration = handle_llm(test_text, executor) + # Without verbose, result is empty, sql extracted + assert sql == sql_text + assert result == "" + assert isinstance(duration, float) + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.run_external_cmd") +def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor): + # 'models' is a known subcommand + test_text = r"\llm models" + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + mock_run_cmd.assert_called_once_with("llm", "models", restart_cli=False) + assert exc_info.value.args[0] is None + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.run_external_cmd") +def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor): + test_text = r"\llm --help" + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + mock_run_cmd.assert_called_once_with("llm", "--help", restart_cli=False) + assert exc_info.value.args[0] is None + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.run_external_cmd") +def test_llm_command_with_install_flag(mock_run_cmd, mock_llm, executor): + test_text = r"\llm install openai" + with pytest.raises(FinishIteration) as exc_info: + handle_llm(test_text, executor) + mock_run_cmd.assert_called_once_with("llm", "install", "openai", restart_cli=True) + assert exc_info.value.args[0] is None + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.ensure_mycli_template") +@patch("mycli.packages.special.llm.sql_using_llm") +def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + \llm prompt 'question' should use template and call sql_using_llm + """ + mock_sql_using_llm.return_value = ("CTX", "SELECT 1;") + test_text = r"\llm prompt 'Test?'" + context, sql, duration = handle_llm(test_text, executor) + mock_ensure_template.assert_called_once() + mock_sql_using_llm.assert_called() + assert context == "CTX" + assert sql == "SELECT 1;" + assert isinstance(duration, float) + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.ensure_mycli_template") +@patch("mycli.packages.special.llm.sql_using_llm") +def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + \llm 'question' treats as prompt and returns SQL + """ + mock_sql_using_llm.return_value = ("CTX2", "SELECT 2;") + test_text = r"\llm 'Top 10?'" + context, sql, duration = handle_llm(test_text, executor) + mock_ensure_template.assert_called_once() + mock_sql_using_llm.assert_called() + assert context == "CTX2" + assert sql == "SELECT 2;" + assert isinstance(duration, float) + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.ensure_mycli_template") +@patch("mycli.packages.special.llm.sql_using_llm") +def test_llm_command_question_verbose(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + \llm+ returns verbose context and SQL + """ + mock_sql_using_llm.return_value = ("NO_CTX", "SELECT 42;") + test_text = r"\llm- 'Succinct?'" + context, sql, duration = handle_llm(test_text, executor) + assert context == "" + assert sql == "SELECT 42;" + assert isinstance(duration, float) + + +def test_is_llm_command(): + # Valid llm command variants + for cmd in ["\\llm", "\\ai"]: + assert is_llm_command(cmd + " 'x'") + # Invalid commands + assert not is_llm_command("select * from table;") + + +def test_sql_using_llm_no_connection(): + # Should error if no database cursor provided + with pytest.raises(RuntimeError) as exc_info: + sql_using_llm(None, question="test") + assert "Connect to a database" in str(exc_info.value) + + +# Test sql_using_llm with dummy cursor and fenced SQL output +@patch("mycli.packages.special.llm.run_external_cmd") +def test_sql_using_llm_success(mock_run_cmd): + # Dummy cursor simulating database schema and sample data + class DummyCursor: + def __init__(self): + self._last = [] + + def execute(self, query): + if "information_schema.columns" in query: + self._last = [("table1(col1 int,col2 text)",), ("table2(colA varchar(20))",)] + elif query.strip().upper().startswith("SHOW TABLES"): + self._last = [("table1",), ("table2",)] + elif query.strip().upper().startswith("SELECT * FROM"): + self.description = [("col1", None), ("col2", None)] + self._row = (1, "abc") + + def fetchall(self): + return getattr(self, "_last", []) + + def fetchone(self): + return getattr(self, "_row", None) + + dummy_cur = DummyCursor() + # Simulate llm CLI returning a fenced SQL result + sql_text = "SELECT 1, 'abc';" + fenced = f"Note\n```sql\n{sql_text}\n```" + mock_run_cmd.return_value = (0, fenced) + result, sql = sql_using_llm(dummy_cur, question="dummy") + assert result == fenced + assert sql == sql_text + + +# Test handle_llm supports alias prefixes without args +@pytest.mark.parametrize("prefix", [r"\\llm", r".llm", r"\\ai", r".ai"]) +def test_handle_llm_aliases_without_args(prefix, executor, monkeypatch): + # Ensure llm is available + from mycli.packages.special import llm as llm_module + + monkeypatch.setattr(llm_module, "llm", object()) + with pytest.raises(FinishIteration) as exc_info: + handle_llm(prefix, executor) + assert exc_info.value.args[0] == [(None, None, None, USAGE)]