diff --git a/poetry.lock b/poetry.lock index c2a1d39a..a05a305e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -476,6 +476,7 @@ files = [ {file = "psycopg2-2.9.10-cp311-cp311-win_amd64.whl", hash = "sha256:0435034157049f6846e95103bd8f5a668788dd913a7c30162ca9503fdf542cb4"}, {file = "psycopg2-2.9.10-cp312-cp312-win32.whl", hash = "sha256:65a63d7ab0e067e2cdb3cf266de39663203d38d6a8ed97f5ca0cb315c73fe067"}, {file = "psycopg2-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:4a579d6243da40a7b3182e0430493dbd55950c493d8c68f4eec0b302f6bbf20e"}, + {file = "psycopg2-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:91fd603a2155da8d0cfcdbf8ab24a2d54bca72795b90d2a3ed2b6da8d979dee2"}, {file = "psycopg2-2.9.10-cp39-cp39-win32.whl", hash = "sha256:9d5b3b94b79a844a986d029eee38998232451119ad653aea42bb9220a8c5066b"}, {file = "psycopg2-2.9.10-cp39-cp39-win_amd64.whl", hash = "sha256:88138c8dedcbfa96408023ea2b0c369eda40fe5d75002c0964c78f46f11fa442"}, {file = "psycopg2-2.9.10.tar.gz", hash = "sha256:12ec0b40b0273f95296233e8750441339298e6a572f7039da5b260e3c8b60e11"}, @@ -536,6 +537,7 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, diff --git a/strawberry_django/__init__.py b/strawberry_django/__init__.py index 74e0a557..45090c89 100644 --- a/strawberry_django/__init__.py +++ b/strawberry_django/__init__.py @@ -7,6 +7,7 @@ DateFilterLookup, DatetimeFilterLookup, FilterLookup, + GeometryFilterLookup, RangeLookup, TimeFilterLookup, ) @@ -37,6 +38,8 @@ "DjangoImageType", "DjangoModelType", "FilterLookup", + "GeometryFilterLookup", + "GeometryFilterLookup", "ListInput", "ManyToManyInput", "ManyToOneInput", diff --git a/strawberry_django/fields/filter_types.py b/strawberry_django/fields/filter_types.py index 1c05ccbb..4a12124e 100644 --- a/strawberry_django/fields/filter_types.py +++ b/strawberry_django/fields/filter_types.py @@ -2,12 +2,15 @@ import decimal import uuid from typing import ( + TYPE_CHECKING, + Annotated, Generic, Optional, TypeVar, ) import strawberry +from django.core.exceptions import ImproperlyConfigured from django.db.models import Q from strawberry import UNSET @@ -15,6 +18,9 @@ from .filter_order import filter_field +if TYPE_CHECKING: + from .types import Geometry + T = TypeVar("T") _SKIP_MSG = "Filter will be skipped on `null` value" @@ -123,3 +129,59 @@ class DatetimeFilterLookup(DateFilterLookup[T], TimeFilterLookup[T]): str: FilterLookup, uuid.UUID: FilterLookup, } + + +GeometryFilterLookup = None + +try: + pass +except ImproperlyConfigured: + # If gdal is not available, skip. + pass +else: + + @strawberry.input + class GeometryFilterLookup(Generic[T]): + bbcontains: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + bboverlaps: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + contained: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + contains: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + contains_properly: Optional[ + Annotated["Geometry", strawberry.lazy(".types")] + ] = UNSET + coveredby: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + covers: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + crosses: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + disjoint: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + equals: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + exacts: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + intersects: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + isempty: Optional[bool] = filter_field( + description=f"Test whether it's empty. {_SKIP_MSG}" + ) + isvalid: Optional[bool] = filter_field( + description=f"Test whether it's valid. {_SKIP_MSG}" + ) + overlaps: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + touches: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + within: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + left: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + right: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = UNSET + overlaps_left: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = ( + UNSET + ) + overlaps_right: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = ( + UNSET + ) + overlaps_above: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = ( + UNSET + ) + overlaps_below: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = ( + UNSET + ) + strictly_above: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = ( + UNSET + ) + strictly_below: Optional[Annotated["Geometry", strawberry.lazy(".types")]] = ( + UNSET + ) diff --git a/strawberry_django/fields/types.py b/strawberry_django/fields/types.py index fe54f045..2fdbd324 100644 --- a/strawberry_django/fields/types.py +++ b/strawberry_django/fields/types.py @@ -24,6 +24,7 @@ from strawberry.file_uploads.scalars import Upload from strawberry.scalars import JSON from strawberry.types.enum import EnumValueDefinition +from strawberry.types.scalar import ScalarWrapper from strawberry.utils.str_converters import capitalize_first, to_camel_case from strawberry_django import filters @@ -348,7 +349,7 @@ def __hash__(self): Geometry = strawberry.scalar( NewType("Geometry", geos.GEOSGeometry), serialize=lambda v: v.tuple if isinstance(v, geos.GEOSGeometry) else v, # type: ignore - parse_value=lambda v: geos.GeometryCollection, + parse_value=lambda v: geos.GEOSGeometry(v), description=( "An arbitrary geographical object. One of Point, " "LineString, LinearRing, Polygon, MultiPoint, MultiLineString, MultiPolygon." @@ -556,7 +557,20 @@ def resolve_model_field_type( and (field_type is not bool or not using_old_filters) ): if using_old_filters: - field_type = filters.FilterLookup[field_type] + field_type = filters.FilterLookup[field_type] # type: ignore + elif type( + field_type + ) is ScalarWrapper and field_type._scalar_definition.name in ( + "Point", + "LineString", + "LinearRing", + "Polygon", + "MultiPoint", + "MultilineString", + "MultiPolygon", + "Geometry", + ): + field_type = filter_types.GeometryFilterLookup[field_type] else: field_type = filter_types.type_filter_map.get( # type: ignore field_type, filter_types.FilterLookup diff --git a/tests/test_queries.py b/tests/test_queries.py index 1dea5bc7..7aa760cf 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -6,6 +6,7 @@ import pytest import strawberry from asgiref.sync import sync_to_async +from django.conf import settings from django.core.files.uploadedfile import SimpleUploadedFile from django.test import override_settings from graphql import GraphQLError @@ -15,7 +16,7 @@ import strawberry_django from strawberry_django.settings import StrawberryDjangoSettings -from . import models, utils +from . import models, types, utils @pytest.fixture @@ -79,6 +80,8 @@ class Query: fruit: Fruit = strawberry_django.field() berries: list[BerryFruit] = strawberry_django.field() bananas: list[BananaFruit] = strawberry_django.field() + if settings.GEOS_IMPORTED: + geometries: list[types.GeoField] = strawberry_django.field() @pytest.fixture @@ -314,3 +317,29 @@ def fruit(self) -> Fruit: } """) assert result.data == {"fruit": {"colorId": mock.ANY, "name": "Banana"}} + + +@pytest.mark.skipif(not settings.GEOS_IMPORTED, reason="GeoDjango is not available.") +async def test_geos(query): + from django.contrib.gis.geos import GEOSGeometry + + result = await query( + """ + query GeosQuery($filter: GeoFieldFilter) { + geometries(filters: $filter) { + geometry + } + } + """, + variable_values={ + "filter": { + "geometry": { + "contains": GEOSGeometry( + "POLYGON(( 10 10, 10 20, 20 20, 20 15, 10 10))" + ) + } + } + }, + ) + + assert not result.errors diff --git a/tests/types.py b/tests/types.py index 42648a36..bae717ea 100644 --- a/tests/types.py +++ b/tests/types.py @@ -38,7 +38,11 @@ class TomatoWithRequiredPictureType: if settings.GEOS_IMPORTED: - @strawberry_django.type(models.GeosFieldsModel) + @strawberry_django.filters.filter(models.GeosFieldsModel, lookups=True) + class GeoFieldFilter: + geometry: auto + + @strawberry_django.type(models.GeosFieldsModel, filters=GeoFieldFilter) class GeoField: id: auto point: auto @@ -47,6 +51,7 @@ class GeoField: multi_point: auto multi_line_string: auto multi_polygon: auto + geometry: auto @strawberry_django.input(models.GeosFieldsModel) class GeoFieldInput(GeoField):