Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion graphene_django_optimizer/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.reverse_related import ManyToOneRel
from graphene import InputObjectType
from graphene.relay.node import GlobalID
from graphene.types.generic import GenericScalar
from graphene.types.resolver import default_resolver
from graphene_django import DjangoObjectType
Expand Down Expand Up @@ -289,7 +290,12 @@ def _is_resolver_for_id_field(self, resolver):
# For python 2 unbound method:
if hasattr(resolve_id, 'im_func'):
resolve_id = resolve_id.im_func
return resolver == resolve_id

if (isinstance(resolver, functools.partial)
and resolver.func == GlobalID.id_resolver):
return resolver.args[0] == resolve_id
else:
return resolver == resolve_id

def _get_model_field_from_name(self, model, name):
try:
Expand Down
39 changes: 39 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,42 @@ def test_should_only_use_the_only_and_not_select_related():
items = gql_optimizer.query(qs, info)
optimized_items = qs.only('id', 'name')
assert_query_equality(items, optimized_items)


@pytest.mark.django_db
def test_should_use_only_with_node_interface():
info = create_resolve_info(schema, '''
query {
relayItems {
edges {
node {
id
}
}
}
}
''')
qs = Item.objects.all()
items = gql_optimizer.query(qs, info)
optimized_items = qs.only('id')
assert_query_equality(items, optimized_items)


@pytest.mark.django_db
def test_should_not_try_to_optimize_non_field_model_fields_on_relay_node():
info = create_resolve_info(schema, '''
query {
relayItems {
edges {
node {
id
foo
}
}
}
}
''')
qs = Item.objects.all()
items = gql_optimizer.query(qs, info)
optimized_items = qs
assert_query_equality(items, optimized_items)