Skip to content

Conversation

@giladturok
Copy link
Contributor

@giladturok giladturok commented Oct 30, 2024

Implement nested R-hat for Markov chain Monte Carlo (MCMC) diagnostic.

The potential scale reduction factor, also known as R-hat, is a popular MCMC diagnostic from Gelman and Rubin.
R-hat detects convergence of MCMC chains by comparing within chain variance to between chain variance.

Nested r-hat from Margossian et al. better predicts convergence when running thousands of short chains on modern hardware. Nested r-hat uses superchains, collections of MCMC chains, and compares within and between chain and superchain variance.

I am seeking feedback on the code style + API design. The code is somewhat complicated by requiring input_array to have 4 dimensions -- num_superchains, num_chains, num_samples, and num_params -- where most users may expect only 3 (or 2). Tests are also still needed, as well as a brief doc explanation of the math.

Quick nit: why does the existing R-hat function return the potential scale factor after flattening along the sample and chain dimensions? I followed this convention for my implementation.

Addresses issue #278 .

@giladturok
Copy link
Contributor Author

@charlesm93: you may be interested in this.

(For everyone else: Charles is the author of the nested R-hat paper.)

@giladturok
Copy link
Contributor Author

Code style tests are failing because Flake8 is finding extra spaces around operators on line 43 of smc/resampling:

Screenshot 2024-10-30 at 4 22 19 PM

However I do not see what the problem is. Here's the line:

If anyone sees how to fix this issue, please let me know.

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Great start. Let me know when you add some test.

NDArray of the resulting statistics (r-hat), with the chain and sample dimensions squeezed.
"""
assert input_array.ndim == 4, "The input array must have 4 dimensions."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should relax the ndim, as our input could have multiple dimensions of event shape (ie the random variable is non-scaler).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use keepdims=True and it should works.

@junpenglao
Copy link
Member

Quick nit: why does the existing R-hat function return the potential scale factor after flattening along the sample and chain dimensions? I followed this convention for my implementation.

The current r-hat function does not flatten the sample, but rather squeeze it, so if you have a random variable with shape=(2, 5), the output result could be
shape=(1, 1, 2, 5) or (1, 2, 5, 1)
doing a squzze makes it return rhat the same shape as the random variable.

@giladturok
Copy link
Contributor Author

Thanks so much for the quick feedback! I'll continue working on this next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants