Skip to content

feat: backend-clone-segments #5393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6168412
feat: implemented-clone-segment-endpoint
Zaimwa9 Apr 21, 2025
6771d8d
feat: reworked-clone-response-as-single-line
Zaimwa9 Apr 21, 2025
e941015
feat: implemented-metadata-duplication-for-entity
Zaimwa9 Apr 23, 2025
517a499
feat: added-tests-on-segment-cloning
Zaimwa9 Apr 24, 2025
56e52bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2025
cbb1809
feat: added-types
Zaimwa9 Apr 25, 2025
64db882
feat: added-types
Zaimwa9 Apr 25, 2025
518dda2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2025
1167012
feat: added-metadata-tests
Zaimwa9 Apr 25, 2025
874d6e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2025
6a60789
feat: uncommented-test
Zaimwa9 Apr 25, 2025
d2a2256
feat: lint-condition-test
Zaimwa9 Apr 25, 2025
8544901
feat: changed-name-missing-test
Zaimwa9 Apr 25, 2025
1c55325
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2025
24ed9fb
feat: added-serializer-test
Zaimwa9 Apr 28, 2025
055d169
Merge branch 'feat/backend-clone-segments' of github.com:Flagsmith/fl…
Zaimwa9 Apr 28, 2025
1103561
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2025
355a48c
feat: test-validation-raise-error
Zaimwa9 Apr 28, 2025
eb2b8a6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2025
31fb1a1
feat: moved-content-type-selection-within-metadata-deep-cloning
Zaimwa9 Apr 28, 2025
cf1a84b
Merge branch 'feat/backend-clone-segments' of github.com:Flagsmith/fl…
Zaimwa9 Apr 28, 2025
682fd2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions api/metadata/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import cast
from urllib.parse import urlparse

from django.contrib.contenttypes.fields import GenericForeignKey
Expand Down Expand Up @@ -30,41 +31,46 @@ class FieldType(models.TextChoices):
class MetadataField(AbstractBaseExportableModel):
"""This model represents a metadata field(specific to an organisation) that can be attached to any model"""

