Skip to content
181 changes: 151 additions & 30 deletions DP/Gamblers Problem.ipynb

Large diffs are not rendered by default.

83 changes: 47 additions & 36 deletions DP/Policy Evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
"cells": [
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": true
},
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
Expand All @@ -17,21 +15,17 @@
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": true
},
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"env = GridworldEnv()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": true
},
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):\n",
Expand All @@ -53,17 +47,27 @@
" # Start with a random (all 0) value function\n",
" V = np.zeros(env.nS)\n",
" while True:\n",
" # TODO: Implement!\n",
" break\n",
" delta = .0\n",
" for s in range(env.nS):\n",
" v = V[s]\n",
" v_new = 0\n",
" \n",
" for a, p_as in enumerate(policy[s]):\n",
" for (p_srsa, next_state, reward, done) in env.P[s][a]:\n",
" v_new += p_as * p_srsa * (reward + discount_factor * V[next_state])\n",
" V[s] = v_new\n",
" \n",
" delta = max(delta, np.abs(v - V[s]))\n",
"\n",
" if delta < theta:\n",
" break\n",
" return np.array(V)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": true
},
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"random_policy = np.ones([env.nS, env.nA]) / env.nA\n",
Expand All @@ -72,35 +76,42 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# Test: Make sure the evaluated policy is what we expected\n",
"expected_v = np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22, -20, -14, 0])\n",
"np.testing.assert_array_almost_equal(v, expected_v, decimal=2)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "\nArrays are not almost equal to 2 decimals\n\n(mismatch 87.5%)\n x: array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0.])\n y: array([ 0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22,\n -20, -14, 0])",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-22-235f39fb115c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Test: Make sure the evaluated policy is what we expected\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mexpected_v\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m22\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m18\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m18\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m22\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtesting\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0massert_array_almost_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpected_v\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_array_almost_equal\u001b[0;34m(x, y, decimal, err_msg, verbose)\u001b[0m\n\u001b[1;32m 914\u001b[0m assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,\n\u001b[1;32m 915\u001b[0m \u001b[0mheader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Arrays are not almost equal to %d decimals'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 916\u001b[0;31m precision=decimal)\n\u001b[0m\u001b[1;32m 917\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 918\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_array_compare\u001b[0;34m(comparison, x, y, err_msg, verbose, header, precision)\u001b[0m\n\u001b[1;32m 735\u001b[0m names=('x', 'y'), precision=precision)\n\u001b[1;32m 736\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcond\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 737\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 738\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 739\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtraceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAssertionError\u001b[0m: \nArrays are not almost equal to 2 decimals\n\n(mismatch 87.5%)\n x: array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0.])\n y: array([ 0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22,\n -20, -14, 0])"
]
"data": {
"text/plain": [
"array([ 0. , -13.99993529, -19.99990698, -21.99989761,\n",
" -13.99993529, -17.9999206 , -19.99991379, -19.99991477,\n",
" -19.99990698, -19.99991379, -17.99992725, -13.99994569,\n",
" -21.99989761, -19.99991477, -13.99994569, 0. ])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Test: Make sure the evaluated policy is what we expected\n",
"expected_v = np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22, -20, -14, 0])\n",
"np.testing.assert_array_almost_equal(v, expected_v, decimal=2)"
"v"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": []
}
Expand Down
135 changes: 71 additions & 64 deletions DP/Policy Iteration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
Expand All @@ -18,10 +16,8 @@
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"pp = pprint.PrettyPrinter(indent=2)\n",
Expand All @@ -30,10 +26,8 @@
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Taken from Policy Evaluation Exercise!\n",
Expand Down Expand Up @@ -78,10 +72,8 @@
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def policy_improvement(env, policy_eval_fn=policy_eval, discount_factor=1.0):\n",
Expand All @@ -102,57 +94,81 @@
" V is the value function for the optimal policy.\n",
" \n",
" \"\"\"\n",
" \n",
" def greedy_action(s, V, env):\n",
" A = np.zeros(env.nA)\n",
" for a in range(env.nA):\n",
" for prob, next_state, reward, done in env.P[s][a]:\n",
" A[a] += prob * (reward + discount_factor * V[next_state])\n",
" \n",
" return A.argmax()\n",
" \n",
" # Start with a random policy\n",
" policy = np.ones([env.nS, env.nA]) / env.nA\n",
" Vstar = None\n",
" \n",
" while True:\n",
" # Implement this!\n",
" break\n",
" V = policy_eval_fn(policy, env, discount_factor)\n",
" \n",
" policy_stable = True\n",
" for s in range(env.nS):\n",
" # given the initial policy is a random policy we should instead\n",
" # sample from the corresponding probability distribution\n",
" action = policy[s].argmax()\n",
" best_action = greedy_action(s, V, env)\n",
" \n",
" policy[s] = np.eye(env.nA)[best_action]\n",
" if action != best_action:\n",
" policy_stable = False\n",
" \n",
" if policy_stable:\n",
" Vstar = V\n",
" break\n",
" \n",
" return policy, np.zeros(env.nS)"
" return policy, Vstar"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Policy Probability Distribution:\n",
"[[ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]]\n",
"[[1. 0. 0. 0.]\n",
" [0. 0. 0. 1.]\n",
" [0. 0. 0. 1.]\n",
" [0. 0. 1. 0.]\n",
" [1. 0. 0. 0.]\n",
" [1. 0. 0. 0.]\n",
" [1. 0. 0. 0.]\n",
" [0. 0. 1. 0.]\n",
" [1. 0. 0. 0.]\n",
" [1. 0. 0. 0.]\n",
" [0. 1. 0. 0.]\n",
" [0. 0. 1. 0.]\n",
" [1. 0. 0. 0.]\n",
" [0. 1. 0. 0.]\n",
" [0. 1. 0. 0.]\n",
" [1. 0. 0. 0.]]\n",
"\n",
"Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\n",
"[[0 0 0 0]\n",
" [0 0 0 0]\n",
" [0 0 0 0]\n",
" [0 0 0 0]]\n",
"[[0 3 3 2]\n",
" [0 0 0 2]\n",
" [0 0 1 2]\n",
" [0 1 1 0]]\n",
"\n",
"Value Function:\n",
"[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"[ 0. -1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1. 0.]\n",
"\n",
"Reshaped Grid Value Function:\n",
"[[ 0. 0. 0. 0.]\n",
" [ 0. 0. 0. 0.]\n",
" [ 0. 0. 0. 0.]\n",
" [ 0. 0. 0. 0.]]\n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n"
]
}
Expand All @@ -179,23 +195,9 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "\nArrays are not almost equal to 2 decimals\n\n(mismatch 87.5%)\n x: array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0.])\n y: array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1, 0])",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-15-55581f8eb5c9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Test the value function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mexpected_v\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtesting\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0massert_array_almost_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpected_v\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_array_almost_equal\u001b[0;34m(x, y, decimal, err_msg, verbose)\u001b[0m\n\u001b[1;32m 914\u001b[0m assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,\n\u001b[1;32m 915\u001b[0m \u001b[0mheader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Arrays are not almost equal to %d decimals'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 916\u001b[0;31m precision=decimal)\n\u001b[0m\u001b[1;32m 917\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 918\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_array_compare\u001b[0;34m(comparison, x, y, err_msg, verbose, header, precision)\u001b[0m\n\u001b[1;32m 735\u001b[0m names=('x', 'y'), precision=precision)\n\u001b[1;32m 736\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcond\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 737\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 738\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 739\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtraceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAssertionError\u001b[0m: \nArrays are not almost equal to 2 decimals\n\n(mismatch 87.5%)\n x: array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0.])\n y: array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1, 0])"
]
}
],
"outputs": [],
"source": [
"# Test the value function\n",
"expected_v = np.array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1, 0])\n",
Expand All @@ -205,9 +207,14 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
Expand Down
Loading