Skip to content

Add rewrite for softplus(log(x)) -> log1p(x) #1452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

lciti
Copy link

@lciti lciti commented Jun 6, 2025

Description

This PR adds the simple rewrite softplus(log(x)) -> log1p(x).

I could also extend it to cover the case softplus(-log(x)) -> log1p(1/x). However I have noticed that even the simple exp(-log(x)) -> 1/x is missing, so I wonder if there is an underlying reason to avoid such simplifications, that I am not aware of.

Also, it would be helpful to get feedback from the community before proceeding further with the PR to ensure that the proposed simplification(s) align with the library's design principles and don't introduce any unintended consequences.
Likewise, I would like to know if the approach and location (within the file and test) of the change is appropriate. I am not sufficiently familiar with the code to understand whether the approach I have used is better/worse than a PatternNodeRewriter.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1452.org.readthedocs.build/en/1452/

@@ -453,6 +453,13 @@ def local_exp_log_nan_switch(fgraph, node):
new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype))
return [new_out]

# Case for softplus(log(x)) -> log1p(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick I prefer to refer to it by log1pexp, which we have as an alias to softplus:

log1pexp = softplus

Also we can add a similar case for log1mexp?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this (with the code below) and it works fine in its domain [0, 1]. I will add it too.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 7, 2025

However I have noticed that even the simple exp(-log(x)) -> 1/x is missing, so I wonder if there is an underlying reason to avoid such simplifications, that I am not aware of.

Hmm I don't know if that introduces numerical precision issues...

However the somewhat converse log(1/x) -> -log(x) seems stable and something we may want to do?

@ricardoV94 ricardoV94 added graph rewriting enhancement New feature or request labels Jun 7, 2025
data_invalid = data_valid - 2

x = fmatrix()
f = function([x], softplus(log(x)), mode=self.mode)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want you can check against the expected graph directly, something like

