Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions openwisp_radius/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from .. import settings as app_settings
from ..base.forms import PasswordResetForm
from ..counters.exceptions import MaxQuotaReached, SkipCheck
from ..counters.exceptions import SkipCheck
from ..registration import REGISTRATION_METHOD_CHOICES
from ..utils import (
get_group_checks,
Expand Down Expand Up @@ -304,12 +304,11 @@ def get_result(self, obj):
group=self.context["group"],
group_check=obj,
)
# Python can handle 64 bit numbers and
# hence we don't need to display Gigawords
remaining = counter.check(gigawords=False)
return int(obj.value) - remaining
except MaxQuotaReached:
return int(obj.value)
consumed = counter.consumed()
value = int(obj.value)
if consumed > value:
consumed = value
return consumed
except (SkipCheck, ValueError, KeyError):
return None

Expand Down
29 changes: 20 additions & 9 deletions openwisp_radius/counters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,18 @@ def check_name(self): # pragma: no cover
pass

@property
@abstractmethod
def reply_name(self): # pragma: no cover
pass
def reply_names(self):
# BACKWARD COMPATIBILITY: In previous versions of openwisp-radius,
# the Counter.reply_name was a string instead of a tuple. Thus,
# we need to convert it to a tuple if it's a string.
reply_name = getattr(self, "reply_name", None)
if not reply_name:
raise NotImplementedError(
"Counter classes must define 'reply_names' property."
)
if isinstance(reply_name, str):
return (reply_name,)
return reply_name

@property
@abstractmethod
Expand All @@ -43,7 +52,6 @@ def get_sql_params(self, start_time, end_time): # pragma: no cover
# sqlcounter module, now we can translate it with gettext
# or customize it (in new counter classes) if needed
reply_message = _("Your maximum daily usage time has been reached")
gigawords = False

def __init__(self, user, group, group_check):
self.user = user
Expand Down Expand Up @@ -72,7 +80,7 @@ def get_attribute_type(self):

