Skip to content

Simplify implementation and tests in #1534 #1555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 24, 2022
Merged

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented May 23, 2022

@matbesancon This PR simplifies the implementation and tests in #1534. It also makes the code a bit more consistent with the conventions in ChainRules, fixes some type stability issue, and simplifies the handling of non-finite values in _logpdf.

Tests pass locally with Julia 1.7 and JuliaDiff/ChainRulesTestUtils.jl#247.

Edit: The CRTU PR was merged and released.

This was referenced May 23, 2022
@devmotion
Copy link
Member Author

@matbesancon Tests all pass it seems, with Julia 1.3, 1.7, and nightly.

@matbesancon matbesancon merged commit 9234155 into cr-dirichlet May 24, 2022
@matbesancon matbesancon deleted the dw/dirichlet branch May 24, 2022 13:57
matbesancon added a commit that referenced this pull request Jul 31, 2022
* constructor frule

* frule tested

* rrule tests

* logpdf test

* signature for conflict

* TestUtils out of Project

* ChainRules itself not needed (yet?)

* remove checkarg

* Update src/multivariate/dirichlet.jl

Co-authored-by: David Widmann <[email protected]>

* Update test/dirichlet.jl

Co-authored-by: David Widmann <[email protected]>

* Update test/dirichlet.jl

Co-authored-by: David Widmann <[email protected]>

* Update test/dirichlet.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/multivariate/dirichlet.jl

Co-authored-by: David Widmann <[email protected]>

* conflict

* eltype instability

* single loop

* fix tests

* forward finite diff

* switch to broadcast

* fix broadcast

* switch off-support value to NaN

* Update src/multivariate/dirichlet.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/multivariate/dirichlet.jl

Co-authored-by: David Widmann <[email protected]>

* do not assume inplace

* fixed temp

* Simplify implementation and tests in #1534 (#1555)

* Simplify implementation and tests

* Precompute `digamma(alpha0)`

* Relax type signature

Co-authored-by: David Widmann <[email protected]>
@AlexRobson
Copy link

AlexRobson commented Aug 23, 2022

fwiw, running something completely different these tests were failing in 1.6, though all passing in 1.8:

In 1.6:

Got exception outside of a @test
  DomainError with [0.100746, 0.151104, 0.198606, 0.106677, 0.417981, 0.234742, 0.910606, 0.501805, 0.861964, -0.00290635]:
  Dirichlet: alpha must be a positive vector.

Is this an RNG issue, or something tied up with the julia version?

If this diagnosis is correct, assuming it's the former, and if chainrules are going to start being added here, It may be worth specifying the tangents in the the test_rrule to avoid RNG issues. It looks like a lot of the testsets are seeded here, so I think the equivalent would be:

Random.MersenneTwister(1)
test_rrule(Distributions._logpdf, d, x ⊢ CRTU.rand_tangent(x); fdm=fdm, rtol=1e-5, nans=true)

@devmotion
Copy link
Member Author

Since tests pass (see e.g. #1606) I don't see any reason to tweak seeds right now. If it becomes necessary, we could possibly specify a seed but that is also not guaranteed to yield the same random numbers on different Julia versions (and OS IIRC). (BTW @testset already resets the seed, so it might not be necessary to specify a seed for every test but could be done once initially: https://docs.julialang.org/en/v1/stdlib/Test/#Test.@testset) Generally, I don't think we want to specify tangents manually if it can be avoided - specific choices make it easier to miss broken cases, and I feel like by the test_rrule API it's intended to provide tangents manually only if specific choices are needed.

@devmotion
Copy link
Member Author

Is this an RNG issue, or something tied up with the julia version?

Possibly finite differencing fails for some unfortunate cases. We could maybe also try to pick values that are further away from zero or tweak the finite differencing method.

@AlexRobson
Copy link

Since tests pass (see e.g. #1606) I don't see any reason to tweak seeds right now. If it becomes necessary, we could possibly specify a seed but that is also not guaranteed to yield the same random numbers on different Julia versions (and OS IIRC). (BTW @testset already resets the seed, so it might not be necessary to specify a seed for every test but could be done once initially: https://docs.julialang.org/en/v1/stdlib/Test/#Test.@testset) Generally, I don't think we want to specify tangents manually if it can be avoided - specific choices make it easier to miss broken cases, and I feel like by the test_rrule API it's intended to provide tangents manually only if specific choices are needed.

Good points. I was considering the cases where adding in other unrelated testsets can change the RNG and I'd actually forgotten testset resets the seed, so this is essentially equivalent. For version changes, technically StableRNGs.jl could be used, but for here I also now don't see any particular reason to go out of the way to tweak. I just happened to be using the wrong version of julia when playing around with what I was working on when testing :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants