Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 44 additions & 34 deletions orchestrator/api/api_v1/endpoints/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,32 @@
PathsResponse,
SearchResultsSchema,
)
from orchestrator.schemas.search_requests import SearchRequest
from orchestrator.search.core.exceptions import InvalidCursorError, QueryStateNotFoundError
from orchestrator.search.core.types import EntityType, UIType
from orchestrator.search.filters.definitions import generate_definitions
from orchestrator.search.retrieval import SearchQueryState, execute_search, execute_search_for_export
from orchestrator.search.retrieval.builder import build_paths_query, create_path_autocomplete_lquery, process_path_rows
from orchestrator.search.filters.definitions import TypeDefinition, generate_definitions
from orchestrator.search.query import QueryState, engine
from orchestrator.search.query.builder import build_paths_query, create_path_autocomplete_lquery, process_path_rows
from orchestrator.search.query.queries import ExportQuery, SelectQuery
from orchestrator.search.query.results import SearchResult
from orchestrator.search.query.validation import is_lquery_syntactically_valid
from orchestrator.search.retrieval.pagination import PageCursor, encode_next_page_cursor
from orchestrator.search.retrieval.validation import is_lquery_syntactically_valid
from orchestrator.search.schemas.parameters import (
ProcessSearchParameters,
ProductSearchParameters,
SearchParameters,
SubscriptionSearchParameters,
WorkflowSearchParameters,
)
from orchestrator.search.schemas.results import SearchResult, TypeDefinition

router = APIRouter()
logger = structlog.get_logger(__name__)


async def _perform_search_and_fetch(
search_params: SearchParameters | None = None,
entity_type: EntityType | None = None,
request: SearchRequest | None = None,
cursor: str | None = None,
query_id: str | None = None,
) -> SearchResultsSchema[SearchResult]:
"""Execute search with optional pagination.
Args:
search_params: Search parameters for new search
entity_type: Entity type to search
request: Search request for new search
cursor: Pagination cursor (loads saved query state)
query_id: Saved query ID to retrieve and execute
Expand All @@ -58,27 +55,31 @@ async def _perform_search_and_fetch(
"""
try:
page_cursor: PageCursor | None = None
query: SelectQuery

if cursor:
page_cursor = PageCursor.decode(cursor)
query_state = SearchQueryState.load_from_id(page_cursor.query_id)
query_state = QueryState.load_from_id(page_cursor.query_id, SelectQuery)
query = query_state.query

elif query_id:
query_state = SearchQueryState.load_from_id(query_id)
elif search_params:
query_state = SearchQueryState(parameters=search_params, query_embedding=None)
query_state = QueryState.load_from_id(query_id, SelectQuery)
query = query_state.query

elif request and entity_type:
query = request.to_query(entity_type)
query_state = QueryState(query=query, query_embedding=None)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Either search_params, cursor, or query_id must be provided",
detail="Either (request + entity_type), cursor, or query_id must be provided",
)

search_response = await execute_search(
query_state.parameters, db.session, page_cursor, query_state.query_embedding
)
search_response = await engine.execute_search(query, db.session, page_cursor, query_state.query_embedding)
if not search_response.results:
return SearchResultsSchema(search_metadata=search_response.metadata)

next_page_cursor = encode_next_page_cursor(search_response, page_cursor, query_state.parameters)
next_page_cursor = encode_next_page_cursor(search_response, page_cursor, query)
has_next_page = next_page_cursor is not None
page_info = PageInfoSchema(has_next_page=has_next_page, next_page_cursor=next_page_cursor)