def get_reset_timestamps(self):
try:
return resets[self.reset](self.user)
return resets[self.reset](self.user, counter=self)
except KeyError:
raise SkipCheck(
message=f'Reset time with key "{self.reset}" not available.',
Expand All @@ -93,7 +101,7 @@ def get_counter(self):
# or if nothing is returned (no sessions present), return zero
return row[0] or 0

def check(self, gigawords=gigawords):
def check(self):
if not self.group_check:
raise SkipCheck(
message=(
Expand Down Expand Up @@ -134,12 +142,15 @@ def check(self, gigawords=gigawords):
reply_message=self.reply_message,
)

return int(remaining)
return (int(remaining),)

def consumed(self):
return int(self.get_counter())


class BaseDailyCounter(BaseCounter):
check_name = "Max-Daily-Session"
reply_name = "Session-Timeout"
reply_names = ("Session-Timeout",)
reset = "daily"

def get_sql_params(self, start_time, end_time):
Expand All @@ -152,7 +163,7 @@ def get_sql_params(self, start_time, end_time):


class BaseTrafficCounter(BaseCounter):
reply_name = app_settings.TRAFFIC_COUNTER_REPLY_NAME
reply_names = (app_settings.TRAFFIC_COUNTER_REPLY_NAME,)

def get_sql_params(self, start_time, end_time):
return [
Expand Down
10 changes: 5 additions & 5 deletions openwisp_radius/counters/resets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,29 @@ def _timestamp(start, end):
return int(start.timestamp()), int(end.timestamp())


def _daily(user=None):
def _daily(user=None, **kwargs):
dt = _today()
start = datetime(dt.year, dt.month, dt.day)
end = datetime(dt.year, dt.month, dt.day) + timedelta(days=1)
return _timestamp(start, end)


def _weekly(user=None):
def _weekly(user=None, **kwargs):
dt = _today()
start = dt - timedelta(days=dt.weekday())
start = datetime(start.year, start.month, start.day)
end = start + timedelta(days=7)
return _timestamp(start, end)


def _monthly(user=None):
def _monthly(user=None, **kwargs):
dt = _today()
start = datetime(dt.year, dt.month, 1)
end = datetime(dt.year, dt.month, 1) + relativedelta(months=1)
return _timestamp(start, end)


def _monthly_subscription(user):
def _monthly_subscription(user, **kwargs):
dt = _today()
day_joined = user.date_joined.day
# subscription cycle starts on the day of month the user joined
Expand All @@ -45,7 +45,7 @@ def _monthly_subscription(user):
return _timestamp(start, end)


def _never(user=None):
def _never(user=None, **kwargs):
return 0, None


Expand Down
38 changes: 38 additions & 0 deletions openwisp_radius/tests/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,44 @@ def test_user_radius_usage_view(self):
},
)

data3 = self.acct_post_data
data3.update(
dict(
session_id="40111117",
unique_id="12234f70",
input_octets=1000000000,
output_octets=1000000000,
username="tester",
)
)
self._create_radius_accounting(**data3)

with self.subTest("User consumed more than allowed limit"):
response = self.client.get(usage_url, HTTP_AUTHORIZATION=authorization)
self.assertEqual(response.status_code, 200)
self.assertIn("checks", response.data)
checks = response.data["checks"]
self.assertDictEqual(
dict(checks[0]),
{
"attribute": "Max-Daily-Session",
"op": ":=",
"value": "10800",
"result": 783,
"type": "seconds",
},
)
self.assertDictEqual(
dict(checks[1]),
{
"attribute": "Max-Daily-Session-Traffic",
"op": ":=",
"value": "3000000000",
"result": 3000000000,
"type": "bytes",
},
)

with self.subTest("Test user does not have RadiusUserGroup"):
RadiusUserGroup.objects.all().delete()
response = self.client.get(usage_url, HTTP_AUTHORIZATION=authorization)
Expand Down
31 changes: 31 additions & 0 deletions openwisp_radius/tests/test_counters/test_base_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,37 @@ def test_abstract_instantiation(self):
BaseCounter(**opts)
self.assertIn("abstract class BaseCounter", str(ctx.exception))

def test_reply_name_backward_compatibility(self):
options = self._get_kwargs("Session-Timeout")

class BackwardCompatibleCounter(BaseCounter):
check_name = "Max-Daily-Session"
counter_name = "BackwardCompatibleCounter"
reset = "daily"
sql = "SELECT 1"

def get_sql_params(self, start_time, end_time):
return []

with self.subTest("Counter does not implement reply_names or reply_name"):
counter = BackwardCompatibleCounter(**options)
with self.assertRaises(NotImplementedError) as ctx:
counter.reply_names
self.assertIn(
"Counter classes must define 'reply_names' property.",
str(ctx.exception),
)

BackwardCompatibleCounter.reply_name = "Session-Timeout"
with self.subTest("Counter does not implement reply_names, uses reply_name"):
counter = BackwardCompatibleCounter(**options)
self.assertEqual(counter.reply_names, ("Session-Timeout",))

BackwardCompatibleCounter.reply_name = ("Session-Timeout",)
with self.subTest("Counter implements reply_names as tuple"):
counter = BackwardCompatibleCounter(**options)
self.assertEqual(counter.reply_names, ("Session-Timeout",))

@freeze_time("2021-11-03T08:21:44-04:00")
def test_resets(self):
with self.subTest("daily"):
Expand Down
14 changes: 7 additions & 7 deletions openwisp_radius/tests/test_counters/test_sqlite_counters.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@ def test_time_counter_repr(self):
def test_time_counter_no_sessions(self):
opts = self._get_kwargs("Max-Daily-Session")
counter = DailyCounter(**opts)
self.assertEqual(counter.check(), int(opts["group_check"].value))
self.assertEqual(counter.check(), (int(opts["group_check"].value),))

def test_time_counter_with_sessions(self):
opts = self._get_kwargs("Max-Daily-Session")
counter = DailyCounter(**opts)
self._create_radius_accounting(**_acct_data)
expected = int(opts["group_check"].value) - int(_acct_data["session_time"])
self.assertEqual(counter.check(), expected)
self.assertEqual(counter.check(), (expected,))
_acct_data2 = _acct_data.copy()
_acct_data2.update({"session_id": "2", "unique_id": "2", "session_time": "500"})
self._create_radius_accounting(**_acct_data2)
session_time = int(_acct_data["session_time"]) + int(
_acct_data2["session_time"]
)
expected = int(opts["group_check"].value) - session_time
self.assertEqual(counter.check(), expected)
self.assertEqual(counter.check(), (expected,))

@capture_any_output()
def test_counter_skip_exceptions(self):
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_counter_skip_exceptions(self):
def test_traffic_counter_no_sessions(self):
opts = self._get_kwargs("Max-Daily-Session-Traffic")
counter = DailyTrafficCounter(**opts)
self.assertEqual(counter.check(), int(opts["group_check"].value))
self.assertEqual(counter.check(), (int(opts["group_check"].value),))

def test_traffic_counter_with_sessions(self):
opts = self._get_kwargs("Max-Daily-Session-Traffic")
Expand All @@ -98,13 +98,13 @@ def test_traffic_counter_with_sessions(self):
self._create_radius_accounting(**acct)
traffic = int(acct["input_octets"]) + int(acct["output_octets"])
expected = int(opts["group_check"].value) - traffic
self.assertEqual(counter.check(), expected)
self.assertEqual(counter.check(), (expected,))

def test_traffic_counter_reply_and_check_name(self):
opts = self._get_kwargs("Max-Daily-Session-Traffic")
counter = DailyTrafficCounter(**opts)
self.assertEqual(counter.check_name, "Max-Daily-Session-Traffic")
self.assertEqual(counter.reply_name, "CoovaChilli-Max-Total-Octets")
self.assertEqual(counter.reply_names[0], "CoovaChilli-Max-Total-Octets")

def test_monthly_traffic_counter_with_sessions(self):
rg = RadiusGroup.objects.filter(name="test-org-users").first()
Expand All @@ -121,7 +121,7 @@ def test_monthly_traffic_counter_with_sessions(self):
self._create_radius_accounting(**acct)
traffic = int(acct["input_octets"]) + int(acct["output_octets"])
expected = int(opts["group_check"].value) - traffic
self.assertEqual(counter.check(), expected)
self.assertEqual(counter.check(), (expected,))


del BaseTransactionTestCase
14 changes: 6 additions & 8 deletions openwisp_radius/tests/test_selenium.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.contrib.auth import get_user_model
from django.contrib.staticfiles.testing import StaticLiveServerTestCase
from django.test import tag
from django.urls import reverse
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import Select
Expand All @@ -15,6 +16,7 @@
OrganizationRadiusSettings = load_model("OrganizationRadiusSettings")


@tag("selenium_tests")
class BasicTest(
SeleniumTestMixin, FileMixin, StaticLiveServerTestCase, TestOrganizationMixin
):
Expand Down Expand Up @@ -44,8 +46,7 @@ def test_batch_user_creation(self):
# Select the previously created organization
option = self.find_element(
By.XPATH,
"//li[contains(@class, 'select2-results__option') and "
"text()='test org']",
"//li[contains(@class, 'select2-results__option') and text()='test org']",
10,
)
option.click()
Expand Down Expand Up @@ -87,8 +88,7 @@ def test_standard_csv_import(self):
organization.click()
option = self.find_element(
By.XPATH,
"//li[contains(@class, 'select2-results__option') and "
"text()='test org']",
"//li[contains(@class, 'select2-results__option') and text()='test org']",
10,
)
option.click()
Expand Down Expand Up @@ -135,8 +135,7 @@ def test_import_with_hashed_passwords(self):
organization.click()
option = self.find_element(
By.XPATH,
"//li[contains(@class, 'select2-results__option') and "
"text()='test org']",
"//li[contains(@class, 'select2-results__option') and text()='test org']",
10,
)
option.click()
Expand Down Expand Up @@ -179,8 +178,7 @@ def test_csv_user_generation(self):
organization.click()
option = self.find_element(
By.XPATH,
"//li[contains(@class, 'select2-results__option') and "
"text()='test org']",
"//li[contains(@class, 'select2-results__option') and text()='test org']",
10,
)
option.click()
Expand Down
Loading
Loading