diff --git a/openwisp_radius/api/serializers.py b/openwisp_radius/api/serializers.py index cdcaa22f..c0da1253 100644 --- a/openwisp_radius/api/serializers.py +++ b/openwisp_radius/api/serializers.py @@ -33,7 +33,7 @@ from .. import settings as app_settings from ..base.forms import PasswordResetForm -from ..counters.exceptions import MaxQuotaReached, SkipCheck +from ..counters.exceptions import SkipCheck from ..registration import REGISTRATION_METHOD_CHOICES from ..utils import ( get_group_checks, @@ -304,12 +304,11 @@ def get_result(self, obj): group=self.context["group"], group_check=obj, ) - # Python can handle 64 bit numbers and - # hence we don't need to display Gigawords - remaining = counter.check(gigawords=False) - return int(obj.value) - remaining - except MaxQuotaReached: - return int(obj.value) + consumed = counter.consumed() + value = int(obj.value) + if consumed > value: + consumed = value + return consumed except (SkipCheck, ValueError, KeyError): return None diff --git a/openwisp_radius/counters/base.py b/openwisp_radius/counters/base.py index eac69fc3..d3265640 100644 --- a/openwisp_radius/counters/base.py +++ b/openwisp_radius/counters/base.py @@ -21,9 +21,18 @@ def check_name(self): # pragma: no cover pass @property - @abstractmethod - def reply_name(self): # pragma: no cover - pass + def reply_names(self): + # BACKWARD COMPATIBILITY: In previous versions of openwisp-radius, + # the Counter.reply_name was a string instead of a tuple. Thus, + # we need to convert it to a tuple if it's a string. + reply_name = getattr(self, "reply_name", None) + if not reply_name: + raise NotImplementedError( + "Counter classes must define 'reply_names' property." + ) + if isinstance(reply_name, str): + return (reply_name,) + return reply_name @property @abstractmethod @@ -43,7 +52,6 @@ def get_sql_params(self, start_time, end_time): # pragma: no cover # sqlcounter module, now we can translate it with gettext # or customize it (in new counter classes) if needed reply_message = _("Your maximum daily usage time has been reached") - gigawords = False def __init__(self, user, group, group_check): self.user = user @@ -72,7 +80,7 @@ def get_attribute_type(self): def get_reset_timestamps(self): try: - return resets[self.reset](self.user) + return resets[self.reset](self.user, counter=self) except KeyError: raise SkipCheck( message=f'Reset time with key "{self.reset}" not available.', @@ -93,7 +101,7 @@ def get_counter(self): # or if nothing is returned (no sessions present), return zero return row[0] or 0 - def check(self, gigawords=gigawords): + def check(self): if not self.group_check: raise SkipCheck( message=( @@ -134,12 +142,15 @@ def check(self, gigawords=gigawords): reply_message=self.reply_message, ) - return int(remaining) + return (int(remaining),) + + def consumed(self): + return int(self.get_counter()) class BaseDailyCounter(BaseCounter): check_name = "Max-Daily-Session" - reply_name = "Session-Timeout" + reply_names = ("Session-Timeout",) reset = "daily" def get_sql_params(self, start_time, end_time): @@ -152,7 +163,7 @@ def get_sql_params(self, start_time, end_time): class BaseTrafficCounter(BaseCounter): - reply_name = app_settings.TRAFFIC_COUNTER_REPLY_NAME + reply_names = (app_settings.TRAFFIC_COUNTER_REPLY_NAME,) def get_sql_params(self, start_time, end_time): return [ diff --git a/openwisp_radius/counters/resets.py b/openwisp_radius/counters/resets.py index 08cc6eeb..5288390f 100644 --- a/openwisp_radius/counters/resets.py +++ b/openwisp_radius/counters/resets.py @@ -11,14 +11,14 @@ def _timestamp(start, end): return int(start.timestamp()), int(end.timestamp()) -def _daily(user=None): +def _daily(user=None, **kwargs): dt = _today() start = datetime(dt.year, dt.month, dt.day) end = datetime(dt.year, dt.month, dt.day) + timedelta(days=1) return _timestamp(start, end) -def _weekly(user=None): +def _weekly(user=None, **kwargs): dt = _today() start = dt - timedelta(days=dt.weekday()) start = datetime(start.year, start.month, start.day) @@ -26,14 +26,14 @@ def _weekly(user=None): return _timestamp(start, end) -def _monthly(user=None): +def _monthly(user=None, **kwargs): dt = _today() start = datetime(dt.year, dt.month, 1) end = datetime(dt.year, dt.month, 1) + relativedelta(months=1) return _timestamp(start, end) -def _monthly_subscription(user): +def _monthly_subscription(user, **kwargs): dt = _today() day_joined = user.date_joined.day # subscription cycle starts on the day of month the user joined @@ -45,7 +45,7 @@ def _monthly_subscription(user): return _timestamp(start, end) -def _never(user=None): +def _never(user=None, **kwargs): return 0, None diff --git a/openwisp_radius/tests/test_api/test_api.py b/openwisp_radius/tests/test_api/test_api.py index cb7b11de..8cf77235 100644 --- a/openwisp_radius/tests/test_api/test_api.py +++ b/openwisp_radius/tests/test_api/test_api.py @@ -1170,6 +1170,44 @@ def test_user_radius_usage_view(self): }, ) + data3 = self.acct_post_data + data3.update( + dict( + session_id="40111117", + unique_id="12234f70", + input_octets=1000000000, + output_octets=1000000000, + username="tester", + ) + ) + self._create_radius_accounting(**data3) + + with self.subTest("User consumed more than allowed limit"): + response = self.client.get(usage_url, HTTP_AUTHORIZATION=authorization) + self.assertEqual(response.status_code, 200) + self.assertIn("checks", response.data) + checks = response.data["checks"] + self.assertDictEqual( + dict(checks[0]), + { + "attribute": "Max-Daily-Session", + "op": ":=", + "value": "10800", + "result": 783, + "type": "seconds", + }, + ) + self.assertDictEqual( + dict(checks[1]), + { + "attribute": "Max-Daily-Session-Traffic", + "op": ":=", + "value": "3000000000", + "result": 3000000000, + "type": "bytes", + }, + ) + with self.subTest("Test user does not have RadiusUserGroup"): RadiusUserGroup.objects.all().delete() response = self.client.get(usage_url, HTTP_AUTHORIZATION=authorization) diff --git a/openwisp_radius/tests/test_counters/test_base_counter.py b/openwisp_radius/tests/test_counters/test_base_counter.py index 58e4dee2..03fc4cdf 100644 --- a/openwisp_radius/tests/test_counters/test_base_counter.py +++ b/openwisp_radius/tests/test_counters/test_base_counter.py @@ -42,6 +42,37 @@ def test_abstract_instantiation(self): BaseCounter(**opts) self.assertIn("abstract class BaseCounter", str(ctx.exception)) + def test_reply_name_backward_compatibility(self): + options = self._get_kwargs("Session-Timeout") + + class BackwardCompatibleCounter(BaseCounter): + check_name = "Max-Daily-Session" + counter_name = "BackwardCompatibleCounter" + reset = "daily" + sql = "SELECT 1" + + def get_sql_params(self, start_time, end_time): + return [] + + with self.subTest("Counter does not implement reply_names or reply_name"): + counter = BackwardCompatibleCounter(**options) + with self.assertRaises(NotImplementedError) as ctx: + counter.reply_names + self.assertIn( + "Counter classes must define 'reply_names' property.", + str(ctx.exception), + ) + + BackwardCompatibleCounter.reply_name = "Session-Timeout" + with self.subTest("Counter does not implement reply_names, uses reply_name"): + counter = BackwardCompatibleCounter(**options) + self.assertEqual(counter.reply_names, ("Session-Timeout",)) + + BackwardCompatibleCounter.reply_name = ("Session-Timeout",) + with self.subTest("Counter implements reply_names as tuple"): + counter = BackwardCompatibleCounter(**options) + self.assertEqual(counter.reply_names, ("Session-Timeout",)) + @freeze_time("2021-11-03T08:21:44-04:00") def test_resets(self): with self.subTest("daily"): diff --git a/openwisp_radius/tests/test_counters/test_sqlite_counters.py b/openwisp_radius/tests/test_counters/test_sqlite_counters.py index 1e3a4290..b3457599 100644 --- a/openwisp_radius/tests/test_counters/test_sqlite_counters.py +++ b/openwisp_radius/tests/test_counters/test_sqlite_counters.py @@ -28,14 +28,14 @@ def test_time_counter_repr(self): def test_time_counter_no_sessions(self): opts = self._get_kwargs("Max-Daily-Session") counter = DailyCounter(**opts) - self.assertEqual(counter.check(), int(opts["group_check"].value)) + self.assertEqual(counter.check(), (int(opts["group_check"].value),)) def test_time_counter_with_sessions(self): opts = self._get_kwargs("Max-Daily-Session") counter = DailyCounter(**opts) self._create_radius_accounting(**_acct_data) expected = int(opts["group_check"].value) - int(_acct_data["session_time"]) - self.assertEqual(counter.check(), expected) + self.assertEqual(counter.check(), (expected,)) _acct_data2 = _acct_data.copy() _acct_data2.update({"session_id": "2", "unique_id": "2", "session_time": "500"}) self._create_radius_accounting(**_acct_data2) @@ -43,7 +43,7 @@ def test_time_counter_with_sessions(self): _acct_data2["session_time"] ) expected = int(opts["group_check"].value) - session_time - self.assertEqual(counter.check(), expected) + self.assertEqual(counter.check(), (expected,)) @capture_any_output() def test_counter_skip_exceptions(self): @@ -88,7 +88,7 @@ def test_counter_skip_exceptions(self): def test_traffic_counter_no_sessions(self): opts = self._get_kwargs("Max-Daily-Session-Traffic") counter = DailyTrafficCounter(**opts) - self.assertEqual(counter.check(), int(opts["group_check"].value)) + self.assertEqual(counter.check(), (int(opts["group_check"].value),)) def test_traffic_counter_with_sessions(self): opts = self._get_kwargs("Max-Daily-Session-Traffic") @@ -98,13 +98,13 @@ def test_traffic_counter_with_sessions(self): self._create_radius_accounting(**acct) traffic = int(acct["input_octets"]) + int(acct["output_octets"]) expected = int(opts["group_check"].value) - traffic - self.assertEqual(counter.check(), expected) + self.assertEqual(counter.check(), (expected,)) def test_traffic_counter_reply_and_check_name(self): opts = self._get_kwargs("Max-Daily-Session-Traffic") counter = DailyTrafficCounter(**opts) self.assertEqual(counter.check_name, "Max-Daily-Session-Traffic") - self.assertEqual(counter.reply_name, "CoovaChilli-Max-Total-Octets") + self.assertEqual(counter.reply_names[0], "CoovaChilli-Max-Total-Octets") def test_monthly_traffic_counter_with_sessions(self): rg = RadiusGroup.objects.filter(name="test-org-users").first() @@ -121,7 +121,7 @@ def test_monthly_traffic_counter_with_sessions(self): self._create_radius_accounting(**acct) traffic = int(acct["input_octets"]) + int(acct["output_octets"]) expected = int(opts["group_check"].value) - traffic - self.assertEqual(counter.check(), expected) + self.assertEqual(counter.check(), (expected,)) del BaseTransactionTestCase diff --git a/openwisp_radius/tests/test_selenium.py b/openwisp_radius/tests/test_selenium.py index b266c14b..642850bf 100644 --- a/openwisp_radius/tests/test_selenium.py +++ b/openwisp_radius/tests/test_selenium.py @@ -1,5 +1,6 @@ from django.contrib.auth import get_user_model from django.contrib.staticfiles.testing import StaticLiveServerTestCase +from django.test import tag from django.urls import reverse from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import Select @@ -15,6 +16,7 @@ OrganizationRadiusSettings = load_model("OrganizationRadiusSettings") +@tag("selenium_tests") class BasicTest( SeleniumTestMixin, FileMixin, StaticLiveServerTestCase, TestOrganizationMixin ): @@ -44,8 +46,7 @@ def test_batch_user_creation(self): # Select the previously created organization option = self.find_element( By.XPATH, - "//li[contains(@class, 'select2-results__option') and " - "text()='test org']", + "//li[contains(@class, 'select2-results__option') and text()='test org']", 10, ) option.click() @@ -87,8 +88,7 @@ def test_standard_csv_import(self): organization.click() option = self.find_element( By.XPATH, - "//li[contains(@class, 'select2-results__option') and " - "text()='test org']", + "//li[contains(@class, 'select2-results__option') and text()='test org']", 10, ) option.click() @@ -135,8 +135,7 @@ def test_import_with_hashed_passwords(self): organization.click() option = self.find_element( By.XPATH, - "//li[contains(@class, 'select2-results__option') and " - "text()='test org']", + "//li[contains(@class, 'select2-results__option') and text()='test org']", 10, ) option.click() @@ -179,8 +178,7 @@ def test_csv_user_generation(self): organization.click() option = self.find_element( By.XPATH, - "//li[contains(@class, 'select2-results__option') and " - "text()='test org']", + "//li[contains(@class, 'select2-results__option') and text()='test org']", 10, ) option.click() diff --git a/openwisp_radius/utils.py b/openwisp_radius/utils.py index 360313be..a3405efa 100644 --- a/openwisp_radius/utils.py +++ b/openwisp_radius/utils.py @@ -333,7 +333,7 @@ def execute_counter_checks( continue try: counter = Counter(user=user, group=group, group_check=group_check) - remaining = counter.check() + results = counter.check() except SkipCheck: continue except Exception as e: @@ -342,32 +342,41 @@ def execute_counter_checks( continue if raise_quota_exceeded: raise - if remaining is None: + if results is None: continue - reply_name = counter.reply_name - # Send remaining value in RADIUS reply, if needed. - # This emulates the implementation of sqlcounter in freeradius - # which sends the reply message only if the value is smaller - # than what was defined to a previous reply message - if reply_name not in counter_data or remaining < _get_reply_value( - counter_data, counter - ): - counter_data[reply_name] = remaining + # BACKWARD COMPATIBILITY: The previous implementation of counters + # returned a single value instead of a tuple/list when there was + # only one reply name defined. + # We need to handle this case to avoid breaking existing counters. + if not isinstance(results, (list, tuple)): + results = (results,) + # We need to map the value to the correct reply name. + # This allows counters to define multiple reply names. + for reply_name, value in zip(counter.reply_names, results): + # Send remaining value in RADIUS reply, if needed. + # This emulates the implementation of sqlcounter in freeradius + # which sends the reply message only if the value is smaller + # than what was defined to a previous reply message + if reply_name not in counter_data or value < _get_reply_value( + counter_data, reply_name + ): + counter_data[reply_name] = value + return counter_data -def _get_reply_value(data, counter): +def _get_reply_value(data, reply_name): """ Helper function to get reply value from counter data for comparison. Args: data: Dictionary containing RADIUS attributes - counter: Counter instance + reply_name: Name of the reply attribute Returns: int or float: Reply value as integer, or math.inf if conversion fails """ - reply_entry = data.get(counter.reply_name, {}) + reply_entry = data.get(reply_name, {}) value = reply_entry.get("value") if value is None: return math.inf @@ -375,7 +384,7 @@ def _get_reply_value(data, counter): return int(value) except (ValueError, TypeError): logger.warning( - f'{counter.reply_name} value ("{value}") ' "cannot be converted to integer." + f'{reply_name} value ("{value}") cannot be converted to integer.' ) return math.inf