diff --git a/src/routers/search_router.py b/src/routers/search_router.py index 127523f76..68746b47f 100644 --- a/src/routers/search_router.py +++ b/src/routers/search_router.py @@ -1,8 +1,8 @@ import abc -from typing import TypeVar, Generic, Any, Type, Literal, Annotated, TypeAlias +from typing import TypeVar, Generic, Any, Type, Literal, Annotated, TypeAlias, Optional -from fastapi import APIRouter, HTTPException, Query -from pydantic import BaseModel +from fastapi import APIRouter, HTTPException, Query, Depends +from pydantic import BaseModel, create_model from pydantic.generics import GenericModel from sqlmodel import SQLModel, select, Field from starlette import status @@ -83,11 +83,20 @@ def create(self, url_prefix: str) -> APIRouter: read_class = resource_read(self.resource_class) # type: ignore indexed_fields: TypeAlias = Literal[tuple(self.indexed_fields)] # type: ignore + # Dynamically create a query model from indexed and linked fields + def make_dynamic_query_model() -> Type[BaseModel]: + fields = { + field: (Optional[str], None) + for field in self.indexed_fields.union(self.linked_fields) + } + return create_model(f"{self.resource_name_plural.capitalize()}SearchParams", **fields) + + SearchQueryModel = make_dynamic_query_model() + @router.get( f"{url_prefix}/search/{self.resource_name_plural}/v1", tags=["search"], description=f"""Search for {self.resource_name_plural}.""", - # response_model=SearchResult[read_class], # This gives errors, so not used. ) def search( search_query: Annotated[ @@ -97,62 +106,15 @@ def search( examples=["Name of the resource"], ), ], - exact_match: Annotated[ - bool, - Query( - description="If true, it searches for an exact match.", - ), - ] = False, - search_fields: Annotated[ - list[indexed_fields] | None, - Query( - description="Search in these fields. If empty, the query will be matched " - "against all fields. Do not use the '--' option in Swagger, it is a Swagger " - "artifact.", - ), - ] = None, - platforms: Annotated[ - list[str] | None, - Query( - description="Search for resources of these platforms. If empty, results from " - "all platforms will be returned.", - examples=["huggingface", "openml"], - ), - ] = None, - date_modified_after: Annotated[ - str | None, - Query( - description="Search for resources modified after this date " - "(yyyy-mm-dd, inclusive).", - pattern="[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]", - examples=["2023-01-01"], - ), - ] = None, - date_modified_before: Annotated[ - str | None, - Query( - description="Search for resources modified before this date " - "(yyyy-mm-dd, not inclusive).", - pattern="[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]", - examples=["2023-01-01"], - ), - ] = None, - sort_by_id: Annotated[ - bool, - Query( - description="If true, the results are sorted by id." - "By default they are sorted by best score.", - ), - ] = False, + dynamic_params: SearchQueryModel = Depends(), + exact_match: bool = False, + platforms: list[str] | None = None, + date_modified_after: str | None = None, + date_modified_before: str | None = None, + sort_by_id: bool = False, limit: Annotated[int, Query(ge=1, le=LIMIT_MAX)] = 10, offset: Annotated[int, Query(ge=0)] = 0, - get_all: Annotated[ - bool, - Query( - description="If true, a request to the database is made to retrieve all data. " - "If false, only the indexed information is returned.", - ), - ] = False, + get_all: bool = False, ): try: with DbSession() as session: @@ -168,21 +130,27 @@ def search( detail=f"The available platforms are: {platform_names}", ) - fields = search_fields if search_fields else self.indexed_fields + # Use provided dynamic search fields if set + extra_search_fields = {k: v for k, v in dynamic_params.dict().items() if v is not None} + + fields = list(extra_search_fields.keys()) or list(self.indexed_fields) query_matches: list[dict[str, dict[str, str | dict[str, str]]]] = [] - if exact_match: - query_matches = [ - {"match": {f: {"query": search_query, "operator": "and"}}} for f in fields - ] - else: - query_matches = [{"match": {f: search_query}} for f in fields] + + for f in fields: + if exact_match: + query_matches.append({"match": {f: {"query": search_query, "operator": "and"}}}) + else: + query_matches.append({"match": {f: search_query}}) + query = {"bool": {"should": query_matches, "minimum_should_match": 1}} must_clause = [] + if platforms: platform_matches = [{"match": {"platform": p}} for p in platforms] must_clause.append( {"bool": {"should": platform_matches, "minimum_should_match": 1}} ) + if date_modified_after or date_modified_before: date_range = {} if date_modified_after: @@ -190,8 +158,10 @@ def search( if date_modified_before: date_range["lt"] = date_modified_before must_clause.append({"range": {"date_modified": date_range}}) + if must_clause: query["bool"]["must"] = must_clause + sort: dict[str, str | dict[str, str]] = {} if sort_by_id: sort = {"identifier": "asc"} @@ -201,7 +171,9 @@ def search( result = ElasticsearchSingleton().client.search( index=self.es_index, query=query, from_=offset, size=limit, sort=sort ) + total_hits = result["hits"]["total"]["value"] + if get_all: identifiers = [hit["_source"]["identifier"] for hit in result["hits"]["hits"]] resources: list[SQLModel] = self._db_query( @@ -212,6 +184,7 @@ def search( self._cast_resource(read_class, hit["_source"]) for hit in result["hits"]["hits"] ] + return SearchResult[read_class]( # type: ignore total_hits=total_hits, resources=resources,