Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions users/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from mitol.scim.adapters import UserAdapter

from b2b.models import ContractPage
from openedx.models import OpenEdxUser
from users.models import LegalAddress, UserProfile

Expand All @@ -25,6 +26,7 @@ class LearnUserAdapter(UserAdapter):
user_profile: UserProfile
legal_address: LegalAddress
openedx_user: OpenEdxUser
b2b_contracts: ContractPage

def __init__(self, obj, request=None):
super().__init__(obj, request=request)
Expand All @@ -33,9 +35,12 @@ def __init__(self, obj, request=None):
self.obj, "user_profile", UserProfile()
)

self.legal_address = self.obj.legal_address = getattr(
self.obj, "legal_address", LegalAddress()
)
try:
self.legal_address = self.obj.legal_address # triggers DB fetch if needed
except LegalAddress.DoesNotExist:
self.legal_address = LegalAddress()

self.b2b_contracts = self.obj.b2b_contracts

self.openedx_user = self.obj.openedx_user
if self.openedx_user is None:
Expand All @@ -54,24 +59,28 @@ def from_dict(self, d):
Consume a ``dict`` conforming to the SCIM User Schema, updating the
internal user object with data from the ``dict``.

Please note, the user object is not saved within this method. To
persist the changes made by this method, please call ``.save()`` on the
adapter. Eg::

scim_user.from_dict(d)
scim_user.save()
Note: This method does NOT save the user object. To persist changes,
call ``.save()`` on the adapter.
"""
super().from_dict(d)

self.obj.name = d.get("fullName", "")

first_name = d.get("name", {}).get("given_name", "")
if first_name:
self.legal_address.first_name = first_name
name_data = d.get("name", {})

last_name = d.get("name", {}).get("last_name", "")
if last_name:
self.legal_address.last_name = last_name
self.legal_address.first_name = (
name_data.get("given_name") or self.legal_address.first_name
)
self.legal_address.last_name = (
name_data.get("last_name") or self.legal_address.last_name
)

organization_name = d.get("organization")
if organization_name:
contract_pages = ContractPage.objects.filter(
organization__name=organization_name
)
self.b2b_contracts.add(*contract_pages)

def _save_related(self):
self.user_profile.user = self.obj
Expand All @@ -82,3 +91,5 @@ def _save_related(self):

self.openedx_user.user = self.obj
self.openedx_user.save()

self.obj.b2b_contracts.add(*self.b2b_contracts)
84 changes: 84 additions & 0 deletions users/adapters_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from unittest import mock

import pytest

from b2b.factories import ContractPageFactory
from openedx.models import OpenEdxUser
from users.adapters import LearnUserAdapter
from users.factories import UserFactory
from users.models import LegalAddress, UserProfile


@pytest.mark.django_db
def test_init_sets_related_objects():
user = UserFactory()
adapter = LearnUserAdapter(user)

assert isinstance(adapter.user_profile, UserProfile)
assert isinstance(adapter.legal_address, LegalAddress)
assert isinstance(adapter.openedx_user, OpenEdxUser)


@pytest.mark.django_db
def test_display_name_returns_name():
user = UserFactory(name="John Doe")
adapter = LearnUserAdapter(user)

assert adapter.display_name == "John Doe"


@pytest.mark.django_db
def test_from_dict_updates_user_and_related():
user = UserFactory.create(name="Old Name")
user.legal_address.first_name = "OldFirst"
user.legal_address.last_name = "OldLast"
user.legal_address.save()
adapter = LearnUserAdapter(user)

contract_page = ContractPageFactory.create(organization__name="Acme Corp")
data = {
"fullName": "New Name",
"name": {"given_name": "NewFirst", "last_name": "NewLast"},
"organization": "Acme Corp",
}

adapter.from_dict(data)
adapter._save_related() # noqa: SLF001
adapter.legal_address.refresh_from_db()
assert adapter.obj.name == "New Name"
assert adapter.legal_address.first_name == "NewFirst"
assert adapter.legal_address.last_name == "NewLast"
assert user.b2b_contracts.filter(id=contract_page.id).exists()


@pytest.mark.django_db
def test_from_dict_keeps_existing_names_if_missing():
user = UserFactory.create(name="Old Name")
user.legal_address.first_name = "OldFirst"
user.legal_address.last_name = "OldLast"
user.legal_address.save()
adapter = LearnUserAdapter(user)

data = {"fullName": "Another Name", "name": {}}
adapter.from_dict(data)

adapter.legal_address.refresh_from_db()

assert adapter.legal_address.first_name == "OldFirst"
assert adapter.legal_address.last_name == "OldLast"


@pytest.mark.django_db
def test_save_related_saves_all():
user = UserFactory()
adapter = LearnUserAdapter(user)

adapter.user_profile = mock.MagicMock()
adapter.legal_address = mock.MagicMock()
adapter.openedx_user = mock.MagicMock()

adapter._save_related() # noqa: SLF001

adapter.user_profile.save.assert_called_once()
adapter.legal_address.save.assert_called_once()
adapter.openedx_user.save.assert_called_once()
Loading