diff --git a/api/metadata/models.py b/api/metadata/models.py index 8cacbd9cf007..149342e41686 100644 --- a/api/metadata/models.py +++ b/api/metadata/models.py @@ -1,3 +1,4 @@ +from typing import cast from urllib.parse import urlparse from django.contrib.contenttypes.fields import GenericForeignKey @@ -30,41 +31,46 @@ class FieldType(models.TextChoices): class MetadataField(AbstractBaseExportableModel): """This model represents a metadata field(specific to an organisation) that can be attached to any model""" - name = models.CharField(max_length=255) - type = models.CharField( + id: models.AutoField[int, int] + name: models.CharField[str, str] = models.CharField(max_length=255) + type: models.CharField[str, str] = models.CharField( max_length=255, choices=FieldType.choices, default=FieldType.STRING ) - description = models.TextField(blank=True, null=True) - organisation = models.ForeignKey(Organisation, on_delete=models.CASCADE) + description: models.TextField[str | None, str | None] = models.TextField( + blank=True, null=True + ) + organisation: models.ForeignKey[Organisation, Organisation] = models.ForeignKey( + Organisation, on_delete=models.CASCADE + ) def is_field_value_valid(self, field_value: str) -> bool: if len(field_value) > FIELD_VALUE_MAX_LENGTH: return False - return self.__getattribute__(f"validate_{self.type}")(field_value) # type: ignore[no-any-return] + return cast(bool, self.__getattribute__(f"validate_{self.type}")(field_value)) - def validate_int(self, field_value: str): # type: ignore[no-untyped-def] + def validate_int(self, field_value: str) -> bool: try: int(field_value) except ValueError: return False return True - def validate_bool(self, field_value: str): # type: ignore[no-untyped-def] + def validate_bool(self, field_value: str) -> bool: if field_value.lower() in ["true", "false"]: return True return False - def validate_url(self, field_value: str): # type: ignore[no-untyped-def] + def validate_url(self, field_value: str) -> bool: try: result = urlparse(field_value) return all([result.scheme, result.netloc]) except ValueError: return False - def validate_str(self, field_value: str): # type: ignore[no-untyped-def] + def validate_str(self, field_value: str) -> bool: return True - def validate_multiline_str(self, field_value: str): # type: ignore[no-untyped-def] + def validate_multiline_str(self, field_value: str) -> bool: return True class Meta: @@ -114,3 +120,15 @@ class Metadata(AbstractBaseExportableModel): class Meta: unique_together = ("model_field", "content_type", "object_id") + + def deep_clone_for_new_entity(self, cloned_entity: models.Model) -> "Metadata": + content_type = ContentType.objects.get_for_model(cloned_entity) + return cast( + Metadata, + Metadata.objects.create( + model_field=self.model_field, + content_type=content_type, + object_id=cloned_entity.pk, + field_value=self.field_value, + ), + ) diff --git a/api/segments/models.py b/api/segments/models.py index f05cc5b9c8af..886845ee419c 100644 --- a/api/segments/models.py +++ b/api/segments/models.py @@ -147,6 +147,49 @@ def set_version_of_to_self_if_none(self): # type: ignore[no-untyped-def] self.version_of = self self.save_without_historical_record() + def clone_segment_rules(self, cloned_segment: "Segment") -> list["SegmentRule"]: + cloned_rules = [] + for rule in self.rules.all(): + cloned_rule = rule.deep_clone(cloned_segment) + cloned_rules.append(cloned_rule) + cloned_segment.refresh_from_db() + assert ( + len(self.rules.all()) + == len(cloned_rules) + == len(cloned_segment.rules.all()) + ), "Mismatch during rules creation" + + return cloned_rules + + def clone_segment_metadata(self, cloned_segment: "Segment") -> list["Metadata"]: + cloned_metadata = [] + for metadata in self.metadata.all(): + cloned_metadata.append(metadata.deep_clone_for_new_entity(cloned_segment)) + cloned_segment.refresh_from_db() + assert ( + len(self.metadata.all()) + == len(cloned_metadata) + == len(cloned_segment.metadata.all()) + ), "Mismatch during metadata creation" + + return cloned_metadata + + def clone(self, name: str) -> "Segment": + cloned_segment = Segment( + name=name, + version_of=None, + uuid=uuid.uuid4(), + description=self.description, + change_request=self.change_request, + project=self.project, + feature=self.feature, + ) + cloned_segment.save() + self.clone_segment_rules(cloned_segment) + self.clone_segment_metadata(cloned_segment) + cloned_segment.refresh_from_db() + return cloned_segment + def shallow_clone( self, name: str, @@ -177,18 +220,7 @@ def deep_clone(self) -> "Segment": self.version += 1 # type: ignore[operator] self.save_without_historical_record() - cloned_rules = [] - for rule in self.rules.all(): - cloned_rule = rule.deep_clone(cloned_segment) - cloned_rules.append(cloned_rule) - - cloned_segment.refresh_from_db() - - assert ( - len(self.rules.all()) - == len(cloned_rules) - == len(cloned_segment.rules.all()) - ), "Mismatch during rules creation" + self.clone_segment_rules(cloned_segment) return cloned_segment diff --git a/api/segments/serializers.py b/api/segments/serializers.py index 0b92c41035ea..355c1ba1454c 100644 --- a/api/segments/serializers.py +++ b/api/segments/serializers.py @@ -1,3 +1,6 @@ +from typing import Any, cast + +from common.segments.serializers import SegmentSerializer from rest_framework import serializers from segments.models import Segment @@ -19,3 +22,22 @@ class SegmentListQuerySerializer(serializers.Serializer): # type: ignore[type-a help_text="Optionally provide the id of an identity to get only the segments they match", ) include_feature_specific = serializers.BooleanField(required=False, default=True) + + +class CloneSegmentSerializer(SegmentSerializer): + class Meta: + model = Segment + fields = ("name",) + + def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: + if not attrs.get("name"): + raise serializers.ValidationError("Name is required to clone a segment") + return attrs + + def create(self, validated_data: dict[str, Any]) -> Segment: + name = validated_data.get("name") + source_segment = self.context.get("source_segment") + assert source_segment is not None, ( + "Source segment is required to clone a segment" + ) + return cast(Segment, source_segment.clone(name)) diff --git a/api/segments/views.py b/api/segments/views.py index 4e027ce54cff..e5995a484445 100644 --- a/api/segments/views.py +++ b/api/segments/views.py @@ -1,4 +1,5 @@ import logging +from typing import Any from common.projects.permissions import VIEW_PROJECT from common.segments.serializers import ( @@ -6,9 +7,10 @@ ) from django.utils.decorators import method_decorator from drf_yasg.utils import swagger_auto_schema # type: ignore[import-untyped] -from rest_framework import viewsets +from rest_framework import status, viewsets from rest_framework.decorators import action, api_view from rest_framework.generics import get_object_or_404 +from rest_framework.request import Request from rest_framework.response import Response from app.pagination import CustomPagination @@ -24,7 +26,10 @@ from .models import Segment from .permissions import SegmentPermissions -from .serializers import SegmentListQuerySerializer +from .serializers import ( + CloneSegmentSerializer, + SegmentListQuerySerializer, +) logger = logging.getLogger() @@ -119,6 +124,25 @@ def associated_features(self, request, *args, **kwargs): # type: ignore[no-unty serializer = self.get_serializer(queryset, many=True) return Response(serializer.data) + @swagger_auto_schema( + request_body=CloneSegmentSerializer, + responses={201: SegmentSerializer()}, + method="post", + ) # type: ignore[misc] + @action( + detail=True, + methods=["POST"], + url_path="clone", + serializer_class=CloneSegmentSerializer, + ) + def clone(self, request: Request, *args: Any, **kwargs: Any) -> Response: + serializer = CloneSegmentSerializer( + data=request.data, context={"source_segment": self.get_object()} + ) + serializer.is_valid(raise_exception=True) + clone = serializer.save() + return Response(SegmentSerializer(clone).data, status=status.HTTP_201_CREATED) + @swagger_auto_schema(responses={200: SegmentSerializer()}, method="get") @api_view(["GET"]) diff --git a/api/tests/unit/segments/test_unit_segments_serializers.py b/api/tests/unit/segments/test_unit_segments_serializers.py new file mode 100644 index 000000000000..053bb94d1062 --- /dev/null +++ b/api/tests/unit/segments/test_unit_segments_serializers.py @@ -0,0 +1,16 @@ +import pytest +from rest_framework.exceptions import ErrorDetail, ValidationError + +from segments.serializers import CloneSegmentSerializer + + +def test_clone_segment_serializer_validation_without_name_should_fail() -> None: + # Given + serializer = CloneSegmentSerializer() + # When + with pytest.raises(ValidationError) as exception: + serializer.validate({"name": ""}) + # Then + assert exception.value.detail == [ + ErrorDetail(string="Name is required to clone a segment", code="invalid") + ] diff --git a/api/tests/unit/segments/test_unit_segments_views.py b/api/tests/unit/segments/test_unit_segments_views.py index 878afb822952..648f32b4ed7a 100644 --- a/api/tests/unit/segments/test_unit_segments_views.py +++ b/api/tests/unit/segments/test_unit_segments_views.py @@ -1493,3 +1493,119 @@ def test_include_feature_specific_query_filter__false( # Then assert response.json()["count"] == 1 assert [res["id"] for res in response.json()["results"]] == [segment.id] + + +@pytest.mark.parametrize( + "source_segment", + [ + (lazy_fixture("segment")), + (lazy_fixture("feature_specific_segment")), + ], +) +def test_clone_segment( + project: Project, + admin_client: APIClient, + source_segment: Segment, + required_a_segment_metadata_field: MetadataModelField, +) -> None: + # Given + url = reverse( + "api-v1:projects:project-segments-clone", args=[project.id, source_segment.id] + ) + new_segment_name = "cloned_segment" + data = { + "name": new_segment_name, + } + # Preparing the rules + segment_rule = SegmentRule.objects.create( + segment=source_segment, + type=SegmentRule.ALL_RULE, + ) + sub_rule = SegmentRule.objects.create( + rule=segment_rule, + type=SegmentRule.ALL_RULE, + ) + + # Preparing the conditions + created_condition = Condition.objects.create( + rule=sub_rule, + property="foo", + operator=EQUAL, + value="bar", + created_with_segment=False, + ) + + # Preparing the metadata + segment_content_type = ContentType.objects.get_for_model(source_segment) + metadata = Metadata.objects.create( + object_id=source_segment.id, + content_type=segment_content_type, + model_field=required_a_segment_metadata_field, + field_value="test-clone-segment-metadata", + ) + + # When + response = admin_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_201_CREATED + + response_data = response.json() + assert response_data["name"] == new_segment_name + assert response_data["project"] == project.id + assert response_data["id"] != source_segment.id + + # Testing cloned segment main attributes + cloned_segment = Segment.objects.get(id=response_data["id"]) + assert cloned_segment.name == new_segment_name + assert cloned_segment.project_id == project.id + assert cloned_segment.description == source_segment.description + assert cloned_segment.version == 1 + assert cloned_segment.version_of_id == cloned_segment.id + assert cloned_segment.change_request is None + assert cloned_segment.feature_id == source_segment.feature_id + + # Testing cloning of rules + assert cloned_segment.rules.count() == source_segment.rules.count() + + cloned_top_rule = cloned_segment.rules.first() + cloned_sub_rule = cloned_top_rule.rules.first() + + assert cloned_top_rule.type == segment_rule.type + assert cloned_sub_rule.type == segment_rule.type + + # Testing cloning of sub-rules conditions + cloned_condition = cloned_sub_rule.conditions.first() + + assert cloned_condition.property == created_condition.property + assert cloned_condition.operator == created_condition.operator + assert cloned_condition.value == created_condition.value + + # Testing cloning of metadata + cloned_metadata = cloned_segment.metadata.first() + assert cloned_metadata.model_field == metadata.model_field + assert cloned_metadata.field_value == metadata.field_value + assert cloned_metadata.id != metadata.id + + +def test_clone_segment_without_name_should_fail( + project: Project, + admin_client: APIClient, + segment: Segment, +) -> None: + # Given + url = reverse( + "api-v1:projects:project-segments-clone", args=[project.id, segment.id] + ) + data = { + "no-name": "", + } + # When + response = admin_client.post( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST