Skip to content

fix: Uses SQL tables if available for script calls #4785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions frontend/src/core/codemirror/completion/utils.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/* Copyright 2024 Marimo. All rights reserved. */
import type { CompletionSource } from "@codemirror/autocomplete";

/**
Expand Down
35 changes: 35 additions & 0 deletions frontend/src/core/codemirror/language/__tests__/sql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,41 @@ _df = mo.sql(
expect(offset).toBe(26);
});

it("should format local tables", () => {
const testDatasets = [
{
name: "in_sql",
source_type: "local",
},
{
name: "not_in_sql",
source_type: "local",
},
{
name: "other",
source_type: "duckdb",
},
];
const mockStore = store;
mockStore.set(datasetsAtom, { tables: testDatasets } as DatasetsState);

const code = "SELECT * FROM in_sql join other";
adapter.showOutput = false;
const [wrappedCode, offset] = adapter.transformOut(code);
expect(wrappedCode).toMatchInlineSnapshot(`
"_df = mo.sql(
f"""
SELECT * FROM in_sql join other
""",
output=False,
tables={
"in_sql": in_sql
}
)"
`);
expect(offset).toBe(24);
});

it("should preserve Python comments", () => {
const pythonCode = '# hello\n_df = mo.sql("""SELECT * FROM {df}""")';
const [innerCode] = adapter.transformIn(pythonCode);
Expand Down
23 changes: 22 additions & 1 deletion frontend/src/core/codemirror/language/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,29 @@ export class SQLLanguageAdapter implements LanguageAdapter {
const showOutputParam = this.showOutput ? "" : ",\n output=False";
const engineParam =
this.engine === this.defaultEngine ? "" : `,\n engine=${this.engine}`;
const end = `\n """${showOutputParam}${engineParam}\n)`;

const localTables = store.get(datasetTablesAtom);
// Need to check table exists with with boundaries
const tablesString = localTables
.filter((table) => {
const matched = Boolean(
new RegExp(`\\b${table.name}\\b`, "g").test(escapedCode),
);
return table.source_type === "local" && matched;
})
.map((table) => `"${table.name}": ${table.name}`)
.join(",\n ");
// Table params only valid in DuckDB
const tablesParam =
this.engine === this.defaultEngine && tablesString
? `,\n tables={\n ${tablesString}\n }`
: "";

const end = `\n """${showOutputParam}${engineParam}${tablesParam}\n)`;

// TODO: Ruff-wasm is now more main stream (adopted by jupyter-ruff)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this manual formatting is quite easy (and gets reformatted with black in the backend). im not sure this is a TODO we want to do for the extra dependency

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed that the line length varies the behavior (ruff rather fit everything on one line if it can). I think the manual formatting may eventually be more effort than it is worth

// we may consider using it opposed to the current approach of manually
// formatting.
return [
[...commentLines, start].join("\n") + indentOneTab(escapedCode) + end,
start.length + 1,
Expand Down
6 changes: 3 additions & 3 deletions marimo/_data/preview_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from marimo._plugins.ui._impl.tables.utils import get_table_manager_or_none
from marimo._runtime.requests import PreviewDatasetColumnRequest
from marimo._sql.utils import wrapped_sql
from marimo._sql.utils import fetch_one, wrapped_sql

LOGGER = _loggers.marimo_logger()

Expand Down Expand Up @@ -148,10 +148,10 @@ def get_column_preview_for_duckdb(
from altair import MaxRowsError

try:
total_rows: int = wrapped_sql(
total_rows: int = fetch_one(
f"SELECT COUNT(*) FROM {fully_qualified_table_name}",
connection=None,
).fetchone()[0] # type: ignore[index]
)[0] # type: ignore[index]

if total_rows <= CHART_MAX_ROWS:
relation = wrapped_sql(
Expand Down
10 changes: 5 additions & 5 deletions marimo/_data/sql_summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from marimo._data.get_datasets import _db_type_to_data_type
from marimo._data.models import ColumnSummary, DataType
from marimo._sql.utils import wrapped_sql
from marimo._sql.utils import fetch_one


def get_sql_summary(
Expand Down Expand Up @@ -64,9 +64,9 @@ def get_sql_summary(
FROM {table_name}
""" # noqa: E501

stats_result: tuple[int, ...] | None = wrapped_sql(
stats_result: tuple[int, ...] | None = fetch_one(
stats_query, connection=None
).fetchone()
)
if stats_result is None:
raise ValueError(
f"Column {column_name} not found in table {table_name}"
Expand Down Expand Up @@ -159,9 +159,9 @@ def get_column_type(
AND column_name = '{column_name}'
"""

column_info_result: tuple[str] | None = wrapped_sql(
column_info_result: tuple[str] | None = fetch_one(
column_info_query, connection=None
).fetchone()
)
if column_info_result is None:
raise ValueError(
f"Column {column_name} not found in table {table_name}"
Expand Down
8 changes: 6 additions & 2 deletions marimo/_sql/engines/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def source(self) -> str:
def dialect(self) -> str:
return "clickhouse"

def execute(self, query: str) -> Any:
def execute(
self, query: str, _tables: Optional[dict[str, Any]] = None
) -> Any:
import chdb # type: ignore

query = query.strip()
Expand Down Expand Up @@ -234,7 +236,9 @@ def source(self) -> str:
def dialect(self) -> str:
return "clickhouse"

def execute(self, query: str) -> Any:
def execute(
self, query: str, _tables: Optional[dict[str, Any]] = None
) -> Any:
if self._connection is None:
return None

Expand Down
6 changes: 4 additions & 2 deletions marimo/_sql/engines/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def source(self) -> str:
def dialect(self) -> str:
return "duckdb"

def execute(self, query: str) -> Any:
relation = wrapped_sql(query, self._connection)
def execute(
self, query: str, tables: Optional[dict[str, Any]] = None
) -> Any:
relation = wrapped_sql(query, self._connection, tables)

# Invalid / empty query
if relation is None:
Expand Down
4 changes: 3 additions & 1 deletion marimo/_sql/engines/ibis.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def dialect(self) -> str:

return str(self._backend.dialect)

def execute(self, query: str) -> Any:
def execute(
self, query: str, _tables: Optional[dict[str, Any]] = None
) -> Any:
query_expr = self._backend.sql(query)

sql_output_format = self.sql_output_format()
Expand Down
4 changes: 3 additions & 1 deletion marimo/_sql/engines/pyiceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def source(self) -> str:
def dialect(self) -> str:
return "iceberg"

def execute(self, query: str) -> Any:
def execute(
self, _query: str, _tables: Optional[dict[str, Any]] = None
) -> Any:
raise NotImplementedError(
"PyIceberg does not support direct SQL execution"
)
Expand Down
4 changes: 3 additions & 1 deletion marimo/_sql/engines/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def source(self) -> str:
def dialect(self) -> str:
return str(self._engine.dialect.name)

def execute(self, query: str) -> Any:
def execute(
self, query: str, _tables: Optional[dict[str, Any]] = None
) -> Any:
sql_output_format = self.sql_output_format()

from sqlalchemy import text
Expand Down
4 changes: 3 additions & 1 deletion marimo/_sql/engines/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def inference_config(self) -> InferenceConfig:
pass

@abstractmethod
def execute(self, query: str) -> Any:
def execute(
self, query: str, tables: Optional[dict[str, Any]] = None
) -> Any:
"""Execute a SQL query and return a dataframe."""
pass

Expand Down
16 changes: 14 additions & 2 deletions marimo/_sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def sql(
| ChdbConnection
| IbisEngine
] = None,
tables: Optional[dict[str, Any]] = None,
) -> Any:
"""
Execute a SQL query.
Expand All @@ -54,7 +55,10 @@ def sql(
query: The SQL query to execute.
output: Whether to display the result in the UI. Defaults to True.
engine: Optional SQL engine to use. Can be a SQLAlchemy, Clickhouse, or DuckDB engine.
If None, uses DuckDB.
If None, uses DuckDB.
tables: Optional dictionary of tables to use in the query. This is only
used for dataframe queries in for DuckDB, with global variables
as a fallback.

Returns:
The result of the query.
Expand All @@ -81,7 +85,15 @@ def sql(
"Unsupported engine. Must be a SQLAlchemy, Ibis, Clickhouse, or DuckDB engine."
)

df = sql_engine.execute(query)
if isinstance(sql_engine, DuckDBEngine):
df = sql_engine.execute(query, tables=tables)
elif tables is not None:
raise ValueError(
"The tables argument is only supported for DuckDB engine."
)
else:
df = sql_engine.execute(query)

if df is None:
return None

Expand Down
38 changes: 33 additions & 5 deletions marimo/_sql/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast

from marimo._data.models import DataType
from marimo._dependencies.dependencies import DependencyManager
Expand All @@ -17,7 +17,8 @@
def wrapped_sql(
query: str,
connection: Optional[duckdb.DuckDBPyConnection],
) -> duckdb.DuckDBPyRelation:
tables: Optional[dict[str, Any]] = None,
) -> Optional[duckdb.DuckDBPyRelation]:
DependencyManager.duckdb.require("to execute sql")

# In Python globals() are scoped to modules; since this function
Expand All @@ -31,19 +32,46 @@ def wrapped_sql(

connection = cast(duckdb.DuckDBPyConnection, duckdb)

if tables is None:
tables = {}

previous_globals = {}
try:
ctx = get_context()
previous_globals = ctx.globals.copy()
ctx.globals.update(tables)
tables = ctx.globals
except ContextNotInitializedError:
relation = connection.sql(query=query)
else:
pass

relation = None
try:
relation = eval(
"connection.sql(query=query)",
ctx.globals,
tables,
{"query": query, "connection": connection},
)
import duckdb

assert isinstance(relation, (type(None), duckdb.DuckDBPyRelation))
finally:
if previous_globals:
ctx.globals.clear()
ctx.globals.update(previous_globals)
return relation


def fetch_one(
query: str,
connection: Optional[duckdb.DuckDBPyConnection] = None,
tables: Optional[dict[str, Any]] = None,
) -> tuple[Any, ...] | None:
stats_table = wrapped_sql(query, connection=connection, tables=tables)
if stats_table is None:
return None
return stats_table.fetchone()


def raise_df_import_error(pkg: str) -> None:
raise ModuleNotFoundError(
"pandas or polars is required to execute sql. "
Expand Down
54 changes: 54 additions & 0 deletions tests/_sql/external_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "duckdb==1.2.2",
# "marimo",
# "pandas==2.2.3",
# "sqlglot==26.16.4",
# ]
# ///

import marimo

__generated_with = "0.13.4"
app = marimo.App(width="medium")

with app.setup:
import pandas as pd

import marimo as mo

df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
not_used = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})


@app.cell
def script_hook_args(df1):
_df = mo.sql(
"""
SELECT * FROM df1
""",
tables={"df1": df1},
)
return


@app.cell
def script_hook_no_args():
_df = mo.sql(
"""
SELECT * FROM df
""",
tables={"df": df},
)
return


@app.cell
def _():
df1 = df
return (df1,)


if __name__ == "__main__":
app.run()
Loading
Loading