Skip to content

dynamic filters for search router - test code #513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: feature/educational_resources_search
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 38 additions & 65 deletions src/routers/search_router.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import abc
from typing import TypeVar, Generic, Any, Type, Literal, Annotated, TypeAlias
from typing import TypeVar, Generic, Any, Type, Literal, Annotated, TypeAlias, Optional

from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from fastapi import APIRouter, HTTPException, Query, Depends
from pydantic import BaseModel, create_model
from pydantic.generics import GenericModel
from sqlmodel import SQLModel, select, Field
from starlette import status
Expand Down Expand Up @@ -83,11 +83,20 @@ def create(self, url_prefix: str) -> APIRouter:
read_class = resource_read(self.resource_class) # type: ignore
indexed_fields: TypeAlias = Literal[tuple(self.indexed_fields)] # type: ignore

# Dynamically create a query model from indexed and linked fields
def make_dynamic_query_model() -> Type[BaseModel]:
fields = {
field: (Optional[str], None)
for field in self.indexed_fields.union(self.linked_fields)
}
return create_model(f"{self.resource_name_plural.capitalize()}SearchParams", **fields)

SearchQueryModel = make_dynamic_query_model()

@router.get(
f"{url_prefix}/search/{self.resource_name_plural}/v1",
tags=["search"],
description=f"""Search for {self.resource_name_plural}.""",
# response_model=SearchResult[read_class], # This gives errors, so not used.
)
def search(
search_query: Annotated[
Expand All @@ -97,62 +106,15 @@ def search(
examples=["Name of the resource"],
),
],
exact_match: Annotated[
bool,
Query(
description="If true, it searches for an exact match.",
),
] = False,
search_fields: Annotated[
list[indexed_fields] | None,
Query(
description="Search in these fields. If empty, the query will be matched "
"against all fields. Do not use the '--' option in Swagger, it is a Swagger "
"artifact.",
),
] = None,
platforms: Annotated[
list[str] | None,
Query(
description="Search for resources of these platforms. If empty, results from "
"all platforms will be returned.",
examples=["huggingface", "openml"],
),
] = None,
date_modified_after: Annotated[
str | None,
Query(
description="Search for resources modified after this date "
"(yyyy-mm-dd, inclusive).",
pattern="[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]",
examples=["2023-01-01"],
),
] = None,
date_modified_before: Annotated[
str | None,
Query(
description="Search for resources modified before this date "
"(yyyy-mm-dd, not inclusive).",
pattern="[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]",
examples=["2023-01-01"],
),
] = None,
sort_by_id: Annotated[
bool,
Query(
description="If true, the results are sorted by id."
"By default they are sorted by best score.",
),
] = False,
dynamic_params: SearchQueryModel = Depends(),
exact_match: bool = False,
platforms: list[str] | None = None,
date_modified_after: str | None = None,
date_modified_before: str | None = None,
sort_by_id: bool = False,
limit: Annotated[int, Query(ge=1, le=LIMIT_MAX)] = 10,
offset: Annotated[int, Query(ge=0)] = 0,
get_all: Annotated[
bool,
Query(
description="If true, a request to the database is made to retrieve all data. "
"If false, only the indexed information is returned.",
),
] = False,
get_all: bool = False,
):
try:
with DbSession() as session:
Expand All @@ -168,30 +130,38 @@ def search(
detail=f"The available platforms are: {platform_names}",
)

fields = search_fields if search_fields else self.indexed_fields
# Use provided dynamic search fields if set
extra_search_fields = {k: v for k, v in dynamic_params.dict().items() if v is not None}

fields = list(extra_search_fields.keys()) or list(self.indexed_fields)
query_matches: list[dict[str, dict[str, str | dict[str, str]]]] = []
if exact_match:
query_matches = [
{"match": {f: {"query": search_query, "operator": "and"}}} for f in fields
]
else:
query_matches = [{"match": {f: search_query}} for f in fields]

for f in fields:
if exact_match:
query_matches.append({"match": {f: {"query": search_query, "operator": "and"}}})
else:
query_matches.append({"match": {f: search_query}})

query = {"bool": {"should": query_matches, "minimum_should_match": 1}}
must_clause = []

if platforms:
platform_matches = [{"match": {"platform": p}} for p in platforms]
must_clause.append(
{"bool": {"should": platform_matches, "minimum_should_match": 1}}
)

if date_modified_after or date_modified_before:
date_range = {}
if date_modified_after:
date_range["gte"] = date_modified_after
if date_modified_before:
date_range["lt"] = date_modified_before
must_clause.append({"range": {"date_modified": date_range}})

if must_clause:
query["bool"]["must"] = must_clause

sort: dict[str, str | dict[str, str]] = {}
if sort_by_id:
sort = {"identifier": "asc"}
Expand All @@ -201,7 +171,9 @@ def search(
result = ElasticsearchSingleton().client.search(
index=self.es_index, query=query, from_=offset, size=limit, sort=sort
)

total_hits = result["hits"]["total"]["value"]

if get_all:
identifiers = [hit["_source"]["identifier"] for hit in result["hits"]["hits"]]
resources: list[SQLModel] = self._db_query(
Expand All @@ -212,6 +184,7 @@ def search(
self._cast_resource(read_class, hit["_source"])
for hit in result["hits"]["hits"]
]

return SearchResult[read_class]( # type: ignore
total_hits=total_hits,
resources=resources,
Expand Down