assert equal_computations(f.maker.fgraph.outputs, [pt.switch(x > 0, pt.log1p(x), np.asarray([[np.nan]], dtype="float32")])

Or something like that. This is not a request!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me a while to figure out how to make the test work.

It works fine if I apply it to a scalar rewritten output, like:

x = pt.scalar("x")
out = pt.softplus(pt.log(x))
new_out = rewrite_graph(out, include=("canonicalize", "stabilize", "specialize"))
equal_computations([new_out], [pt.switch(x >= 0, pt.log1p(x), pt.nan)])

But if I want to apply within the test (where the function is also applied to test data), I need to fix a few things.
First of all, the constants have to be matrices of the right type: int8 and float32, respectively (which was not obvious to me).
Then, for some reason, I could only make it work after overriding the compile mode:

        x = fmatrix()
        mode=get_mode('FAST_COMPILE').including("local_exp_log", "local_exp_log_nan_switch")
        f = function([x], softplus(log(x)), mode=mode)
        assert equal_computations(f.maker.fgraph.outputs, [pt.switch(x >= np.array([[0]], dtype=np.int8), pt.log1p(x), np.array([[np.nan]], dtype=np.float32))])

For some reason the mode used in the test (set at class level) seems to do something extra and equal_computations returns false, possibly due to:

nd_x.op: Elemwise(scalar_op=Switch,inplace_pattern=<frozendict {0: 1}>)
nd_y.op: Elemwise(scalar_op=Switch,inplace_pattern=<frozendict {}>)

I don't know enough of the internals of PyTensor to find a solution, except forcing the mode above inside this test.
Is it fine to override the class compile mode with the compile mode above?

Copy link

codecov bot commented Jun 7, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.12%. Comparing base (ff98ab8) to head (10bcb41).

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1452   +/-   ##
=======================================
  Coverage   82.12%   82.12%           
=======================================
  Files         211      211           
  Lines       49757    49762    +5     
  Branches     8819     8820    +1     
=======================================
+ Hits        40862    40867    +5     
  Misses       6715     6715           
  Partials     2180     2180           
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/math.py 89.28% <100.00%> (+0.03%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@lciti
Copy link
Author

lciti commented Jun 9, 2025

@ricardoV94 Thanks for the feedback!

Thank you for the suggested changes - noted.

Let me first report some results related to my question about rewriting the different expressions.

I have used the code below to compare the original expression and the rewrite, against the original expression computed with arbitrary precision then converted back to the working floating point precision.

import numpy as np
import pytensor
import pytensor.tensor as pt
from mpmath import mp
mp.dps = 200

dtype_fn = np.float64
test_rewrite = ...

smallest_subnormal = np.finfo(dtype_fn).smallest_subnormal
tiny = np.finfo(dtype_fn).tiny  # == np.finfo(dtype_fn).smallest_normal
eps = np.finfo(dtype_fn).eps

x = pt.scalar("x", dtype=dtype_fn)

if test_rewrite == 'log1pexp_log->log1p':
    f1 = lambda x: pt.log1pexp(pt.log(x))
    f2 = lambda x: pt.log1p(x)
    fr = lambda x: mp.log1p(mp.exp(mp.log(x.item())))
elif test_rewrite == 'exp_minus_log->recip':
    f1 = lambda x: pt.exp(-pt.log(x))
    f2 = lambda x: 1/x
    fr = lambda x: mp.exp(-mp.log(x.item()))
elif test_rewrite == 'log1pexp_minus_log->log1p_recip':
    f1 = lambda x: pt.log1pexp(-pt.log(x))
    f2 = lambda x: pt.switch(x > eps, pt.log1p(1/x), -np.log(x))
    fr = lambda x: mp.log1p(mp.exp(-mp.log(x.item())))

fn1 = pytensor.function([x], f1(x), mode="FAST_RUN")
fn2 = pytensor.function([x], f2(x), mode="FAST_RUN")

print(f'{"argument":24} | {"rel_err_orig":24} | {"rel_err_rewrite":24}')
for z in [smallest_subnormal, np.exp((np.log(smallest_subnormal) + np.log(tiny))/2), tiny / 2, tiny, 1e-30, eps/8, eps, 1e-3, .125, 1, 8, 1000, 1e30, np.finfo(dtype_fn).max]:
    z = np.array(z, dtype=dtype_fn)
    zi = np.array(fr(z), dtype=dtype_fn)
    print(f'{z:24} | {float((fn1(z) - zi)/zi):24} | {float((fn2(z) - zi) / zi):24}')

For the rewrite proposed in this PR (test_rewrite = 'log1pexp_log->log1p'), the results confirm that the rewrite simplifies the expression without affecting its numerical accuracy. There is sometimes a minor improvement within approx one or two decimal digits, which is probably not enough to warrant a @register_stabilize.

argument                 | rel_err_orig             | rel_err_rewrite         
                  5e-324 |                      0.0 |                      0.0
          3.3156184e-316 |                      0.0 |                      0.0
 1.1125369292536007e-308 |   -3.108624468950438e-14 |                      0.0
 2.2250738585072014e-308 |   2.7533531010703882e-14 |                      0.0
                   1e-30 |   2.2771100045278277e-15 |                      0.0
  2.7755575615628914e-17 |  -3.1086244689504383e-15 |                      0.0
   2.220446049250313e-16 |    2.331468351712829e-15 |                      0.0
                   0.001 |   2.1694883665334253e-16 |                      0.0
                   0.125 |   2.3565002770519656e-16 |                      0.0
                     1.0 |                      0.0 |                      0.0
                     8.0 |  -2.0211370946362213e-16 |                      0.0
                  1000.0 |                      0.0 |                      0.0
                   1e+30 |                      0.0 |                      0.0
 1.7976931348623157e+308 |                      0.0 |                      0.0

For the exp(-log(x)) suggested rewrite (test_rewrite = 'exp_minus_log->recip'), it looks like it's a safe rewrite. In general, floating point hardware and libraries are supposed to implement the reciprocal so it returns the correct floating point representation of the reciprocal of the input (this does not mean that the error is zero but it's the smallest possible in floating point). This is generally also true for exp and log but because there are two of them the error might compound and be slightly larger than for 1/x. The results are as follows. Again the improvement is minor so the main advantage is (marginally) speed, simplicity and opportunities for further optimisations down the line.

argument                 | rel_err_orig             | rel_err_rewrite         
                  5e-324 |                      nan |                      nan
          3.3156184e-316 |                      nan |                      nan
 1.1125369292536007e-308 |   3.1308289294429414e-14 |                      0.0
 2.2250738585072014e-308 |  -2.7422508708241367e-14 |                      0.0
                   1e-30 |   -2.251799813685248e-15 |                      0.0
  2.7755575615628914e-17 |   3.1086244689504383e-15 |                      0.0
   2.220446049250313e-16 |    -2.55351295663786e-15 |                      0.0
                   0.001 |  -2.2737367544323206e-16 |                      0.0
                   0.125 |   -2.220446049250313e-16 |                      0.0
                     1.0 |                      0.0 |                      0.0
                     8.0 |    2.220446049250313e-16 |                      0.0
                  1000.0 |   2.1684043449710089e-16 |                      0.0
                   1e+30 |     2.45227231256843e-15 |                      0.0
 1.7976931348623157e+308 |    2.398081733190338e-14 |                      0.0

Regarding the other rewrite I proposed, log1pexp(-log(x)) (i.e. test_rewrite = 'log1pexp_minus_log->log1p_recip'), rewriting it as log1p(1/x) works fine for numbers larger than 1/DBL_MAX ≅ 5.6e-309 (this includes all normal floating point numbers). For values between the absolute smallest float (smallest_subnormal = 5e-324) and 5.6e-309, the issue is that 1/x becomes +Inf and so is log1p(1/x). This is not the case for the original expression, thanks to the log. One simple solution is to notice that log1p(1/x)=log1p(x)-log(x) and that for small enough x we have log1p(x)≅x<<abs(log(x)). For example for x=eps we have log1p(x)-log(x) = -log(x) in floating point (x=eps is just a possible threshold, not the only one; x=tiny would also work). So the rewrite could be pt.log1pexp(-pt.log(x)) -> pt.switch(x > np.finfo(old_out.dtype).eps, pt.log1p(1/x), -pt.log(x)) (this will also take care of dealing with negative x automatically, without the need for an extra switch). The results are as follows:

argument                 | rel_err_orig             | rel_err_rewrite         
                  5e-324 |                      0.0 |                      0.0
          3.3156184e-316 |                      0.0 |                      0.0
 1.1125369292536007e-308 |                      0.0 |                      0.0
 2.2250738585072014e-308 |                      0.0 |                      0.0
                   1e-30 |                      0.0 |                      0.0
  2.7755575615628914e-17 |                      0.0 |                      0.0
   2.220446049250313e-16 |                      0.0 |                      0.0
                   0.001 |                      0.0 |                      0.0
                   0.125 |  -2.0211370946362213e-16 |                      0.0
                     1.0 |                      0.0 |                      0.0
                     8.0 |   2.3565002770519656e-16 |                      0.0
                  1000.0 |   2.1694883665334253e-16 |                      0.0
                   1e+30 |     2.45227231256843e-15 |                      0.0
 1.7976931348623157e+308 |    2.398081733190338e-14 |                      0.0

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 10, 2025

Thanks for the careful analysis, that's invaluable!

Sounds like we can go ahead with exp(-log(x)) -> 1/x then.

For the log1pexp(-log(x)) I'm not sure the extra switch / graph complexity is worth it? Is this expression something you were seeing in your applications?

@lciti
Copy link
Author

lciti commented Jun 11, 2025

It's something I see in my analysis but apparently it does not make a big difference.

I have the following two cases:

  1. pt.log1pexp(pt.log(pt.expm1(f(x))))
    that eventually becomes just f(x) (with a switch) thanks to the rewrite in this PR;
    this simplification is useful because it avoids hitting +inf as soon as pt.expm1(f(x)) becomes +inf
  2. pt.log1pexp(-pt.log(pt.expm1(g(x))))
    that in principle is equal to -pt.log1mexp(-g(x)),
    but even if we implemented log1pexp(-log(x)) -> log1p(1/x) it would not eventually lead to the simplified form without further super-specialised rewrites;
    however this simplification has very minor (if any) numerical benefits as far as I can tell.

So bottom line, I think it's worth implementing the rewrite in this PR plus log1mexp(log(x)), exp(-log(x)) -> 1/x and possibly log(1/x) -> -log(x).

I can give it a try. However, I would need some guidance about where it's the right place (inside local_exp_log_nan_switch for exp(-log(x)) -> 1/x I assume) and approach (handcrafted vs PatternNodeRewriter). Thanks!

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 11, 2025

The rewrites that require a nan switch can go in the same function

The log(1/x) -> -log(x) doesn't so it can go in a separate function. I guess the more general pattern here is log(const/x) -> log(const) - log(x) and log(x/const) -> log(x) - log(const), provided const is positive (so that signs can't change and we don't risk creating nans)

Because it is a constant one of the logs will disappear and we replace a log and division by a log and a subtraction which should be more stable and faster? I guess we can also cover multiplication? Or do you see any reason to apply it to 1/x only?

The kind of rewrite is up to you. the explicit function approach is usually more efficient.

@ricardoV94
Copy link
Member

Re the other special case you have, you can always implement it outside of PyTensor/ at runtime and use it in your workflow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Missing rewrite softplus(log(x)) -> log1p(x)
2 participants