From c21dec44adbd9a013997873e336a25be9071c612 Mon Sep 17 00:00:00 2001 From: QikeLi Date: Wed, 4 Jul 2018 18:04:49 -0700 Subject: [PATCH] Modify Policy Evaluation Solution according to David Silver's slides. --- DP/Policy Evaluation Solution.ipynb | 43 +++++++++++++++++++---------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/DP/Policy Evaluation Solution.ipynb b/DP/Policy Evaluation Solution.ipynb index 0b06f87e7..c6305795e 100644 --- a/DP/Policy Evaluation Solution.ipynb +++ b/DP/Policy Evaluation Solution.ipynb @@ -3,7 +3,9 @@ { "cell_type": "code", "execution_count": 1, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "from IPython.core.debugger import set_trace\n", @@ -18,7 +20,9 @@ { "cell_type": "code", "execution_count": 2, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "pp = pprint.PrettyPrinter(indent=2)\n", @@ -28,7 +32,9 @@ { "cell_type": "code", "execution_count": 3, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):\n", @@ -57,9 +63,11 @@ " # Look at the possible next actions\n", " for a, action_prob in enumerate(policy[s]):\n", " # For each action, look at the possible next states...\n", + " v_tmp = 0\n", " for prob, next_state, reward, done in env.P[s][a]:\n", - " # Calculate the expected value. Ref: Sutton book eq. 4.6.\n", - " v += action_prob * prob * (reward + discount_factor * V[next_state])\n", + " # Calculate v_k+1(current_state). Ref: page 8 of David Silver's slides for lecture 3\n", + " v_tmp += prob * V[next_state]\n", + " v += action_prob * (reward + discount_factor * v_tmp)\n", " # How much our value function changed (across any states)\n", " delta = max(delta, np.abs(v - V[s]))\n", " V[s] = v\n", @@ -72,7 +80,9 @@ { "cell_type": "code", "execution_count": 4, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "random_policy = np.ones([env.nS, env.nA]) / env.nA\n", @@ -91,8 +101,7 @@ "Value Function:\n", "[ 0. -13.99993529 -19.99990698 -21.99989761 -13.99993529\n", " -17.9999206 -19.99991379 -19.99991477 -19.99990698 -19.99991379\n", - " -17.99992725 -13.99994569 -21.99989761 -19.99991477 -13.99994569\n", - " 0. ]\n", + " -17.99992725 -13.99994569 -21.99989761 -19.99991477 -13.99994569 0. ]\n", "\n", "Reshaped Grid Value Function:\n", "[[ 0. -13.99993529 -19.99990698 -21.99989761]\n", @@ -116,7 +125,9 @@ { "cell_type": "code", "execution_count": 6, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# Test: Make sure the evaluated policy is what we expected\n", @@ -127,28 +138,30 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 2", "language": "python", - "name": "python3" + "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 3 + "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.4" + "pygments_lexer": "ipython2", + "version": "2.7.14" } }, "nbformat": 4,