Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions DP/Policy Evaluation Solution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from IPython.core.debugger import set_trace\n",
Expand All @@ -18,7 +20,9 @@
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"pp = pprint.PrettyPrinter(indent=2)\n",
Expand All @@ -28,7 +32,9 @@
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):\n",
Expand Down Expand Up @@ -57,9 +63,11 @@
" # Look at the possible next actions\n",
" for a, action_prob in enumerate(policy[s]):\n",
" # For each action, look at the possible next states...\n",
" v_tmp = 0\n",
" for prob, next_state, reward, done in env.P[s][a]:\n",
" # Calculate the expected value. Ref: Sutton book eq. 4.6.\n",
" v += action_prob * prob * (reward + discount_factor * V[next_state])\n",
" # Calculate v_k+1(current_state). Ref: page 8 of David Silver's slides for lecture 3\n",
" v_tmp += prob * V[next_state]\n",
" v += action_prob * (reward + discount_factor * v_tmp)\n",
" # How much our value function changed (across any states)\n",
" delta = max(delta, np.abs(v - V[s]))\n",
" V[s] = v\n",
Expand All @@ -72,7 +80,9 @@
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"random_policy = np.ones([env.nS, env.nA]) / env.nA\n",
Expand All @@ -91,8 +101,7 @@
"Value Function:\n",
"[ 0. -13.99993529 -19.99990698 -21.99989761 -13.99993529\n",
" -17.9999206 -19.99991379 -19.99991477 -19.99990698 -19.99991379\n",
" -17.99992725 -13.99994569 -21.99989761 -19.99991477 -13.99994569\n",
" 0. ]\n",
" -17.99992725 -13.99994569 -21.99989761 -19.99991477 -13.99994569 0. ]\n",
"\n",
"Reshaped Grid Value Function:\n",
"[[ 0. -13.99993529 -19.99990698 -21.99989761]\n",
Expand All @@ -116,7 +125,9 @@
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Test: Make sure the evaluated policy is what we expected\n",
Expand All @@ -127,28 +138,30 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 2",
"language": "python",
"name": "python3"
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
"pygments_lexer": "ipython2",
"version": "2.7.14"
}
},
"nbformat": 4,
Expand Down