diff --git a/edward/__init__.py b/edward/__init__.py index 4997a89f3..4a2d0df35 100644 --- a/edward/__init__.py +++ b/edward/__init__.py @@ -14,7 +14,7 @@ HMC, MetropolisHastings, SGLD, SGHMC, \ KLpq, KLqp, ReparameterizationKLqp, ReparameterizationKLKLqp, \ ReparameterizationEntropyKLqp, ScoreKLqp, ScoreKLKLqp, ScoreEntropyKLqp, \ - ScoreRBKLqp, WakeSleep, GANInference, BiGANInference, WGANInference, \ + ScoreRBKLqp, RejectionSamplingKLqp, WakeSleep, GANInference, BiGANInference, WGANInference, \ ImplicitKLqp, MAP, Laplace, complete_conditional, Gibbs from edward.models import RandomVariable from edward.util import check_data, check_latent_vars, copy, dot, \ @@ -52,6 +52,7 @@ 'ScoreKLKLqp', 'ScoreEntropyKLqp', 'ScoreRBKLqp', + 'RejectionSamplingKLqp', 'WakeSleep', 'GANInference', 'BiGANInference', diff --git a/edward/inferences/__init__.py b/edward/inferences/__init__.py index 38262fcb7..fd22d9e8c 100644 --- a/edward/inferences/__init__.py +++ b/edward/inferences/__init__.py @@ -42,6 +42,7 @@ 'ScoreKLKLqp', 'ScoreEntropyKLqp', 'ScoreRBKLqp', + 'RejectionSamplingKLqp', 'Laplace', 'MAP', 'MetropolisHastings', diff --git a/edward/inferences/inference.py b/edward/inferences/inference.py index a7cea84d2..ac69bcac4 100644 --- a/edward/inferences/inference.py +++ b/edward/inferences/inference.py @@ -123,7 +123,6 @@ def run(self, variables=None, use_coordinator=True, *args, **kwargs): Passed into `initialize`. """ self.initialize(*args, **kwargs) - if variables is None: init = tf.global_variables_initializer() else: @@ -144,6 +143,7 @@ def run(self, variables=None, use_coordinator=True, *args, **kwargs): for _ in range(self.n_iter): info_dict = self.update() + print(info_dict) self.print_progress(info_dict) self.finalize() diff --git a/edward/inferences/klpq.py b/edward/inferences/klpq.py index 11008ff4b..b2a14a309 100644 --- a/edward/inferences/klpq.py +++ b/edward/inferences/klpq.py @@ -32,7 +32,7 @@ class KLpq(VariationalInference): with respect to $\\theta$. - In conditional inference, we infer $z` in $p(z, \\beta + In conditional inference, we infer $z$ in $p(z, \\beta \mid x)$ while fixing inference over $\\beta$ using another distribution $q(\\beta)$. During gradient calculation, instead of using the model's density diff --git a/edward/inferences/klqp.py b/edward/inferences/klqp.py index 3370e1fcf..42838c79b 100644 --- a/edward/inferences/klqp.py +++ b/edward/inferences/klqp.py @@ -6,7 +6,8 @@ import tensorflow as tf from edward.inferences.variational_inference import VariationalInference -from edward.models import RandomVariable +from edward.models import RandomVariable, Gamma +from edward.samplers import GammaRejectionSampler from edward.util import copy, get_descendants try: @@ -616,6 +617,62 @@ def build_loss_and_gradients(self, var_list): return build_score_rb_loss_and_gradients(self, var_list) +class RejectionSamplingKLqp(VariationalInference): + + """ + """ + + def __init__(self, latent_vars=None, data=None, rejection_sampler_vars=None): + """Create an inference algorithm. + + # TODO: update me + + Args: + latent_vars: list of RandomVariable or + dict of RandomVariable to RandomVariable. + Collection of random variables to perform inference on. If + list, each random variable will be implictly optimized using a + `Normal` random variable that is defined internally with a + free parameter per location and scale and is initialized using + standard normal draws. The random variables to approximate + must be continuous. + """ + if isinstance(latent_vars, list): + with tf.variable_scope(None, default_name="posterior"): + latent_vars_dict = {} + continuous = \ + ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') + for z in latent_vars: + if not hasattr(z, 'support') or z.support not in continuous: + raise AttributeError( + "Random variable {} is not continuous or a random " + "variable with supported continuous support.".format(z)) + batch_event_shape = z.batch_shape.concatenate(z.event_shape) + loc = tf.Variable(tf.random_normal(batch_event_shape)) + scale = tf.nn.softplus( + tf.Variable(tf.random_normal(batch_event_shape))) + latent_vars_dict[z] = Normal(loc=loc, scale=scale) + latent_vars = latent_vars_dict + del latent_vars_dict + super(RejectionSamplingKLqp, self).__init__(latent_vars, data) + self.rejection_sampler_vars = rejection_sampler_vars + + def initialize(self, n_samples=1, *args, **kwargs): + """Initialize inference algorithm. It initializes hyperparameters + and builds ops for the algorithm's computation graph. + + Args: + n_samples: int, optional. + Number of samples from variational model for calculating + stochastic gradients. + """ + self.n_samples = n_samples + return super(RejectionSamplingKLqp, self).initialize(*args, **kwargs) + + def build_loss_and_gradients(self, var_list): + return build_rejection_sampling_loss_and_gradients(self, var_list) + + def build_reparam_loss_and_gradients(inference, var_list): """Build loss function. Its automatic differentiation is a stochastic gradient of @@ -1127,3 +1184,90 @@ def build_score_rb_loss_and_gradients(inference, var_list): grads_vars.extend(model_vars) grads_and_vars = list(zip(grads, grads_vars)) return loss, grads_and_vars + + +def build_rejection_sampling_loss_and_gradients(inference, var_list, epsilon=None): + """ + """ + rej_samplers = { + Gamma: GammaRejectionSampler + } + + rep = [0.0] * inference.n_samples + cor = [0.0] * inference.n_samples + base_scope = tf.get_default_graph().unique_name("inference") + '/' + for s in range(inference.n_samples): + # Form dictionary in order to replace conditioning on prior or + # observed variable with conditioning on a specific value. + scope = base_scope + tf.get_default_graph().unique_name("sample") + dict_swap = {} + for x, qx in six.iteritems(inference.data): + if isinstance(x, RandomVariable): + if isinstance(qx, RandomVariable): + qx_copy = copy(qx, scope=scope) + dict_swap[x] = qx_copy.value() + else: + dict_swap[x] = qx + + p_log_prob = 0. + q_log_prob = 0. + r_log_prob = 0. + + for z, qz in six.iteritems(inference.latent_vars): + # Copy q(z) to obtain new set of posterior samples. + qz_copy = copy(qz, scope=scope) + sampler = rej_samplers[qz_copy.__class__](density=qz) + + if epsilon is not None: # temporary + pass + else: + dict_swap[z] = qz_copy.value() + print('sample:', dict_swap[z]) + epsilon = sampler.h_inverse(dict_swap[z]) + + dict_swap[z] = sampler.h(epsilon) + q_log_prob += tf.reduce_sum( + inference.scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z])) + r_log_prob += -tf.log(tf.gradients(dict_swap[z], epsilon)) + + for z in six.iterkeys(inference.latent_vars): + z_copy = copy(z, dict_swap, scope=scope) + p_log_prob += tf.reduce_sum( + inference.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) + + for x in six.iterkeys(inference.data): + if isinstance(x, RandomVariable): + x_copy = copy(x, dict_swap, scope=scope) + p_log_prob += tf.reduce_sum( + inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) + + rep[s] = p_log_prob + cor[s] = tf.stop_gradient(p_log_prob) * (q_log_prob - r_log_prob) + + rep = tf.reduce_mean(rep) + cor = tf.reduce_mean(cor) + q_entropy = tf.reduce_sum([ + tf.reduce_sum(qz.entropy()) + for z, qz in six.iteritems(inference.latent_vars)]) + reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) + + loss = -(rep + q_entropy - reg_penalty) + + if inference.logging: + tf.summary.scalar("loss/reparam_objective", rep, + collections=[inference._summary_key]) + tf.summary.scalar("loss/correction_term", cor, + collections=[inference._summary_key]) + tf.summary.scalar("loss/q_entropy", q_entropy, + collections=[inference._summary_key]) + tf.summary.scalar("loss/reg_penalty", reg_penalty, + collections=[inference._summary_key]) + + g_rep = tf.gradients(rep, var_list) + g_cor = tf.gradients(cor, var_list) + g_entropy = tf.gradients(q_entropy, var_list) + + grad_summands = zip(*[g_rep, g_cor, g_entropy]) + grads = [tf.reduce_sum(summand) for summand in grad_summands] + grads_and_vars = list(zip(grads, var_list)) + return loss, grads_and_vars diff --git a/edward/inferences/variational_inference.py b/edward/inferences/variational_inference.py index 171eca56b..fd1589782 100644 --- a/edward/inferences/variational_inference.py +++ b/edward/inferences/variational_inference.py @@ -67,6 +67,8 @@ def initialize(self, optimizer=None, var_list=None, use_prettytensor=False, self.loss, grads_and_vars = self.build_loss_and_gradients(var_list) + self.grads_and_vars = grads_and_vars + if self.logging: tf.summary.scalar("loss", self.loss, collections=[self._summary_key]) for grad, var in grads_and_vars: @@ -151,7 +153,9 @@ def update(self, feed_dict=None): feed_dict[key] = value sess = get_session() - _, t, loss = sess.run([self.train, self.increment_t, self.loss], feed_dict) + # _, t, loss = sess.run([self.train, self.increment_t, self.loss], feed_dict) + # TODO: delete me + _, t, loss, grads_and_vars_debug = sess.run([self.train, self.increment_t, self.loss, self.grads_and_vars], feed_dict) if self.debug: sess.run(self.op_check, feed_dict) @@ -161,7 +165,7 @@ def update(self, feed_dict=None): summary = sess.run(self.summarize, feed_dict) self.train_writer.add_summary(summary, t) - return {'t': t, 'loss': loss} + return {'t': t, 'loss': loss, 'grads_and_vars_debug': grads_and_vars_debug} def print_progress(self, info_dict): """Print progress to output. diff --git a/edward/samplers/__init__.py b/edward/samplers/__init__.py new file mode 100644 index 000000000..4bb20a107 --- /dev/null +++ b/edward/samplers/__init__.py @@ -0,0 +1,15 @@ +""" +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from edward.samplers.rejection import * + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'GammaRejectionSampler', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/edward/samplers/rejection.py b/edward/samplers/rejection.py new file mode 100644 index 000000000..d4a82471e --- /dev/null +++ b/edward/samplers/rejection.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import tensorflow as tf + + +class GammaRejectionSampler: + + # As implemented in https://github.com/blei-lab/ars-reparameterization/blob/master/gamma/demo.ipynb + + def __init__(self, density): + self.alpha = density.parameters['concentration'] + self.beta = density.parameters['rate'] + + def h(self, epsilon): + a = self.alpha - (1. / 3) + b = tf.sqrt(9 * self.alpha - 3) + c = 1 + (epsilon / b) + d = a * c**3 + return d / self.beta + + def h_inverse(self, z): + a = self.alpha - (1. / 3) + b = tf.sqrt(9 * self.alpha - 3) + c = self.beta * z / a + d = c**(1 / 3) + return b * (d - 1) + + @staticmethod + def log_prob_s(epsilon): + return -0.5 * (tf.log(2 * math.pi) + epsilon**2) diff --git a/scrap.ipynb b/scrap.ipynb new file mode 100644 index 000000000..e87484cf2 --- /dev/null +++ b/scrap.ipynb @@ -0,0 +1,409 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from edward.models import Normal\n", + "import numpy as np\n", + "import tensorflow as tf" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# --- manually!" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "t = 0.1\n", + "delta = 10e-3\n", + "eta = 1e-1\n", + "\n", + "vars_manual = [1., 2.]\n", + "grads_manual = [3.1018744, 1.5509372]\n", + "s_n_manual = [0., 0.]\n", + "n_manual = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# grads_manual = [2.7902498, 1.241244]\n", + "grads_manual = [2.6070995, 1.0711095]\n", + "n_manual = 3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "collapsed": false, + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n", + "s_n[i]: 2.1597428437294326\n", + "p_n_first: 0.05837280797536004\n", + "p_n_second: 0.404922832044\n", + "p_n: 0.0236364827198\n", + "var: 0.694741731383\n", + "\n", + "s_n[i]: 0.44822725824311604\n", + "p_n_first: 0.05837280797536004\n", + "p_n_second: 0.598982532688\n", + "p_n: 0.0349642923612\n", + "var: 1.80355358733\n", + "\n" + ] + } + ], + "source": [ + "print(n_manual)\n", + "\n", + "for i in [0, 1]:\n", + " s_n_manual[i] = (t * grads_manual[i]**2) + (1 - t)*s_n_manual[i]\n", + " print('s_n[i]:', s_n_manual[i])\n", + " p_n_first = (eta * n_manual**(-.5 + delta))\n", + " p_n_second = (1 + np.sqrt(s_n_manual[i]))**(-1)\n", + " p_n = p_n_first * p_n_second\n", + " print('p_n_first:', p_n_first)\n", + " print('p_n_second:', p_n_second)\n", + " print('p_n:', p_n)\n", + " \n", + " vars_manual[i] += -p_n * grads_manual[i]\n", + " print('var:', vars_manual[i])\n", + " \n", + " print('')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# tensorflow..." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "t = 0.1\n", + "delta = 10e-3\n", + "eta = 1e-1\n", + " \n", + "def alp_optimizer_apply_gradients(n, s_n, grads_and_vars):\n", + " ops = []\n", + " for i, (grad, var) in enumerate(grads_and_vars):\n", + " updated_s_n = s_n[i].assign( (t * grad**2) + (1 - t) * s_n[i] )\n", + "\n", + " p_n_first = eta * n**(-.5 + delta)\n", + " p_n_second = (1 + tf.sqrt(updated_s_n[i]))**(-1)\n", + " p_n = p_n_first * p_n_second\n", + "\n", + " updated_var = var.assign_add(-p_n * grad)\n", + " ops.append((updated_s_n[i], p_n_first, p_n_second, p_n, updated_var))\n", + "# increment_n = n.assign_add(1.)\n", + "# ops.append(increment_n)\n", + " return ops" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "collapsed": false, + "scrolled": false + }, + "outputs": [], + "source": [ + "w1 = tf.Variable(tf.constant(1.))\n", + "w2 = tf.Variable(tf.constant(2.))\n", + "var_list = [w1, w2]\n", + "\n", + "x = tf.constant([3., 4., 5.])\n", + "y = tf.constant([.8, .1, .1])\n", + "\n", + "pred = tf.nn.softmax(x * w1 * w2)\n", + "loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred)))\n", + "grads = tf.gradients(loss, var_list)\n", + "grads_and_vars = list(zip(grads, var_list))\n", + "\n", + "s_n = tf.Variable(tf.zeros(2))\n", + "n = tf.Variable(tf.constant(1.))\n", + "\n", + "train = alp_optimizer_apply_gradients(n, s_n, grads_and_vars)\n", + "increment_n = n.assign_add(1.)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "collapsed": false, + "scrolled": false + }, + "outputs": [], + "source": [ + "sess = tf.InteractiveSession()\n", + "init = tf.global_variables_initializer()\n", + "sess.run(init)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# starting vals:\n", + "\n", + "# grads_and_vars: [(3.1018744, 1.0), (1.5509372, 2.0)]\n", + "# s_n: array([ 0., 0.], dtype=float32)\n", + "# n: 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "collapsed": false, + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the_grads_and_vars: [(3.1018744, 1.0), (1.5509372, 2.0)]\n", + "\n", + "s_n_i: 0.962162\n", + "p_n_first_i: 0.1\n", + "p_n_second_i: 0.504821\n", + "p_n_i: 0.0504821\n", + "updated_var: 0.843411\n", + "\n", + "s_n_i: 0.240541\n", + "p_n_first_i: 0.1\n", + "p_n_second_i: 0.670939\n", + "p_n_i: 0.0670939\n", + "updated_var: 1.89594\n", + "\n", + "the_grads_and_vars: [(2.7902498, 0.84341073), (1.241244, 1.8959416)]\n", + "\n", + "s_n_i: 1.6445\n", + "p_n_first_i: 0.0712025\n", + "p_n_second_i: 0.438139\n", + "p_n_i: 0.0311966\n", + "updated_var: 0.756364\n", + "\n", + "s_n_i: 0.370555\n", + "p_n_first_i: 0.0712025\n", + "p_n_second_i: 0.621607\n", + "p_n_i: 0.04426\n", + "updated_var: 1.841\n", + "\n" + ] + } + ], + "source": [ + "for _ in range(2):\n", + " the_n, the_s_n, the_grads_and_vars = sess.run([n, s_n, grads_and_vars])\n", + "# print('the_n:', the_n)\n", + "# print('the_s_n:', the_s_n)\n", + " print('the_grads_and_vars:', the_grads_and_vars)\n", + " \n", + "# the_grads_and_vars, results = sess.run([grads_and_vars, train])\n", + " results = sess.run(train)\n", + "\n", + " print('')\n", + " for result in results:\n", + " s_n_i, p_n_first_i, p_n_second_i, p_n_i, updated_var = result\n", + " print('s_n_i:', s_n_i)\n", + " print('p_n_first_i:', p_n_first_i)\n", + " print('p_n_second_i:', p_n_second_i)\n", + " print('p_n_i:', p_n_i)\n", + " print('updated_var:', updated_var)\n", + " print('')\n", + " sess.run(increment_n)\n", + " \n", + "\n", + "# \n", + "# np.testing.assert_almost_equal(w1_, var_manual[0], 7)\n", + "# np.testing.assert_almost_equal(w2_, var_manual[1], 7)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[(2.7902498, 0.75636435), (1.241244, 1.8410041)]" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "the_grads_and_vars" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# after 1\n", + "\n", + "# s_n[i]: 0.962162479337536\n", + "# p_n_first: 0.1\n", + "# p_n_second: 0.504821343702\n", + "# p_n: 0.0504821343702\n", + "# var: 0.84341075974\n", + "\n", + "# s_n[i]: 0.240540619834384\n", + "# p_n_first: 0.1\n", + "# p_n_second: 0.670938574622\n", + "# p_n: 0.0670938574622\n", + "# var: 1.89594164057\n", + "\n", + "s_n_i: 0.962162\n", + "p_n_first_i: 0.1\n", + "p_n_second_i: 0.504821\n", + "p_n_i: 0.0504821\n", + "updated_var: 0.843411\n", + "\n", + "s_n_i: 0.240541\n", + "p_n_first_i: 0.1\n", + "p_n_second_i: 0.670939\n", + "p_n_i: 0.0670939\n", + "updated_var: 1.89594\n", + "\n", + "# after 2\n", + "\n", + "# 2\n", + "# s_n[i]: 1.6444956260437862\n", + "# p_n_first: 0.07120250977985358\n", + "# p_n_second: 0.438139347908\n", + "# p_n: 0.0311966212044\n", + "# var: 0.756364393664\n", + "\n", + "# s_n[i]: 0.3705552246045456\n", + "# p_n_first: 0.07120250977985358\n", + "# p_n_second: 0.621607393593\n", + "# p_n: 0.0442600065216\n", + "# var: 1.84100417304\n", + "\n", + "s_n_i: 1.6445\n", + "p_n_first_i: 0.0712025\n", + "p_n_second_i: 0.438139\n", + "p_n_i: 0.0311966\n", + "updated_var: 0.756364\n", + "\n", + "s_n_i: 0.370555\n", + "p_n_first_i: 0.0712025\n", + "p_n_second_i: 0.621607\n", + "p_n_i: 0.04426\n", + "updated_var: 1.841\n", + " \n", + "# after 3\n", + "\n", + "# s_n[i]: 2.1597428437294326\n", + "# p_n_first: 0.05837280797536004\n", + "# p_n_second: 0.404922832044\n", + "# p_n: 0.0236364827198\n", + "# var: 0.694741731383\n", + "\n", + "# s_n[i]: 0.44822725824311604\n", + "# p_n_first: 0.05837280797536004\n", + "# p_n_second: 0.598982532688\n", + "# p_n: 0.0349642923612\n", + "# var: 1.80355358733" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/inferences/test_klqp.py b/tests/inferences/test_klqp.py index baf729394..81f23361a 100644 --- a/tests/inferences/test_klqp.py +++ b/tests/inferences/test_klqp.py @@ -6,7 +6,10 @@ import numpy as np import tensorflow as tf -from edward.models import Bernoulli, Normal +from edward.models import Bernoulli, Normal, Dirichlet, Multinomial, \ + Gamma, Poisson + +from edward.inferences.klqp import build_rejection_sampling_loss_and_gradients class test_klqp_class(tf.test.TestCase): @@ -43,6 +46,76 @@ def _test_normal_normal(self, Inference, default, *args, **kwargs): self.assertEqual(new_t, 0) self.assertNotEqual(old_variables, new_variables) + def _test_poisson_gamma(self, Inference, *args, **kwargs): + with self.test_session() as sess: + x_data = np.array([2, 8, 3, 6, 1], dtype=np.int32) + + rate = Gamma(5.0, 1.0) + x = Poisson(rate=rate, sample_shape=5) + + qalpha = tf.nn.softplus(tf.Variable(tf.random_normal([]), name='qalpha')) + qbeta = tf.nn.softplus(tf.Variable(tf.random_normal([]), name='qbeta')) + qgamma = Gamma(qalpha, qbeta, allow_nan_stats=False) + + # sum(x_data) = 20 + # len(x_data) = 5 + # analytic solution: Gamma(alpha=5+20, beta=1+5) + inference = Inference({rate: qgamma}, data={x: x_data}) + + inference.run(*args, **kwargs) + + self.assertAllClose(tf.nn.softplus(qalpha).eval(), 25., atol=1e-2) + self.assertAllClose(tf.nn.softplus(qbeta).eval(), 6., atol=1e-2) + + def _test_multinomial_dirichlet(self, Inference, *args, **kwargs): + with self.test_session() as sess: + x_data = tf.constant([2, 7, 1], dtype=np.float32) + + probs = Dirichlet([1., 1., 1.]) + x = Multinomial(total_count=10.0, probs=probs) + + qalpha = tf.Variable(tf.random_normal([3])) + qprobs = Dirichlet(qalpha) + + # analytic solution: Dirichlet(alpha=[1+2, 1+7, 1+1]) + inference = Inference({probs: qprobs}, data={x: x_data}) + + inference.run(*args, **kwargs) + + def _test_build_rejection_sampling_loss_and_gradients(self, *args, **kwargs): + with self.test_session() as sess: + x_data = np.array([3, 3, 3, 3, 0], dtype=np.float32) + + rate = Gamma(1.0, 1.0) + x = Poisson(rate=rate, sample_shape=5) + + _qalpha = tf.Variable(-0.52817175, name='qalpha') + _qbeta = tf.Variable(-1.07296862, name='qbeta') + var_list = [_qalpha, _qbeta] + + qalpha = tf.exp(_qalpha) + 1 + qbeta = tf.exp(_qbeta) + qgamma = Gamma(qalpha, qbeta, allow_nan_stats=False) + + class DummyInference: + n_samples = 1 + data = {x: x_data} + latent_vars = {rate: qgamma} + scale = {} + logging = False + + tf.global_variables_initializer().run() + + expected_g_reparam = np.array([-10.348131898560453, 31.81539831675293]) + expected_g_score = np.array([0.30550423741109256, 0.0]) + expected_g_entropy = np.array([0.28863888798339055, -1.0]) + + loss, grads_and_vars = build_rejection_sampling_loss_and_gradients(DummyInference(), + var_list, epsilon=tf.constant(0.86540763)) + + self.assertAllClose([g.eval() for g, v in grads_and_vars], + expected_g_reparam + expected_g_score + expected_g_entropy, rtol=1e-6, atol=1e-6) + def _test_model_parameter(self, Inference, *args, **kwargs): with self.test_session() as sess: x_data = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 1]) @@ -109,6 +182,19 @@ def test_score_rb_klqp(self): ed.ScoreRBKLqp, default=True, n_samples=5, n_iter=5000) self._test_model_parameter(ed.ScoreRBKLqp, n_iter=50) + def test_rejection_sampling_klqp(self): + self._test_build_rejection_sampling_loss_and_gradients() + self._test_poisson_gamma( + ed.RejectionSamplingKLqp, + n_samples=1, + n_iter=50, + optimizer='rmsprop', + global_step=tf.Variable(0, trainable=False, name="global_step") + ) + # self._test_multinomial_dirichlet( + # ed.RejectionSamplingKLqp, n_samples=5, n_iter=5000) + + if __name__ == '__main__': ed.set_seed(42) tf.test.main() diff --git a/tests/samplers/test_rejection.py b/tests/samplers/test_rejection.py new file mode 100644 index 000000000..3cc1f3eb4 --- /dev/null +++ b/tests/samplers/test_rejection.py @@ -0,0 +1,24 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from edward.models import Gamma +from edward.samplers import GammaRejectionSampler + + +class test_rejection_samplers_class(tf.test.TestCase): + + def test_gamma_rejection_sampler(self): + with self.test_session() as sess: + gamma = Gamma(4., 2.) + epsilon = tf.constant(.5) + sampler = GammaRejectionSampler(density=gamma) + z = sampler.h(epsilon) + + self.assertAllClose(sampler.h_inverse(z).eval(), + epsilon.eval(), atol=1e-6) + # np.log(scipy.stats.norm(.5)) + self.assertAllClose(sampler.log_prob_s(epsilon).eval(), + -1.0439385332046727, atol=1e-6)