Skip to content

Commit eb7f133

Browse files
committed
policy aipw v1
1 parent f97fe53 commit eb7f133

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

book/cate_and_policy/policy_learning.ipynb

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,71 @@
3232
" async>\n",
3333
"</script>"
3434
]
35+
},
36+
{
37+
"cell_type": "code",
38+
"execution_count": null,
39+
"metadata": {},
40+
"outputs": [],
41+
"source": [
42+
"# 라이브러리와 난수 고정\n",
43+
"import numpy as np\n",
44+
"import matplotlib.pyplot as plt\n",
45+
"from matplotlib.patches import Rectangle\n",
46+
"rng = np.random.default_rng(42)\n",
47+
"\n",
48+
"# 데이터 생성\n",
49+
"n = 1000\n",
50+
"p = 4\n",
51+
"X = rng.random((n, p)) # 0~1 균등분포 공변량 4개\n",
52+
"W = rng.binomial(1, 0.5, size=n) # 무작위 처치(0/1), 확률 0.5\n",
53+
"Y = 0.5*(X[:, 0] - 0.5) + (X[:, 1] - 0.5)*W + 0.1*rng.normal(size=n)\n",
54+
"\n",
55+
"# 시각화\n",
56+
"y_norm = 1 - (Y - Y.min())/(Y.max() - Y.min()) # 0~1로 정규화\n",
57+
"gray_colors = np.array([str(v) for v in y_norm])\n",
58+
"\n",
59+
"plt.scatter(X[:, 0], X[:, 1], c=gray_colors, s=60, marker='o',\n",
60+
" edgecolors='k', linewidths=0.5)\n",
61+
"\n",
62+
"import matplotlib.pyplot as plt\n",
63+
"from matplotlib.patches import Rectangle\n",
64+
"\n",
65+
"plt.figure(figsize=(6, 5))\n",
66+
"\n",
67+
"# 1) 구역 칠하기 (사각형 3개)\n",
68+
"col_treat = (0.25, 0.69, 0.65, 0.35) # 초록 투명\n",
69+
"col_notreat = (0.996, 0.754, 0.027, 0.35) # 노랑 투명\n",
70+
"\n",
71+
"# 왼쪽(0~0.5, 전체 y)\n",
72+
"plt.gca().add_patch(Rectangle((-.1, -.1), 0.6, 1.2, facecolor=col_notreat, edgecolor='none', hatch='///'))\n",
73+
"# 오른쪽 아래(0.5~1, 0~0.5)\n",
74+
"plt.gca().add_patch(Rectangle((0.5, -.1), 0.6, 0.6, facecolor=col_notreat, edgecolor='none', hatch='///'))\n",
75+
"# 오른쪽 위(0.5~1, 0.5~1)\n",
76+
"plt.gca().add_patch(Rectangle((0.5, 0.5), 0.6, 0.6, facecolor=col_treat, edgecolor='none', hatch='///'))\n",
77+
"\n",
78+
"# 2) 점 찍기\n",
79+
"plt.scatter(X[W==0,0], X[W==0,1],\n",
80+
" c=y_norm[W==0], cmap='gray', vmin=0, vmax=1,\n",
81+
" s=60, marker='^', edgecolors='k', linewidths=0.5,\n",
82+
" label=\"Untreated\")\n",
83+
"plt.scatter(X[W==1,0], X[W==1,1],\n",
84+
" c=y_norm[W==1], cmap='gray', vmin=0, vmax=1,\n",
85+
" s=60, marker='o', edgecolors='k', linewidths=0.5,\n",
86+
" label=\"Treated\")\n",
87+
"\n",
88+
"# 3) 텍스트 라벨 붙이기\n",
89+
"plt.text(0.75, 0.75, \"TREAT (A)\", fontsize=14, ha='center', va='center')\n",
90+
"plt.text(0.25, 0.25, \"DO NOT TREAT (A^C)\", fontsize=12, ha='center', va='center')\n",
91+
"\n",
92+
"plt.xlim(-0.1, 1.1)\n",
93+
"plt.ylim(-0.1, 1.1)\n",
94+
"plt.xlabel(\"X1\"); plt.ylabel(\"X2\")\n",
95+
"plt.title(\"Policy Regions with Treated vs Untreated\")\n",
96+
"plt.legend()\n",
97+
"plt.tight_layout()\n",
98+
"plt.show()\n"
99+
]
35100
}
36101
],
37102
"metadata": {

0 commit comments

Comments
 (0)