Expand All @@ -98,34 +99,34 @@ async def _perform_search_and_fetch(

@router.post("/subscriptions", response_model=SearchResultsSchema[SearchResult])
async def search_subscriptions(
search_params: SubscriptionSearchParameters,
request: SearchRequest,
cursor: str | None = None,
) -> SearchResultsSchema[SearchResult]:
return await _perform_search_and_fetch(search_params, cursor)
return await _perform_search_and_fetch(EntityType.SUBSCRIPTION, request, cursor)


@router.post("/workflows", response_model=SearchResultsSchema[SearchResult])
async def search_workflows(
search_params: WorkflowSearchParameters,
request: SearchRequest,
cursor: str | None = None,
) -> SearchResultsSchema[SearchResult]:
return await _perform_search_and_fetch(search_params, cursor)
return await _perform_search_and_fetch(EntityType.WORKFLOW, request, cursor)


@router.post("/products", response_model=SearchResultsSchema[SearchResult])
async def search_products(
search_params: ProductSearchParameters,
request: SearchRequest,
cursor: str | None = None,
) -> SearchResultsSchema[SearchResult]:
return await _perform_search_and_fetch(search_params, cursor)
return await _perform_search_and_fetch(EntityType.PRODUCT, request, cursor)


@router.post("/processes", response_model=SearchResultsSchema[SearchResult])
async def search_processes(
search_params: ProcessSearchParameters,
request: SearchRequest,
cursor: str | None = None,
) -> SearchResultsSchema[SearchResult]:
return await _perform_search_and_fetch(search_params, cursor)
return await _perform_search_and_fetch(EntityType.PROCESS, request, cursor)


@router.get(
Expand Down Expand Up @@ -191,7 +192,7 @@ async def export_by_query_id(query_id: str) -> ExportResponse:
as flattened records suitable for CSV download.
Args:
query_id: Query UUID
query_id: QueryTypes UUID
Returns:
ExportResponse containing 'page' with an array of flattened entity records.
Expand All @@ -200,8 +201,17 @@ async def export_by_query_id(query_id: str) -> ExportResponse:
HTTPException: 404 if query not found, 400 if invalid data
"""
try:
query_state = SearchQueryState.load_from_id(query_id)
export_records = await execute_search_for_export(query_state, db.session)
# Load SelectQuery from the database (what gets saved during search)
query_state = QueryState.load_from_id(query_id, SelectQuery)

# Convert to ExportQuery with export-appropriate limit
export_query = ExportQuery(
entity_type=query_state.query.entity_type,
filters=query_state.query.filters,
query_text=query_state.query.query_text,
)

export_records = await engine.execute_export(export_query, db.session, query_state.query_embedding)
return ExportResponse(page=export_records)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# limitations under the License.

import json
import re

import structlog
from sqlalchemy import and_
Expand All @@ -22,45 +21,21 @@
from orchestrator.db.models import AiSearchIndex
from orchestrator.search.core.types import EntityType
from orchestrator.search.indexing.registry import ENTITY_CONFIG_REGISTRY
from orchestrator.search.schemas.parameters import BaseSearchParameters
from orchestrator.search.schemas.results import SearchResult
from orchestrator.search.query.queries import BaseQuery
from orchestrator.search.query.results import SearchResult

logger = structlog.get_logger(__name__)


def generate_highlight_indices(text: str, term: str) -> list[tuple[int, int]]:
"""Finds all occurrences of individual words from the term, including both word boundary and substring matches."""
if not text or not term:
return []

all_matches = []
words = [w.strip() for w in term.split() if w.strip()]

for word in words:
# First find word boundary matches
word_boundary_pattern = rf"\b{re.escape(word)}\b"
word_matches = list(re.finditer(word_boundary_pattern, text, re.IGNORECASE))
all_matches.extend([(m.start(), m.end()) for m in word_matches])

# Then find all substring matches
substring_pattern = re.escape(word)
substring_matches = list(re.finditer(substring_pattern, text, re.IGNORECASE))
all_matches.extend([(m.start(), m.end()) for m in substring_matches])

return sorted(set(all_matches))


def display_filtered_paths_only(
results: list[SearchResult], search_params: BaseSearchParameters, db_session: WrappedSession
) -> None:
def display_filtered_paths_only(results: list[SearchResult], query: BaseQuery, db_session: WrappedSession) -> None:
"""Display only the paths that were searched for in the results."""
if not results:
logger.info("No results found.")
return

logger.info("--- Search Results ---")

searched_paths = search_params.filters.get_all_paths() if search_params.filters else []
searched_paths = query.filters.get_all_paths() if query.filters else []
if not searched_paths:
return

Expand Down
46 changes: 22 additions & 24 deletions orchestrator/cli/search/search_explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import typer
from pydantic import ValidationError

from orchestrator.cli.search.display import display_filtered_paths_only, display_results
from orchestrator.db import db
from orchestrator.search.core.types import EntityType, FilterOp, UIType
from orchestrator.search.filters import EqualityFilter, FilterTree, LtreeFilter, PathFilter
from orchestrator.search.retrieval import execute_search
from orchestrator.search.retrieval.utils import display_filtered_paths_only, display_results
from orchestrator.search.retrieval.validation import get_structured_filter_schema
from orchestrator.search.schemas.parameters import BaseSearchParameters
from orchestrator.search.query import engine
from orchestrator.search.query.queries import SelectQuery
from orchestrator.search.query.validation import get_structured_filter_schema

app = typer.Typer(help="Experiment with the subscription search indexes.")

Expand All @@ -31,16 +31,14 @@ def structured(path: str, value: str, entity_type: EntityType = EntityType.SUBSC
...
"""
path_filter = PathFilter(path=path, condition=EqualityFilter(op=FilterOp.EQ, value=value), value_kind=UIType.STRING)
search_params = BaseSearchParameters.create(
entity_type=entity_type, filters=FilterTree.from_flat_and([path_filter]), limit=limit
)
search_response = asyncio.run(execute_search(search_params=search_params, db_session=db.session))
display_filtered_paths_only(search_response.results, search_params, db.session)
query = SelectQuery(entity_type=entity_type, filters=FilterTree.from_flat_and([path_filter]), limit=limit)
search_response = asyncio.run(engine.execute_search(query=query, db_session=db.session))
display_filtered_paths_only(search_response.results, query, db.session)
display_results(search_response.results, db.session, "Match")


@app.command()
def semantic(query: str, entity_type: EntityType = EntityType.SUBSCRIPTION, limit: int = 10) -> None:
def semantic(query_text: str, entity_type: EntityType = EntityType.SUBSCRIPTION, limit: int = 10) -> None:
"""Finds subscriptions that are conceptually most similar to the query text.

Example:
Expand All @@ -52,8 +50,8 @@ def semantic(query: str, entity_type: EntityType = EntityType.SUBSCRIPTION, limi
},
...
"""
search_params = BaseSearchParameters.create(entity_type=entity_type, query=query, limit=limit)
search_response = asyncio.run(execute_search(search_params=search_params, db_session=db.session))
query = SelectQuery(entity_type=entity_type, query_text=query_text, limit=limit)
search_response = asyncio.run(engine.execute_search(query=query, db_session=db.session))
display_results(search_response.results, db.session, "Distance")


Expand All @@ -70,16 +68,16 @@ def fuzzy(term: str, entity_type: EntityType = EntityType.SUBSCRIPTION, limit: i
},
...
"""
search_params = BaseSearchParameters.create(entity_type=entity_type, query=term, limit=limit)
search_response = asyncio.run(execute_search(search_params=search_params, db_session=db.session))
query = SelectQuery(entity_type=entity_type, query_text=term, limit=limit)
search_response = asyncio.run(engine.execute_search(query=query, db_session=db.session))
display_results(search_response.results, db.session, "Similarity")


@app.command()
def hierarchical(
op: str = typer.Argument(..., help="The hierarchical operation to perform."),
path: str = typer.Argument(..., help="The ltree path or lquery pattern for the operation."),
query: str | None = typer.Option(None, "--query", "-f", help="An optional fuzzy term to rank the results."),
query_text: str | None = typer.Option(None, "--query", "-q", help="An optional fuzzy term to rank the results."),
entity_type: EntityType = EntityType.SUBSCRIPTION,
limit: int = 10,
) -> None:
Expand All @@ -96,23 +94,23 @@ def hierarchical(

path_filter = PathFilter(path="ltree_hierarchical_filter", condition=condition, value_kind=UIType.STRING)

search_params = BaseSearchParameters.create(
entity_type=entity_type, filters=FilterTree.from_flat_and([path_filter]), query=query, limit=limit
query = SelectQuery(
entity_type=entity_type, filters=FilterTree.from_flat_and([path_filter]), query_text=query_text, limit=limit
)
search_response = asyncio.run(execute_search(search_params=search_params, db_session=db.session))
search_response = asyncio.run(engine.execute_search(query=query, db_session=db.session))
display_results(search_response.results, db.session, "Hierarchical Score")


@app.command()
def hybrid(query: str, term: str, entity_type: EntityType = EntityType.SUBSCRIPTION, limit: int = 10) -> None:
def hybrid(query_text: str, term: str, entity_type: EntityType = EntityType.SUBSCRIPTION, limit: int = 10) -> None:
"""Performs a hybrid search, combining semantic and fuzzy matching.

Example:
dotenv run python main.py search hybrid "reptile store" "Kingswood"
"""
search_params = BaseSearchParameters.create(entity_type=entity_type, query=query, limit=limit)
logger.info("Executing Hybrid Search", query=query, term=term)
search_response = asyncio.run(execute_search(search_params=search_params, db_session=db.session))
query = SelectQuery(entity_type=entity_type, query_text=query_text, limit=limit)
logger.info("Executing Hybrid Search", query_text=query_text, term=term)
search_response = asyncio.run(engine.execute_search(query=query, db_session=db.session))
display_results(search_response.results, db.session, "Hybrid Score")


Expand Down Expand Up @@ -198,8 +196,8 @@ def nested_demo(entity_type: EntityType = EntityType.SUBSCRIPTION, limit: int =
}
)

params = BaseSearchParameters.create(entity_type=entity_type, filters=tree, limit=limit)
search_response = asyncio.run(execute_search(params, db.session))
query = SelectQuery(entity_type=entity_type, filters=tree, limit=limit)
search_response = asyncio.run(engine.execute_search(query=query, db_session=db.session))

display_results(search_response.results, db.session, "Score")

Expand Down
20 changes: 11 additions & 9 deletions orchestrator/cli/search/speedtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from orchestrator.search.core.embedding import QueryEmbedder
from orchestrator.search.core.types import EntityType
from orchestrator.search.core.validators import is_uuid
from orchestrator.search.retrieval.engine import execute_search
from orchestrator.search.schemas.parameters import BaseSearchParameters
from orchestrator.search.query import engine
from orchestrator.search.query.queries import SelectQuery

logger = structlog.get_logger(__name__)
console = Console()
Expand Down Expand Up @@ -50,23 +50,25 @@ async def generate_embeddings_for_queries(queries: list[str]) -> dict[str, list[
return embedding_lookup


async def run_single_query(query: str, embedding_lookup: dict[str, list[float]]) -> dict[str, Any]:
search_params = BaseSearchParameters(entity_type=EntityType.SUBSCRIPTION, query=query, limit=30)
async def run_single_query(query_text: str, embedding_lookup: dict[str, list[float]]) -> dict[str, Any]:
query = SelectQuery(entity_type=EntityType.SUBSCRIPTION, query_text=query_text, limit=30)

query_embedding = None

if is_uuid(query):
logger.debug("Using fuzzy-only ranking for full UUID", query=query)
if is_uuid(query_text):
logger.debug("Using fuzzy-only ranking for full UUID", query_text=query_text)
else:
query_embedding = embedding_lookup[query]
query_embedding = embedding_lookup[query_text]

with db.session as session:
start_time = time.perf_counter()
response = await execute_search(search_params, session, cursor=None, query_embedding=query_embedding)
response = await engine.execute_search(
query=query, db_session=session, cursor=None, query_embedding=query_embedding
)
end_time = time.perf_counter()

return {
"query": query,
"query": query_text,
"time": end_time - start_time,
"results": len(response.results),
"search_type": response.metadata.search_type if hasattr(response, "metadata") else "unknown",
Expand Down
Loading