-
Notifications
You must be signed in to change notification settings - Fork 135
Add rewrite for softplus(log(x)) -> log1p(x)
#1452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -453,6 +453,13 @@ def local_exp_log_nan_switch(fgraph, node): | |||
new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype)) | |||
return [new_out] | |||
|
|||
# Case for softplus(log(x)) -> log1p(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick I prefer to refer to it by log1pexp
, which we have as an alias to softplus:
pytensor/pytensor/tensor/math.py
Line 2474 in ff98ab8
log1pexp = softplus |
Also we can add a similar case for log1mexp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this (with the code below) and it works fine in its domain [0, 1]. I will add it too.
Hmm I don't know if that introduces numerical precision issues... However the somewhat converse |
data_invalid = data_valid - 2 | ||
|
||
x = fmatrix() | ||
f = function([x], softplus(log(x)), mode=self.mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you want you can check against the expected graph directly, something like
assert equal_computations(f.maker.fgraph.outputs, [pt.switch(x > 0, pt.log1p(x), np.asarray([[np.nan]], dtype="float32")])
Or something like that. This is not a request!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It took me a while to figure out how to make the test work.
It works fine if I apply it to a scalar rewritten output, like:
x = pt.scalar("x")
out = pt.softplus(pt.log(x))
new_out = rewrite_graph(out, include=("canonicalize", "stabilize", "specialize"))
equal_computations([new_out], [pt.switch(x >= 0, pt.log1p(x), pt.nan)])
But if I want to apply within the test (where the function is also applied to test data), I need to fix a few things.
First of all, the constants have to be matrices of the right type: int8 and float32, respectively (which was not obvious to me).
Then, for some reason, I could only make it work after overriding the compile mode:
x = fmatrix()
mode=get_mode('FAST_COMPILE').including("local_exp_log", "local_exp_log_nan_switch")
f = function([x], softplus(log(x)), mode=mode)
assert equal_computations(f.maker.fgraph.outputs, [pt.switch(x >= np.array([[0]], dtype=np.int8), pt.log1p(x), np.array([[np.nan]], dtype=np.float32))])
For some reason the mode used in the test (set at class level) seems to do something extra and equal_computations
returns false, possibly due to:
nd_x.op: Elemwise(scalar_op=Switch,inplace_pattern=<frozendict {0: 1}>)
nd_y.op: Elemwise(scalar_op=Switch,inplace_pattern=<frozendict {}>)
I don't know enough of the internals of PyTensor to find a solution, except forcing the mode above inside this test.
Is it fine to override the class compile mode with the compile mode above?
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1452 +/- ##
=======================================
Coverage 82.12% 82.12%
=======================================
Files 211 211
Lines 49757 49762 +5
Branches 8819 8820 +1
=======================================
+ Hits 40862 40867 +5
Misses 6715 6715
Partials 2180 2180
🚀 New features to boost your workflow:
|
@ricardoV94 Thanks for the feedback! Thank you for the suggested changes - noted. Let me first report some results related to my question about rewriting the different expressions. I have used the code below to compare the original expression and the rewrite, against the original expression computed with arbitrary precision then converted back to the working floating point precision. import numpy as np
import pytensor
import pytensor.tensor as pt
from mpmath import mp
mp.dps = 200
dtype_fn = np.float64
test_rewrite = ...
smallest_subnormal = np.finfo(dtype_fn).smallest_subnormal
tiny = np.finfo(dtype_fn).tiny # == np.finfo(dtype_fn).smallest_normal
eps = np.finfo(dtype_fn).eps
x = pt.scalar("x", dtype=dtype_fn)
if test_rewrite == 'log1pexp_log->log1p':
f1 = lambda x: pt.log1pexp(pt.log(x))
f2 = lambda x: pt.log1p(x)
fr = lambda x: mp.log1p(mp.exp(mp.log(x.item())))
elif test_rewrite == 'exp_minus_log->recip':
f1 = lambda x: pt.exp(-pt.log(x))
f2 = lambda x: 1/x
fr = lambda x: mp.exp(-mp.log(x.item()))
elif test_rewrite == 'log1pexp_minus_log->log1p_recip':
f1 = lambda x: pt.log1pexp(-pt.log(x))
f2 = lambda x: pt.switch(x > eps, pt.log1p(1/x), -np.log(x))
fr = lambda x: mp.log1p(mp.exp(-mp.log(x.item())))
fn1 = pytensor.function([x], f1(x), mode="FAST_RUN")
fn2 = pytensor.function([x], f2(x), mode="FAST_RUN")
print(f'{"argument":24} | {"rel_err_orig":24} | {"rel_err_rewrite":24}')
for z in [smallest_subnormal, np.exp((np.log(smallest_subnormal) + np.log(tiny))/2), tiny / 2, tiny, 1e-30, eps/8, eps, 1e-3, .125, 1, 8, 1000, 1e30, np.finfo(dtype_fn).max]:
z = np.array(z, dtype=dtype_fn)
zi = np.array(fr(z), dtype=dtype_fn)
print(f'{z:24} | {float((fn1(z) - zi)/zi):24} | {float((fn2(z) - zi) / zi):24}') For the rewrite proposed in this PR (
For the
Regarding the other rewrite I proposed,
|
Thanks for the careful analysis, that's invaluable! Sounds like we can go ahead with For the |
It's something I see in my analysis but apparently it does not make a big difference. I have the following two cases:
So bottom line, I think it's worth implementing the rewrite in this PR plus I can give it a try. However, I would need some guidance about where it's the right place (inside |
The rewrites that require a nan switch can go in the same function The log(1/x) -> -log(x) doesn't so it can go in a separate function. I guess the more general pattern here is Because it is a constant one of the logs will disappear and we replace a log and division by a log and a subtraction which should be more stable and faster? I guess we can also cover multiplication? Or do you see any reason to apply it to 1/x only? The kind of rewrite is up to you. the explicit function approach is usually more efficient. |
Re the other special case you have, you can always implement it outside of PyTensor/ at runtime and use it in your workflow. |
Description
This PR adds the simple rewrite
softplus(log(x)) -> log1p(x)
.I could also extend it to cover the case
softplus(-log(x)) -> log1p(1/x)
. However I have noticed that even the simpleexp(-log(x)) -> 1/x
is missing, so I wonder if there is an underlying reason to avoid such simplifications, that I am not aware of.Also, it would be helpful to get feedback from the community before proceeding further with the PR to ensure that the proposed simplification(s) align with the library's design principles and don't introduce any unintended consequences.
Likewise, I would like to know if the approach and location (within the file and test) of the change is appropriate. I am not sufficiently familiar with the code to understand whether the approach I have used is better/worse than a
PatternNodeRewriter
.Related Issue
softplus(log(x)) -> log1p(x)
#1451Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1452.org.readthedocs.build/en/1452/