name = models.CharField(max_length=255)
type = models.CharField(
id: models.AutoField[int, int]
name: models.CharField[str, str] = models.CharField(max_length=255)
type: models.CharField[str, str] = models.CharField(
max_length=255, choices=FieldType.choices, default=FieldType.STRING
)
description = models.TextField(blank=True, null=True)
organisation = models.ForeignKey(Organisation, on_delete=models.CASCADE)
description: models.TextField[str | None, str | None] = models.TextField(
blank=True, null=True
)
organisation: models.ForeignKey[Organisation, Organisation] = models.ForeignKey(
Organisation, on_delete=models.CASCADE
)

def is_field_value_valid(self, field_value: str) -> bool:
if len(field_value) > FIELD_VALUE_MAX_LENGTH:
return False
return self.__getattribute__(f"validate_{self.type}")(field_value) # type: ignore[no-any-return]
return cast(bool, self.__getattribute__(f"validate_{self.type}")(field_value))

def validate_int(self, field_value: str): # type: ignore[no-untyped-def]
def validate_int(self, field_value: str) -> bool:
try:
int(field_value)
except ValueError:
return False
return True

def validate_bool(self, field_value: str): # type: ignore[no-untyped-def]
def validate_bool(self, field_value: str) -> bool:
if field_value.lower() in ["true", "false"]:
return True
return False

def validate_url(self, field_value: str): # type: ignore[no-untyped-def]
def validate_url(self, field_value: str) -> bool:
try:
result = urlparse(field_value)
return all([result.scheme, result.netloc])
except ValueError:
return False

def validate_str(self, field_value: str): # type: ignore[no-untyped-def]
def validate_str(self, field_value: str) -> bool:
return True

def validate_multiline_str(self, field_value: str): # type: ignore[no-untyped-def]
def validate_multiline_str(self, field_value: str) -> bool:
return True

class Meta:
Expand Down Expand Up @@ -114,3 +120,15 @@ class Metadata(AbstractBaseExportableModel):

class Meta:
unique_together = ("model_field", "content_type", "object_id")

def deep_clone_for_new_entity(self, cloned_entity: models.Model) -> "Metadata":
content_type = ContentType.objects.get_for_model(cloned_entity)
return cast(
Metadata,
Metadata.objects.create(
model_field=self.model_field,
content_type=content_type,
object_id=cloned_entity.pk,
field_value=self.field_value,
),
)
56 changes: 44 additions & 12 deletions api/segments/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,49 @@ def set_version_of_to_self_if_none(self): # type: ignore[no-untyped-def]
self.version_of = self
self.save_without_historical_record()

def clone_segment_rules(self, cloned_segment: "Segment") -> list["SegmentRule"]:
cloned_rules = []
for rule in self.rules.all():
cloned_rule = rule.deep_clone(cloned_segment)
cloned_rules.append(cloned_rule)
cloned_segment.refresh_from_db()
assert (
len(self.rules.all())
== len(cloned_rules)
== len(cloned_segment.rules.all())
), "Mismatch during rules creation"

return cloned_rules

def clone_segment_metadata(self, cloned_segment: "Segment") -> list["Metadata"]:
cloned_metadata = []
for metadata in self.metadata.all():
cloned_metadata.append(metadata.deep_clone_for_new_entity(cloned_segment))
cloned_segment.refresh_from_db()
assert (
len(self.metadata.all())
== len(cloned_metadata)
== len(cloned_segment.metadata.all())
), "Mismatch during metadata creation"

return cloned_metadata

def clone(self, name: str) -> "Segment":
cloned_segment = Segment(
name=name,
version_of=None,
uuid=uuid.uuid4(),
description=self.description,
change_request=self.change_request,
project=self.project,
feature=self.feature,
)
cloned_segment.save()
self.clone_segment_rules(cloned_segment)
self.clone_segment_metadata(cloned_segment)
cloned_segment.refresh_from_db()
return cloned_segment

def shallow_clone(
self,
name: str,
Expand Down Expand Up @@ -177,18 +220,7 @@ def deep_clone(self) -> "Segment":
self.version += 1 # type: ignore[operator]
self.save_without_historical_record()

cloned_rules = []
for rule in self.rules.all():
cloned_rule = rule.deep_clone(cloned_segment)
cloned_rules.append(cloned_rule)

cloned_segment.refresh_from_db()

assert (
len(self.rules.all())
== len(cloned_rules)
== len(cloned_segment.rules.all())
), "Mismatch during rules creation"
self.clone_segment_rules(cloned_segment)

return cloned_segment

Expand Down
22 changes: 22 additions & 0 deletions api/segments/serializers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any, cast

from common.segments.serializers import SegmentSerializer
from rest_framework import serializers

from segments.models import Segment
Expand All @@ -19,3 +22,22 @@ class SegmentListQuerySerializer(serializers.Serializer): # type: ignore[type-a
help_text="Optionally provide the id of an identity to get only the segments they match",
)
include_feature_specific = serializers.BooleanField(required=False, default=True)


class CloneSegmentSerializer(SegmentSerializer):
class Meta:
model = Segment
fields = ("name",)

def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
if not attrs.get("name"):
raise serializers.ValidationError("Name is required to clone a segment")
return attrs

def create(self, validated_data: dict[str, Any]) -> Segment:
name = validated_data.get("name")
source_segment = self.context.get("source_segment")
assert source_segment is not None, (
"Source segment is required to clone a segment"
)
return cast(Segment, source_segment.clone(name))
28 changes: 26 additions & 2 deletions api/segments/views.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
from typing import Any

from common.projects.permissions import VIEW_PROJECT
from common.segments.serializers import (
SegmentSerializer,
)
from django.utils.decorators import method_decorator
from drf_yasg.utils import swagger_auto_schema # type: ignore[import-untyped]
from rest_framework import viewsets
from rest_framework import status, viewsets
from rest_framework.decorators import action, api_view
from rest_framework.generics import get_object_or_404
from rest_framework.request import Request
from rest_framework.response import Response

from app.pagination import CustomPagination
Expand All @@ -24,7 +26,10 @@

from .models import Segment
from .permissions import SegmentPermissions
from .serializers import SegmentListQuerySerializer
from .serializers import (
CloneSegmentSerializer,
SegmentListQuerySerializer,
)

logger = logging.getLogger()

Expand Down Expand Up @@ -119,6 +124,25 @@ def associated_features(self, request, *args, **kwargs): # type: ignore[no-unty
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)

@swagger_auto_schema(
request_body=CloneSegmentSerializer,
responses={201: SegmentSerializer()},
method="post",
) # type: ignore[misc]
@action(
detail=True,
methods=["POST"],
url_path="clone",
serializer_class=CloneSegmentSerializer,
)
def clone(self, request: Request, *args: Any, **kwargs: Any) -> Response:
serializer = CloneSegmentSerializer(
data=request.data, context={"source_segment": self.get_object()}
)
serializer.is_valid(raise_exception=True)
clone = serializer.save()
return Response(SegmentSerializer(clone).data, status=status.HTTP_201_CREATED)


