Skip to content

Commit 1a6025b

Browse files
committed
policy aipw v2
1 parent eb7f133 commit 1a6025b

File tree

1 file changed

+227
-42
lines changed

1 file changed

+227
-42
lines changed

book/cate_and_policy/policy_learning.ipynb

Lines changed: 227 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -39,63 +39,248 @@
3939
"metadata": {},
4040
"outputs": [],
4141
"source": [
42-
"# 라이브러리와 난수 고정\n",
4342
"import numpy as np\n",
4443
"import matplotlib.pyplot as plt\n",
4544
"from matplotlib.patches import Rectangle\n",
46-
"rng = np.random.default_rng(42)\n",
45+
"import matplotlib.patches as mpatches"
46+
]
47+
},
48+
{
49+
"cell_type": "code",
50+
"execution_count": null,
51+
"metadata": {},
52+
"outputs": [],
53+
"source": [
54+
"# Set random seed for reproducibility\n",
55+
"np.random.seed(42)\n",
4756
"\n",
48-
"# 데이터 생성\n",
57+
"# Generate data\n",
4958
"n = 1000\n",
5059
"p = 4\n",
51-
"X = rng.random((n, p)) # 0~1 균등분포 공변량 4개\n",
52-
"W = rng.binomial(1, 0.5, size=n) # 무작위 처치(0/1), 확률 0.5\n",
53-
"Y = 0.5*(X[:, 0] - 0.5) + (X[:, 1] - 0.5)*W + 0.1*rng.normal(size=n)\n",
60+
"X = np.random.uniform(0, 1, (n, p))\n",
61+
"W = np.random.binomial(1, 0.5, n) # Independent from X and Y\n",
62+
"Y = 0.5 * (X[:, 0] - 0.5) + (X[:, 1] - 0.5) * W + 0.1 * np.random.randn(n)"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": null,
68+
"metadata": {},
69+
"outputs": [],
70+
"source": [
71+
"# Normalize Y for plotting\n",
72+
"y_norm = 1 - (Y - Y.min()) / (Y.max() - Y.min())\n",
73+
"\n",
74+
"# First plot: All data points\n",
75+
"fig1, ax1 = plt.subplots(1, 1, figsize=(8, 6))\n",
76+
"for i in range(n):\n",
77+
" if W[i] == 1:\n",
78+
" ax1.scatter(X[i, 0], X[i, 1], marker='o', s=100, \n",
79+
" c=[y_norm[i]], cmap='gray', vmin=0, vmax=1, \n",
80+
" edgecolors='black', linewidths=1)\n",
81+
" else:\n",
82+
" ax1.scatter(X[i, 0], X[i, 1], marker='D', s=80, \n",
83+
" c=[y_norm[i]], cmap='gray', vmin=0, vmax=1,\n",
84+
" edgecolors='black', linewidths=1)\n",
85+
"ax1.set_xlabel('X1', fontsize=12)\n",
86+
"ax1.set_ylabel('X2', fontsize=12)\n",
87+
"ax1.set_title('All Data Points (○: Treated, ◇: Untreated)', fontsize=14)\n",
88+
"plt.show()"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": null,
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"# Second plot: Separated by treatment\n",
98+
"fig2, (ax2, ax3) = plt.subplots(1, 2, figsize=(14, 6))\n",
5499
"\n",
55-
"# 시각화\n",
56-
"y_norm = 1 - (Y - Y.min())/(Y.max() - Y.min()) # 0~1로 정규화\n",
57-
"gray_colors = np.array([str(v) for v in y_norm])\n",
100+
"# Untreated group\n",
101+
"untreated_idx = W == 0\n",
102+
"ax2.scatter(X[untreated_idx, 0], X[untreated_idx, 1], marker='D', s=80, \n",
103+
" c=y_norm[untreated_idx], cmap='gray', vmin=0, vmax=1,\n",
104+
" edgecolors='black', linewidths=1)\n",
105+
"ax2.set_xlabel('X1', fontsize=12)\n",
106+
"ax2.set_ylabel('X2', fontsize=12)\n",
107+
"ax2.set_title('Untreated', fontsize=14)\n",
58108
"\n",
59-
"plt.scatter(X[:, 0], X[:, 1], c=gray_colors, s=60, marker='o',\n",
60-
" edgecolors='k', linewidths=0.5)\n",
109+
"# Treated group\n",
110+
"treated_idx = W == 1\n",
111+
"ax3.scatter(X[treated_idx, 0], X[treated_idx, 1], marker='o', s=100, \n",
112+
" c=y_norm[treated_idx], cmap='gray', vmin=0, vmax=1,\n",
113+
" edgecolors='black', linewidths=1)\n",
114+
"ax3.set_xlabel('X1', fontsize=12)\n",
115+
"ax3.set_ylabel('X2', fontsize=12)\n",
116+
"ax3.set_title('Treated', fontsize=14)\n",
117+
"plt.show()"
118+
]
119+
},
120+
{
121+
"cell_type": "code",
122+
"execution_count": null,
123+
"metadata": {},
124+
"outputs": [],
125+
"source": [
126+
"# Third plot: Policy regions\n",
127+
"fig3, ax4 = plt.subplots(1, 1, figsize=(8, 6))\n",
61128
"\n",
62-
"import matplotlib.pyplot as plt\n",
63-
"from matplotlib.patches import Rectangle\n",
129+
"# Define colors with transparency\n",
130+
"col1 = (0.9960938, 0.7539062, 0.0273438, 0.35) # Yellow-ish\n",
131+
"col2 = (0.250980, 0.690196, 0.650980, 0.35) # Teal-ish\n",
64132
"\n",
65-
"plt.figure(figsize=(6, 5))\n",
133+
"# Draw policy regions\n",
134+
"rect1 = Rectangle((-0.1, -0.1), 0.6, 1.2, linewidth=0, \n",
135+
" edgecolor='none', facecolor=col1, hatch='///')\n",
136+
"rect2 = Rectangle((0.5, -0.1), 0.6, 0.6, linewidth=0, \n",
137+
" edgecolor='none', facecolor=col1, hatch='///')\n",
138+
"rect3 = Rectangle((0.5, 0.5), 0.6, 0.6, linewidth=0, \n",
139+
" edgecolor='none', facecolor=col2, hatch='///')\n",
140+
"ax4.add_patch(rect1)\n",
141+
"ax4.add_patch(rect2)\n",
142+
"ax4.add_patch(rect3)\n",
66143
"\n",
67-
"# 1) 구역 칠하기 (사각형 3개)\n",
68-
"col_treat = (0.25, 0.69, 0.65, 0.35) # 초록 투명\n",
69-
"col_notreat = (0.996, 0.754, 0.027, 0.35) # 노랑 투명\n",
144+
"# Plot data points\n",
145+
"for i in range(n):\n",
146+
" if W[i] == 1:\n",
147+
" ax4.scatter(X[i, 0], X[i, 1], marker='o', s=100, \n",
148+
" c=[y_norm[i]], cmap='gray', vmin=0, vmax=1, \n",
149+
" edgecolors='black', linewidths=1)\n",
150+
" else:\n",
151+
" ax4.scatter(X[i, 0], X[i, 1], marker='D', s=80, \n",
152+
" c=[y_norm[i]], cmap='gray', vmin=0, vmax=1,\n",
153+
" edgecolors='black', linewidths=1)\n",
154+
"\n",
155+
"# Add text labels\n",
156+
"ax4.text(0.75, 0.75, 'TREAT (A)', fontsize=16, ha='center', va='center')\n",
157+
"ax4.text(0.25, 0.25, 'DO NOT TREAT (A^C)', fontsize=16, ha='left', va='center')\n",
158+
"ax4.set_xlabel('X1', fontsize=12)\n",
159+
"ax4.set_ylabel('X2', fontsize=12)\n",
160+
"ax4.set_xlim(-0.1, 1.1)\n",
161+
"ax4.set_ylim(-0.1, 1.1)\n",
162+
"ax4.set_title('Policy Regions', fontsize=14)\n",
163+
"plt.show()"
164+
]
165+
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": null,
169+
"metadata": {},
170+
"outputs": [],
171+
"source": [
172+
"# Policy Evaluation Methods\n",
173+
"print(\"=\" * 60)\n",
174+
"print(\"POLICY EVALUATION RESULTS\")\n",
175+
"print(\"=\" * 60)\n",
70176
"\n",
71-
"# 왼쪽(0~0.5, 전체 y)\n",
72-
"plt.gca().add_patch(Rectangle((-.1, -.1), 0.6, 1.2, facecolor=col_notreat, edgecolor='none', hatch='///'))\n",
73-
"# 오른쪽 아래(0.5~1, 0~0.5)\n",
74-
"plt.gca().add_patch(Rectangle((0.5, -.1), 0.6, 0.6, facecolor=col_notreat, edgecolor='none', hatch='///'))\n",
75-
"# 오른쪽 위(0.5~1, 0.5~1)\n",
76-
"plt.gca().add_patch(Rectangle((0.5, 0.5), 0.6, 0.6, facecolor=col_treat, edgecolor='none', hatch='///'))\n",
177+
"# Method 1: Value of policy A (only valid in randomized setting)\n",
178+
"A = (X[:, 0] > 0.5) & (X[:, 1] > 0.5)\n",
179+
"value_estimate = np.mean(Y[A & (W == 1)]) * np.mean(A) + \\\n",
180+
" np.mean(Y[~A & (W == 0)]) * np.mean(~A)\n",
181+
"value_stderr = np.sqrt(\n",
182+
" np.var(Y[A & (W == 1)]) / np.sum(A & (W == 1)) * np.mean(A)**2 + \n",
183+
" np.var(Y[~A & (W == 0)]) / np.sum(~A & (W == 0)) * np.mean(~A)**2\n",
184+
")\n",
185+
"print(f\"\\nMethod 1: Value of Policy A\")\n",
186+
"print(f\"Value estimate: {value_estimate:.6f}\")\n",
187+
"print(f\"Std. Error: {value_stderr:.6f}\")"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"metadata": {},
194+
"outputs": [],
195+
"source": [
196+
"# Method 2: Value of fixed treatment proportion (p=0.75)\n",
197+
"p_treat = 0.75\n",
198+
"value_estimate2 = p_treat * np.mean(Y[W == 1]) + (1 - p_treat) * np.mean(Y[W == 0])\n",
199+
"value_stderr2 = np.sqrt(\n",
200+
" np.var(Y[W == 1]) / np.sum(W == 1) * p_treat**2 + \n",
201+
" np.var(Y[W == 0]) / np.sum(W == 0) * (1 - p_treat)**2\n",
202+
")\n",
203+
"print(f\"\\nMethod 2: Value of Fixed Treatment Proportion (p={p_treat})\")\n",
204+
"print(f\"Value estimate: {value_estimate2:.6f}\")\n",
205+
"print(f\"Std. Error: {value_stderr2:.6f}\")"
206+
]
207+
},
208+
{
209+
"cell_type": "code",
210+
"execution_count": null,
211+
"metadata": {},
212+
"outputs": [],
213+
"source": [
214+
"# Method 3: Treatment effect within policy region A\n",
215+
"diff_estimate = (np.mean(Y[A & (W == 1)]) - np.mean(Y[A & (W == 0)])) * np.mean(A)\n",
216+
"diff_stderr = np.sqrt(\n",
217+
" np.var(Y[A & (W == 1)]) / np.sum(A & (W == 1)) + \n",
218+
" np.var(Y[A & (W == 0)]) / np.sum(A & (W == 0))\n",
219+
") * np.mean(A)\n",
220+
"print(f\"\\nMethod 3: Treatment Effect within Policy Region A\")\n",
221+
"print(f\"Difference estimate: {diff_estimate:.6f}\")\n",
222+
"print(f\"Std. Error: {diff_stderr:.6f}\")"
223+
]
224+
},
225+
{
226+
"cell_type": "code",
227+
"execution_count": null,
228+
"metadata": {},
229+
"outputs": [],
230+
"source": [
231+
"# Method 4: Optimal policy difference\n",
232+
"diff_estimate2 = (np.mean(Y[A & (W == 1)]) - np.mean(Y[A & (W == 0)])) * np.mean(A) / 2 + \\\n",
233+
" (np.mean(Y[~A & (W == 0)]) - np.mean(Y[~A & (W == 1)])) * np.mean(~A) / 2\n",
234+
"diff_stderr2 = np.sqrt(\n",
235+
" (np.mean(A) / 2)**2 * (\n",
236+
" np.var(Y[A & (W == 1)]) / np.sum(A & (W == 1)) + \n",
237+
" np.var(Y[A & (W == 0)]) / np.sum(A & (W == 0))\n",
238+
" ) + \n",
239+
" (np.mean(~A) / 2)**2 * (\n",
240+
" np.var(Y[~A & (W == 1)]) / np.sum(~A & (W == 1)) + \n",
241+
" np.var(Y[~A & (W == 0)]) / np.sum(~A & (W == 0))\n",
242+
" )\n",
243+
")\n",
244+
"print(f\"\\nMethod 4: Optimal Policy Difference\")\n",
245+
"print(f\"Difference estimate: {diff_estimate2:.6f}\")\n",
246+
"print(f\"Std. Error: {diff_stderr2:.6f}\")\n",
77247
"\n",
78-
"# 2) 점 찍기\n",
79-
"plt.scatter(X[W==0,0], X[W==0,1],\n",
80-
" c=y_norm[W==0], cmap='gray', vmin=0, vmax=1,\n",
81-
" s=60, marker='^', edgecolors='k', linewidths=0.5,\n",
82-
" label=\"Untreated\")\n",
83-
"plt.scatter(X[W==1,0], X[W==1,1],\n",
84-
" c=y_norm[W==1], cmap='gray', vmin=0, vmax=1,\n",
85-
" s=60, marker='o', edgecolors='k', linewidths=0.5,\n",
86-
" label=\"Treated\")\n",
248+
"print(\"\\n\" + \"=\" * 60)"
249+
]
250+
},
251+
{
252+
"cell_type": "code",
253+
"execution_count": null,
254+
"metadata": {},
255+
"outputs": [],
256+
"source": [
257+
"# Additional analysis: Treatment effect heterogeneity\n",
258+
"print(\"\\nADDITIONAL ANALYSIS\")\n",
259+
"print(\"=\" * 60)\n",
87260
"\n",
88-
"# 3) 텍스트 라벨 붙이기\n",
89-
"plt.text(0.75, 0.75, \"TREAT (A)\", fontsize=14, ha='center', va='center')\n",
90-
"plt.text(0.25, 0.25, \"DO NOT TREAT (A^C)\", fontsize=12, ha='center', va='center')\n",
261+
"# Calculate treatment effects by region\n",
262+
"te_in_A = np.mean(Y[A & (W == 1)]) - np.mean(Y[A & (W == 0)])\n",
263+
"te_out_A = np.mean(Y[~A & (W == 1)]) - np.mean(Y[~A & (W == 0)])\n",
91264
"\n",
92-
"plt.xlim(-0.1, 1.1)\n",
93-
"plt.ylim(-0.1, 1.1)\n",
94-
"plt.xlabel(\"X1\"); plt.ylabel(\"X2\")\n",
95-
"plt.title(\"Policy Regions with Treated vs Untreated\")\n",
96-
"plt.legend()\n",
97-
"plt.tight_layout()\n",
98-
"plt.show()\n"
265+
"print(f\"\\nTreatment Effect Heterogeneity:\")\n",
266+
"print(f\"Treatment effect in region A: {te_in_A:.6f}\")\n",
267+
"print(f\"Treatment effect outside region A: {te_out_A:.6f}\")\n",
268+
"print(f\"Difference in treatment effects: {te_in_A - te_out_A:.6f}\")"
269+
]
270+
},
271+
{
272+
"cell_type": "code",
273+
"execution_count": null,
274+
"metadata": {},
275+
"outputs": [],
276+
"source": [
277+
"# Summary statistics\n",
278+
"print(f\"\\nSummary Statistics:\")\n",
279+
"print(f\"Proportion in region A: {np.mean(A):.3f}\")\n",
280+
"print(f\"Proportion treated: {np.mean(W):.3f}\")\n",
281+
"print(f\"Mean outcome (treated): {np.mean(Y[W == 1]):.6f}\")\n",
282+
"print(f\"Mean outcome (untreated): {np.mean(Y[W == 0]):.6f}\")\n",
283+
"print(f\"Overall treatment effect: {np.mean(Y[W == 1]) - np.mean(Y[W == 0]):.6f}\")"
99284
]
100285
}
101286
],

0 commit comments

Comments
 (0)