Skip to content

Commit 6725dc5

Browse files
authored
Bug Fix: Inaccurate Tails of Incomplete Gamma (#481)
1 parent a5b5a21 commit 6725dc5

File tree

6 files changed

+5781
-30
lines changed

6 files changed

+5781
-30
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ swift_cc_test(
199199
"tests/test_block_utils.cc",
200200
"tests/test_call_trace.cc",
201201
"tests/test_callers.cc",
202+
"tests/test_chi_squared_versus_gsl.cc",
202203
"tests/test_compression.cc",
203204
"tests/test_concatenate.cc",
204205
"tests/test_conditional_gaussian.cc",

include/albatross/Stats

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef ALBATROSS_STATS_H
1414
#define ALBATROSS_STATS_H
1515

16+
#include <albatross/src/details/typecast.hpp>
1617
#include <albatross/src/stats/gaussian.hpp>
1718
#include <albatross/src/stats/gauss_legendre.hpp>
1819
#include <albatross/src/stats/incomplete_gamma.hpp>

include/albatross/src/stats/chi_squared.hpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ namespace albatross {
2626

2727
namespace details {
2828

29-
inline double chi_squared_cdf_unsafe(double x, std::size_t degrees_of_freedom) {
30-
return incomplete_gamma(0.5 * cast::to_double(degrees_of_freedom), 0.5 * x);
29+
inline double chi_squared_cdf_unsafe(double x, double degrees_of_freedom) {
30+
return incomplete_gamma(0.5 * degrees_of_freedom, 0.5 * x);
3131
}
3232

33-
inline double chi_squared_cdf_safe(double x, std::size_t degrees_of_freedom) {
33+
inline double chi_squared_cdf_safe(double x, double degrees_of_freedom) {
3434

3535
if (std::isnan(x) || x < 0.) {
3636
return NAN;
@@ -53,7 +53,16 @@ inline double chi_squared_cdf_safe(double x, std::size_t degrees_of_freedom) {
5353

5454
} // namespace details
5555

56-
inline double chi_squared_cdf(double x, std::size_t degrees_of_freedom) {
56+
template <typename IntType,
57+
typename = std::enable_if_t<std::is_integral<IntType>::value>>
58+
inline double chi_squared_cdf(double x, IntType degrees_of_freedom) {
59+
// due to implicit argument conversions we can't directly use cast::to_double
60+
// here.
61+
return details::chi_squared_cdf_safe(x,
62+
static_cast<double>(degrees_of_freedom));
63+
}
64+
65+
inline double chi_squared_cdf(double x, double degrees_of_freedom) {
5766
return details::chi_squared_cdf_safe(x, degrees_of_freedom);
5867
}
5968

include/albatross/src/stats/incomplete_gamma.hpp

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,32 +75,11 @@ inline double incomplete_gamma_quadrature_recursive(double lb, double ub,
7575

7676
inline std::pair<double, double> incomplete_gamma_quadrature_bounds(double a,
7777
double z) {
78-
79-
if (a > 800) {
80-
return std::make_pair(std::max(0., std::min(z, a) - 11 * sqrt(a)),
81-
std::min(z, a + 10 * sqrt(a)));
82-
} else if (a > 300) {
83-
return std::make_pair(std::max(0., std::min(z, a) - 10 * sqrt(a)),
84-
std::min(z, a + 9 * sqrt(a)));
85-
} else if (a > 90) {
86-
return std::make_pair(std::max(0., std::min(z, a) - 9 * sqrt(a)),
87-
std::min(z, a + 8 * sqrt(a)));
88-
} else if (a > 70) {
89-
return std::make_pair(std::max(0., std::min(z, a) - 8 * sqrt(a)),
90-
std::min(z, a + 7 * sqrt(a)));
91-
} else if (a > 50) {
92-
return std::make_pair(std::max(0., std::min(z, a) - 7 * sqrt(a)),
93-
std::min(z, a + 6 * sqrt(a)));
94-
} else if (a > 40) {
95-
return std::make_pair(std::max(0., std::min(z, a) - 6 * sqrt(a)),
96-
std::min(z, a + 5 * sqrt(a)));
97-
} else if (a > 30) {
98-
return std::make_pair(std::max(0., std::min(z, a) - 5 * sqrt(a)),
99-
std::min(z, a + 4 * sqrt(a)));
100-
} else {
101-
return std::make_pair(std::max(0., std::min(z, a) - 4 * sqrt(a)),
102-
std::min(z, a + 4 * sqrt(a)));
103-
}
78+
// NOTE: GCEM uses a large conditional block to select tighter bounds, but in
79+
// practice those bounds were not tight enough, particularly on the upper
80+
// bound, so we've modified this function to be more conservative
81+
return std::make_pair(std::max(0., std::min(z, a) - 12 * sqrt(a)),
82+
std::min(z, a + 13 * sqrt(a + 1)));
10483
}
10584

10685
inline double incomplete_gamma_quadrature(double a, double z) {

0 commit comments

Comments
 (0)