@swagger_auto_schema(responses={200: SegmentSerializer()}, method="get")
@api_view(["GET"])
Expand Down
16 changes: 16 additions & 0 deletions api/tests/unit/segments/test_unit_segments_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
from rest_framework.exceptions import ErrorDetail, ValidationError

from segments.serializers import CloneSegmentSerializer


def test_clone_segment_serializer_validation_without_name_should_fail() -> None:
# Given
serializer = CloneSegmentSerializer()
# When
with pytest.raises(ValidationError) as exception:
serializer.validate({"name": ""})
# Then
assert exception.value.detail == [
ErrorDetail(string="Name is required to clone a segment", code="invalid")
]
116 changes: 116 additions & 0 deletions api/tests/unit/segments/test_unit_segments_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,3 +1493,119 @@ def test_include_feature_specific_query_filter__false(
# Then
assert response.json()["count"] == 1
assert [res["id"] for res in response.json()["results"]] == [segment.id]


@pytest.mark.parametrize(
"source_segment",
[
(lazy_fixture("segment")),
(lazy_fixture("feature_specific_segment")),
],
)
def test_clone_segment(
project: Project,
admin_client: APIClient,
source_segment: Segment,
required_a_segment_metadata_field: MetadataModelField,
) -> None:
# Given
url = reverse(
"api-v1:projects:project-segments-clone", args=[project.id, source_segment.id]
)
new_segment_name = "cloned_segment"
data = {
"name": new_segment_name,
}
# Preparing the rules
segment_rule = SegmentRule.objects.create(
segment=source_segment,
type=SegmentRule.ALL_RULE,
)
sub_rule = SegmentRule.objects.create(
rule=segment_rule,
type=SegmentRule.ALL_RULE,
)

# Preparing the conditions
created_condition = Condition.objects.create(
rule=sub_rule,
property="foo",
operator=EQUAL,
value="bar",
created_with_segment=False,
)

# Preparing the metadata
segment_content_type = ContentType.objects.get_for_model(source_segment)
metadata = Metadata.objects.create(
object_id=source_segment.id,
content_type=segment_content_type,
model_field=required_a_segment_metadata_field,
field_value="test-clone-segment-metadata",
)

# When
response = admin_client.post(
url, data=json.dumps(data), content_type="application/json"
)

# Then
assert response.status_code == status.HTTP_201_CREATED

response_data = response.json()
assert response_data["name"] == new_segment_name
assert response_data["project"] == project.id
assert response_data["id"] != source_segment.id

# Testing cloned segment main attributes
cloned_segment = Segment.objects.get(id=response_data["id"])
assert cloned_segment.name == new_segment_name
assert cloned_segment.project_id == project.id
assert cloned_segment.description == source_segment.description
assert cloned_segment.version == 1
assert cloned_segment.version_of_id == cloned_segment.id
assert cloned_segment.change_request is None
assert cloned_segment.feature_id == source_segment.feature_id

# Testing cloning of rules
assert cloned_segment.rules.count() == source_segment.rules.count()

cloned_top_rule = cloned_segment.rules.first()
cloned_sub_rule = cloned_top_rule.rules.first()

assert cloned_top_rule.type == segment_rule.type
assert cloned_sub_rule.type == segment_rule.type

# Testing cloning of sub-rules conditions
cloned_condition = cloned_sub_rule.conditions.first()

assert cloned_condition.property == created_condition.property
assert cloned_condition.operator == created_condition.operator
assert cloned_condition.value == created_condition.value

# Testing cloning of metadata
cloned_metadata = cloned_segment.metadata.first()
assert cloned_metadata.model_field == metadata.model_field
assert cloned_metadata.field_value == metadata.field_value
assert cloned_metadata.id != metadata.id


def test_clone_segment_without_name_should_fail(
project: Project,
admin_client: APIClient,
segment: Segment,
) -> None:
# Given
url = reverse(
"api-v1:projects:project-segments-clone", args=[project.id, segment.id]
)
data = {
"no-name": "",
}
# When
response = admin_client.post(
url, data=json.dumps(data), content_type="application/json"
)

# Then
assert response.status_code == status.HTTP_400_BAD_REQUEST
Loading