diff --git a/.gitignore b/.gitignore index 39ca09e..203fcd0 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist/ *.egg-info/ build/ .tox/ +.idea diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index 97694e4..8a4082f 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -1,11 +1,14 @@ """ Mixin to dynamically select only a subset of fields per DRF resource. """ + import warnings from django.conf import settings from django.utils.functional import cached_property +from rest_framework import serializers + class DynamicFieldsMixin(object): """ @@ -13,6 +16,15 @@ class DynamicFieldsMixin(object): which fields should be displayed. """ + @property + def is_preventing_nested_serializers(self): + is_root = self.root == self + parent_is_list_root = self.parent == self.root and getattr( + self.parent, "many", False + ) + + return not (is_root or parent_is_list_root) + @cached_property def fields(self): """ @@ -29,13 +41,7 @@ def fields(self): # We are being called before a request cycle return fields - # Only filter if this is the root serializer, or if the parent is the - # root serializer with many=True - is_root = self.root == self - parent_is_list_root = self.parent == self.root and getattr( - self.parent, "many", False - ) - if not (is_root or parent_is_list_root): + if self.is_preventing_nested_serializers: return fields try: @@ -55,15 +61,11 @@ def fields(self): if params is None: warnings.warn("Request object does not contain query parameters") - try: - filter_fields = params.get("fields", None).split(",") - except AttributeError: - filter_fields = None + source = get_source_path(self) + level = compute_level(self) - try: - omit_fields = params.get("omit", None).split(",") - except AttributeError: - omit_fields = [] + filter_fields = self.get_filter_fields(params.get("fields", None), level, source) + omit_fields = self.get_omit_fields(params.get("omit", None), level, source) # Drop any fields that are not specified in the `fields` argument. existing = set(fields.keys()) @@ -85,3 +87,73 @@ def fields(self): fields.pop(field, None) return fields + + def get_filter_fields(self, params, level, source, default=None, include_parent=True): + try: + return params.split(",") + except AttributeError: + return default + + + def get_omit_fields(self, params, level, source): + return self.get_filter_fields(params, level, source, default=[], include_parent=False) + + +class NestedDynamicFieldsMixin(DynamicFieldsMixin): + + @property + def is_preventing_nested_serializers(self): + return False + + def get_filter_fields(self, params, level, source, default=None, include_parent=True): + fields = super().get_filter_fields(params, level, source, default, include_parent) + return get_fields_for_level_and_prefix( + fields, + level, + source, + default=default, + include_parent=include_parent + ) + +def get_source_path(serializer): + parts = [] + current = serializer + while current.parent is not None: + if hasattr(current, 'field_name'): + parts.insert(0, current.field_name) + current = current.parent + return "__".join(filter(None, parts)) + +def get_fields_for_level_and_prefix(fields_list, level, source, include_parent, default): + if not fields_list: + return default + + allowed = set() + prefix = source.split("__") if source else [] + for f in fields_list: + parts = f.split("__") + if parts[:level] != prefix: + continue + if len(parts) <= level + 1: + allowed.add(parts[-1]) + elif len(parts) > level + 1 and include_parent: + # include parent field to ensure nesting proceeds + allowed.add(parts[level]) + if set(prefix) == allowed: + return default + return allowed + +def compute_level(serializer): + level = 0 + current = serializer + while hasattr(current, 'parent') and current.parent is not None: + parent = current.parent + + # Handle ListSerializer by skipping over it + if isinstance(parent, serializers.ListSerializer): + current = parent.parent + else: + current = parent + + level += 1 + return level diff --git a/tests/models.py b/tests/models.py index 714098e..884697a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -14,3 +14,12 @@ class School(models.Model): name = models.CharField(max_length=30) teachers = models.ManyToManyField(Teacher) + + +class Child(models.Model): + secret = models.CharField(max_length=100) + public = models.CharField(max_length=100) + + +class Parent(models.Model): + child = models.ForeignKey(Child, on_delete=models.CASCADE) \ No newline at end of file diff --git a/tests/serializers.py b/tests/serializers.py index 3619d59..9e37b49 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -3,16 +3,12 @@ """ from rest_framework import serializers -from drf_dynamic_fields import DynamicFieldsMixin +from drf_dynamic_fields import DynamicFieldsMixin, NestedDynamicFieldsMixin -from .models import Teacher, School +from .models import Teacher, School, Child -class TeacherSerializer(DynamicFieldsMixin, serializers.ModelSerializer): - """ - The request_info field is to highlight the issue accessing request during - a nested serializer. - """ +class BaseTeacherSerializer(serializers.ModelSerializer): request_info = serializers.SerializerMethodField() @@ -29,14 +25,44 @@ def get_request_info(self, teacher): return request.build_absolute_uri("/api/v1/teacher/{}".format(teacher.pk)) -class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer): +class TeacherSerializer(DynamicFieldsMixin, BaseTeacherSerializer): + pass + + +class NestableTeacherSerializer(NestedDynamicFieldsMixin, BaseTeacherSerializer): """ - Interesting enough serializer because the TeacherSerializer - will use ListSerializer due to the `many=True` + The request_info field is to highlight the issue accessing request during + a nested serializer. + """ - teachers = TeacherSerializer(many=True, read_only=True) +class BaseSchoolSerializer(serializers.ModelSerializer): + class Meta: model = School fields = ("id", "teachers", "name") + + +class SchoolSerializer(DynamicFieldsMixin, BaseSchoolSerializer): + teachers = TeacherSerializer(many=True, read_only=True) + + +class NestableSchoolSerializer(NestedDynamicFieldsMixin, BaseSchoolSerializer): + """ + Interesting enough serializer because the TeacherSerializer + will use ListSerializer due to the `many=True` + """ + teachers = NestableTeacherSerializer(many=True, read_only=True) + +class ChildSerializer(NestedDynamicFieldsMixin, serializers.Serializer): + secret = serializers.CharField() + public = serializers.CharField() + + class Meta: + model = Child + + +class ParentSerializer(NestedDynamicFieldsMixin, serializers.Serializer): + id = serializers.IntegerField() + child = ChildSerializer() diff --git a/tests/test_mixins.py b/tests/test_mixins.py index f54e179..8c9a1a1 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -11,8 +11,14 @@ from django.test import TestCase, RequestFactory -from .serializers import SchoolSerializer, TeacherSerializer -from .models import Teacher, School +from .serializers import ( + NestableSchoolSerializer, + NestableTeacherSerializer, + SchoolSerializer, + TeacherSerializer, + ParentSerializer, +) +from .models import Teacher, School, Child, Parent class TestDynamicFieldsMixin(TestCase): @@ -20,13 +26,16 @@ class TestDynamicFieldsMixin(TestCase): Test case for the DynamicFieldsMixin """ + SchoolSerializer = SchoolSerializer + TeacherSerializer = TeacherSerializer + def test_removes_fields(self): """ Does it actually remove fields? """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields=id") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set(("id",))) @@ -36,7 +45,7 @@ def test_fields_left_alone(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual( set(serializer.fields.keys()), set(("id", "request_info", "age", "name")) @@ -48,7 +57,7 @@ def test_fields_all_gone(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set()) @@ -60,7 +69,7 @@ def test_ordinary_serializer(self): request = rf.get("/api/v1/schools/1/?fields=id,age") teacher = Teacher.objects.create(name="Susan", age=34) - serializer = TeacherSerializer(teacher, context={"request": request}) + serializer = self.TeacherSerializer(teacher, context={"request": request}) self.assertEqual(serializer.data, {"id": teacher.id, "age": teacher.age}) @@ -70,7 +79,7 @@ def test_omit(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit=request_info") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set(("id", "name", "age"))) @@ -80,7 +89,7 @@ def test_omit_and_fields_used(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields=id,request_info&omit=request_info") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set(("id",))) @@ -90,7 +99,7 @@ def test_omit_everything(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit=id,request_info,age,name") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set()) @@ -100,7 +109,7 @@ def test_omit_nothing(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual( set(serializer.fields.keys()), set(("id", "request_info", "name", "age")) @@ -109,7 +118,7 @@ def test_omit_nothing(self): def test_omit_non_existant_field(self): rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit=pretend") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual( set(serializer.fields.keys()), set(("id", "request_info", "name", "age")) @@ -129,7 +138,7 @@ def test_as_nested_serializer(self): ] school.teachers.add(*teachers) - serializer = SchoolSerializer(school, context={"request": request}) + serializer = self.SchoolSerializer(school, context={"request": request}) request_info = "http://testserver/api/v1/teacher/{}" @@ -168,10 +177,208 @@ def test_serializer_reuse_with_changing_request(self): rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields=id") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), {"id"}) # now change the request on this instantiated serializer. request2 = rf.get("/api/v1/schools/1/?fields=id,name") serializer.context["request"] = request2 self.assertEqual(set(serializer.fields.keys()), {"id"}) + +class TestNestedDynamicFieldsMixin(TestDynamicFieldsMixin): + """ + Test case for the NestedDynamicFieldsMixin + """ + SchoolSerializer = NestableSchoolSerializer + TeacherSerializer = NestableTeacherSerializer + + def _assert_nested_fields(self, data, expected_fields): + """ + Assert nested fields match the expected fields. + """ + for parent, nested_fields in expected_fields.items(): + with self.subTest(parent=parent): + items = data[parent] + if nested_fields is None: + continue + expected_set = set(nested_fields) + for obj in items: + with self.subTest(parent=parent): + actual_set = set(obj.keys()) + self.assertEqual( + actual_set, + expected_set, + f"{parent} fields mismatch: expected " + f"exactly {nested_fields}, got {list(obj.keys())}", + ) + + @staticmethod + def _prepare_school_instance(): + """Prepare school instance for testing.""" + school = School.objects.create(name="Python Heights High") + teachers = [ + Teacher.objects.create(name="Shane", age=45), + Teacher.objects.create(name="Kaz", age=29), + ] + school.teachers.add(*teachers) + return school + + def test_omit_nested_field(self): + """Omitting a nested field""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=invalid,name,teachers__age,teachers__invalid") + + school = self._prepare_school_instance() + serializer = self.SchoolSerializer(school, context={"request": request}) + data = serializer.data + + # Confirm omitted fields are in deferred list + deferred = set(serializer.get_model_fields_to_defer()) + self.assertEqual({"name", "teachers__age"}, deferred) + + expected_fields = {"id": None, "teachers": ["id", "name", "request_info"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_everything_nested_field(self): + """Omitting all fields within a nested field""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?omit=teachers__id,teachers__age,teachers__name,teachers__request_info" + ) + + school = self._prepare_school_instance() + serializer = self.SchoolSerializer(school, context={"request": request}) + data = serializer.data + + expected_fields = {"id": None, "name": None, "teachers": []} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_top_field_and_keep_all_nested_fields(self): + """Omitting a top-level field while keeping all nested fields""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=name") + + school = self._prepare_school_instance() + serializer = self.SchoolSerializer(school, context={"request": request}) + data = serializer.data + + expected_fields = { + "id": None, + "teachers": ["id", "name", "request_info", "age"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_allow_nested_field(self): + """Select only the requested fields, including nested-level fields.""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields=invalid,id,teachers__age,teachers__invalid") + school = self._prepare_school_instance() + serializer = self.SchoolSerializer(school, context={"request": request}) + + # Confirm omitted fields are in deferred list + deferred = set(serializer.get_model_fields_to_defer()) + self.assertEqual({"name","teachers__name", "teachers__id"}, deferred) + + data = serializer.data + expected_fields = {"id": None, "teachers": ["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_fields_all_gone_nested(self): + """If no fields are selected, all fields are omitted, including those + from the nested serializer. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields") + school = self._prepare_school_instance() + serializer = self.SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = {} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_nested_omit_and_fields_used(self): + """Omit and fields can be used together at the nested field level.""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?fields=id,name,teachers__name,teachers__age&omit=name,teachers__name" + ) + school = self._prepare_school_instance() + serializer = self.SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = {"id": None, "teachers": ["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_nothing_nested(self): + """ + Blank omit doesn't affect nested fields. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit") + school = self._prepare_school_instance() + serializer = self.SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = { + "id": None, + "name": None, + "teachers": ["id", "age", "name", "request_info"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_single_nested_instance_omit_field(self): + """Omit also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?omit=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertNotIn("id", data) + self.assertNotIn("secret", data["child"]) + self.assertEqual(data["child"]["public"], "public_key") + + def test_single_nested_instance_allow_field(self): + """Fields selection also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?fields=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertEqual(data["id"], 1) + self.assertIn("secret", data["child"]) + self.assertEqual(data["child"]["secret"], "secret_key") + self.assertNotIn("public", data["child"]) + +