diff --git a/graphene_django_extras/directives/base.py b/graphene_django_extras/directives/base.py index 6798999..841558b 100644 --- a/graphene_django_extras/directives/base.py +++ b/graphene_django_extras/directives/base.py @@ -6,17 +6,21 @@ class BaseExtraGraphQLDirective(GraphQLDirective): + default_locations = [ + DirectiveLocation.FIELD, + DirectiveLocation.FRAGMENT_SPREAD, + DirectiveLocation.INLINE_FRAGMENT, + ] + + locations = [] + def __init__(self): registry = get_global_registry() super(BaseExtraGraphQLDirective, self).__init__( name=self.get_name(), description=self.__doc__, args=self.get_args(), - locations=[ - DirectiveLocation.FIELD, - DirectiveLocation.FRAGMENT_SPREAD, - DirectiveLocation.INLINE_FRAGMENT, - ], + locations=self.get_locations(), ) registry.register_directive(self.get_name(), self) @@ -27,3 +31,7 @@ def get_name(cls): @staticmethod def get_args(): return {} + + @classmethod + def get_locations(cls): + return cls.locations if cls.locations else cls.default_locations diff --git a/tests/conftest.py b/tests/conftest.py index 8c11bd2..81270a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,7 +63,9 @@ def pytest_configure(config): "tests", ), PASSWORD_HASHERS=("django.contrib.auth.hashers.MD5PasswordHasher",), - GRAPHENE={"SCHEMA": "tests.schema.schema"}, + GRAPHENE={"SCHEMA": "tests.schema.schema", "MIDDLEWARE":[ + "graphene_django_extras.middleware.ExtraGraphQLDirectiveMiddleware", + ]}, AUTHENTICATION_BACKENDS=( "django.contrib.auth.backends.ModelBackend", "guardian.backends.ObjectPermissionBackend", diff --git a/tests/queries.py b/tests/queries.py index ef88975..51edd9e 100644 --- a/tests/queries.py +++ b/tests/queries.py @@ -22,6 +22,14 @@ } } """ + +ALL_USERS_DIR = """query{ + allUsers1{ + username @uppercase + } +} +""" + ALL_USERS3 = """query { allUsers3 { id diff --git a/tests/schema.py b/tests/schema.py index e5a0ecc..9016dab 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -2,6 +2,7 @@ from django.contrib.auth.models import User from django.utils.translation import ugettext_lazy as _ +from graphene_django_extras.directives import all_directives from graphene_django_extras.types import ( DjangoListObjectType, DjangoSerializerType, @@ -93,4 +94,4 @@ class Query(graphene.ObjectType): user2, users = UserModelType.QueryFields() -schema = graphene.Schema(query=Query) +schema = graphene.Schema(query=Query, directives=all_directives) diff --git a/tests/test_fields.py b/tests/test_fields.py index a2425e8..0a78ba0 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -103,6 +103,11 @@ def test_filter_charfield_iexact(self): ) +class DirectiveQueryTest(ParentTest, TestCase): + query = queries.ALL_USERS_DIR + expected_return_payload = {"data": {"allUsers1": [{"username": "GRAPHQL"}]}} + + class DjangoSerializerTypeTest(ParentTest, TestCase): expected_return_payload = { "data": {