Skip to content

Commit 8894464

Browse files
committed
fixes bug in PoissonIntervals distribution where if last rate is zero, it cycles back to beginning too early
1 parent 24a9025 commit 8894464

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

ciw/dists/distributions.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -561,9 +561,7 @@ def __init__(self, rates, endpoints, max_sample_date):
561561
self.max_sample_date = max_sample_date
562562
self.get_intervals()
563563
self.get_dates()
564-
self.inter_arrivals = [t - s for s, t in zip(self.dates, self.dates[1:])]
565-
if self.inter_arrivals == []:
566-
self.inter_arrivals = [float("inf")]
564+
self.inter_arrivals = [t - s for s, t in zip(self.dates, self.dates[1:])] + [float('inf')]
567565
super().__init__(self.inter_arrivals)
568566

569567
def __repr__(self):

ciw/tests/test_sampling.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -1434,7 +1434,7 @@ def test_poissoninterval_dist_object(self):
14341434
expected_dates = [0]
14351435
for t in Pi.inter_arrivals:
14361436
expected_dates.append(expected_dates[-1] + t)
1437-
self.assertEqual(Pi.dates, expected_dates)
1437+
self.assertEqual(Pi.dates, expected_dates[:-1])
14381438
self.assertLessEqual(Pi.dates[-1], Pi.max_sample_date)
14391439

14401440
self.assertRaises(
@@ -1483,19 +1483,38 @@ def test_poissoninterval_dist_object(self):
14831483
def test_poissoninterval_rate_zero(self):
14841484
ciw.seed(5)
14851485
Pi = ciw.dists.PoissonIntervals(
1486-
rates=[10, 0], endpoints=[1, 2], max_sample_date=15
1486+
rates=[10, 0], endpoints=[1, 2], max_sample_date=15.5
14871487
)
14881488
arrivals_when_rate_is_zero = [date for date in Pi.dates if int(date) % 2 == 1]
14891489
self.assertEqual(arrivals_when_rate_is_zero, [])
1490+
final_date = Pi.dates[-1]
1491+
self.assertTrue(final_date < 15)
14901492

14911493
ciw.seed(5)
14921494
Pi = ciw.dists.PoissonIntervals(
1493-
rates=[0, 0], endpoints=[1, 2], max_sample_date=15
1495+
rates=[0, 0], endpoints=[1, 2], max_sample_date=15.5
14941496
)
14951497
arrivals_when_rate_is_zero = [date for date in Pi.dates if int(date) % 2 == 1]
14961498
self.assertEqual(arrivals_when_rate_is_zero, [])
14971499
self.assertEqual(Pi.sample(), float("inf"))
14981500

1501+
ciw.seed(5)
1502+
Pi = ciw.dists.PoissonIntervals(
1503+
rates=[3, 2, 0], endpoints=[1, 2, 10], max_sample_date=9
1504+
)
1505+
final_date = Pi.dates[-1]
1506+
self.assertTrue(final_date < 2)
1507+
1508+
t = 0
1509+
arrivals = []
1510+
while t < 8:
1511+
t += Pi.sample()
1512+
arrivals.append(t)
1513+
final_sample = arrivals[-1]
1514+
final_arrival = arrivals[-2]
1515+
self.assertEqual(final_sample, float('inf'))
1516+
self.assertEqual(final_arrival, final_date)
1517+
14991518
def test_poissoninterval_against_theory(self):
15001519
"""
15011520
rates = [5, 1.5, 3]

0 commit comments

Comments
 (0)