|
39 | 39 | "metadata": {},
|
40 | 40 | "outputs": [],
|
41 | 41 | "source": [
|
42 |
| - "# 라이브러리와 난수 고정\n", |
43 | 42 | "import numpy as np\n",
|
44 | 43 | "import matplotlib.pyplot as plt\n",
|
45 | 44 | "from matplotlib.patches import Rectangle\n",
|
46 |
| - "rng = np.random.default_rng(42)\n", |
| 45 | + "import matplotlib.patches as mpatches" |
| 46 | + ] |
| 47 | + }, |
| 48 | + { |
| 49 | + "cell_type": "code", |
| 50 | + "execution_count": null, |
| 51 | + "metadata": {}, |
| 52 | + "outputs": [], |
| 53 | + "source": [ |
| 54 | + "# Set random seed for reproducibility\n", |
| 55 | + "np.random.seed(42)\n", |
47 | 56 | "\n",
|
48 |
| - "# 데이터 생성\n", |
| 57 | + "# Generate data\n", |
49 | 58 | "n = 1000\n",
|
50 | 59 | "p = 4\n",
|
51 |
| - "X = rng.random((n, p)) # 0~1 균등분포 공변량 4개\n", |
52 |
| - "W = rng.binomial(1, 0.5, size=n) # 무작위 처치(0/1), 확률 0.5\n", |
53 |
| - "Y = 0.5*(X[:, 0] - 0.5) + (X[:, 1] - 0.5)*W + 0.1*rng.normal(size=n)\n", |
| 60 | + "X = np.random.uniform(0, 1, (n, p))\n", |
| 61 | + "W = np.random.binomial(1, 0.5, n) # Independent from X and Y\n", |
| 62 | + "Y = 0.5 * (X[:, 0] - 0.5) + (X[:, 1] - 0.5) * W + 0.1 * np.random.randn(n)" |
| 63 | + ] |
| 64 | + }, |
| 65 | + { |
| 66 | + "cell_type": "code", |
| 67 | + "execution_count": null, |
| 68 | + "metadata": {}, |
| 69 | + "outputs": [], |
| 70 | + "source": [ |
| 71 | + "# Normalize Y for plotting\n", |
| 72 | + "y_norm = 1 - (Y - Y.min()) / (Y.max() - Y.min())\n", |
| 73 | + "\n", |
| 74 | + "# First plot: All data points\n", |
| 75 | + "fig1, ax1 = plt.subplots(1, 1, figsize=(8, 6))\n", |
| 76 | + "for i in range(n):\n", |
| 77 | + " if W[i] == 1:\n", |
| 78 | + " ax1.scatter(X[i, 0], X[i, 1], marker='o', s=100, \n", |
| 79 | + " c=[y_norm[i]], cmap='gray', vmin=0, vmax=1, \n", |
| 80 | + " edgecolors='black', linewidths=1)\n", |
| 81 | + " else:\n", |
| 82 | + " ax1.scatter(X[i, 0], X[i, 1], marker='D', s=80, \n", |
| 83 | + " c=[y_norm[i]], cmap='gray', vmin=0, vmax=1,\n", |
| 84 | + " edgecolors='black', linewidths=1)\n", |
| 85 | + "ax1.set_xlabel('X1', fontsize=12)\n", |
| 86 | + "ax1.set_ylabel('X2', fontsize=12)\n", |
| 87 | + "ax1.set_title('All Data Points (○: Treated, ◇: Untreated)', fontsize=14)\n", |
| 88 | + "plt.show()" |
| 89 | + ] |
| 90 | + }, |
| 91 | + { |
| 92 | + "cell_type": "code", |
| 93 | + "execution_count": null, |
| 94 | + "metadata": {}, |
| 95 | + "outputs": [], |
| 96 | + "source": [ |
| 97 | + "# Second plot: Separated by treatment\n", |
| 98 | + "fig2, (ax2, ax3) = plt.subplots(1, 2, figsize=(14, 6))\n", |
54 | 99 | "\n",
|
55 |
| - "# 시각화\n", |
56 |
| - "y_norm = 1 - (Y - Y.min())/(Y.max() - Y.min()) # 0~1로 정규화\n", |
57 |
| - "gray_colors = np.array([str(v) for v in y_norm])\n", |
| 100 | + "# Untreated group\n", |
| 101 | + "untreated_idx = W == 0\n", |
| 102 | + "ax2.scatter(X[untreated_idx, 0], X[untreated_idx, 1], marker='D', s=80, \n", |
| 103 | + " c=y_norm[untreated_idx], cmap='gray', vmin=0, vmax=1,\n", |
| 104 | + " edgecolors='black', linewidths=1)\n", |
| 105 | + "ax2.set_xlabel('X1', fontsize=12)\n", |
| 106 | + "ax2.set_ylabel('X2', fontsize=12)\n", |
| 107 | + "ax2.set_title('Untreated', fontsize=14)\n", |
58 | 108 | "\n",
|
59 |
| - "plt.scatter(X[:, 0], X[:, 1], c=gray_colors, s=60, marker='o',\n", |
60 |
| - " edgecolors='k', linewidths=0.5)\n", |
| 109 | + "# Treated group\n", |
| 110 | + "treated_idx = W == 1\n", |
| 111 | + "ax3.scatter(X[treated_idx, 0], X[treated_idx, 1], marker='o', s=100, \n", |
| 112 | + " c=y_norm[treated_idx], cmap='gray', vmin=0, vmax=1,\n", |
| 113 | + " edgecolors='black', linewidths=1)\n", |
| 114 | + "ax3.set_xlabel('X1', fontsize=12)\n", |
| 115 | + "ax3.set_ylabel('X2', fontsize=12)\n", |
| 116 | + "ax3.set_title('Treated', fontsize=14)\n", |
| 117 | + "plt.show()" |
| 118 | + ] |
| 119 | + }, |
| 120 | + { |
| 121 | + "cell_type": "code", |
| 122 | + "execution_count": null, |
| 123 | + "metadata": {}, |
| 124 | + "outputs": [], |
| 125 | + "source": [ |
| 126 | + "# Third plot: Policy regions\n", |
| 127 | + "fig3, ax4 = plt.subplots(1, 1, figsize=(8, 6))\n", |
61 | 128 | "\n",
|
62 |
| - "import matplotlib.pyplot as plt\n", |
63 |
| - "from matplotlib.patches import Rectangle\n", |
| 129 | + "# Define colors with transparency\n", |
| 130 | + "col1 = (0.9960938, 0.7539062, 0.0273438, 0.35) # Yellow-ish\n", |
| 131 | + "col2 = (0.250980, 0.690196, 0.650980, 0.35) # Teal-ish\n", |
64 | 132 | "\n",
|
65 |
| - "plt.figure(figsize=(6, 5))\n", |
| 133 | + "# Draw policy regions\n", |
| 134 | + "rect1 = Rectangle((-0.1, -0.1), 0.6, 1.2, linewidth=0, \n", |
| 135 | + " edgecolor='none', facecolor=col1, hatch='///')\n", |
| 136 | + "rect2 = Rectangle((0.5, -0.1), 0.6, 0.6, linewidth=0, \n", |
| 137 | + " edgecolor='none', facecolor=col1, hatch='///')\n", |
| 138 | + "rect3 = Rectangle((0.5, 0.5), 0.6, 0.6, linewidth=0, \n", |
| 139 | + " edgecolor='none', facecolor=col2, hatch='///')\n", |
| 140 | + "ax4.add_patch(rect1)\n", |
| 141 | + "ax4.add_patch(rect2)\n", |
| 142 | + "ax4.add_patch(rect3)\n", |
66 | 143 | "\n",
|
67 |
| - "# 1) 구역 칠하기 (사각형 3개)\n", |
68 |
| - "col_treat = (0.25, 0.69, 0.65, 0.35) # 초록 투명\n", |
69 |
| - "col_notreat = (0.996, 0.754, 0.027, 0.35) # 노랑 투명\n", |
| 144 | + "# Plot data points\n", |
| 145 | + "for i in range(n):\n", |
| 146 | + " if W[i] == 1:\n", |
| 147 | + " ax4.scatter(X[i, 0], X[i, 1], marker='o', s=100, \n", |
| 148 | + " c=[y_norm[i]], cmap='gray', vmin=0, vmax=1, \n", |
| 149 | + " edgecolors='black', linewidths=1)\n", |
| 150 | + " else:\n", |
| 151 | + " ax4.scatter(X[i, 0], X[i, 1], marker='D', s=80, \n", |
| 152 | + " c=[y_norm[i]], cmap='gray', vmin=0, vmax=1,\n", |
| 153 | + " edgecolors='black', linewidths=1)\n", |
| 154 | + "\n", |
| 155 | + "# Add text labels\n", |
| 156 | + "ax4.text(0.75, 0.75, 'TREAT (A)', fontsize=16, ha='center', va='center')\n", |
| 157 | + "ax4.text(0.25, 0.25, 'DO NOT TREAT (A^C)', fontsize=16, ha='left', va='center')\n", |
| 158 | + "ax4.set_xlabel('X1', fontsize=12)\n", |
| 159 | + "ax4.set_ylabel('X2', fontsize=12)\n", |
| 160 | + "ax4.set_xlim(-0.1, 1.1)\n", |
| 161 | + "ax4.set_ylim(-0.1, 1.1)\n", |
| 162 | + "ax4.set_title('Policy Regions', fontsize=14)\n", |
| 163 | + "plt.show()" |
| 164 | + ] |
| 165 | + }, |
| 166 | + { |
| 167 | + "cell_type": "code", |
| 168 | + "execution_count": null, |
| 169 | + "metadata": {}, |
| 170 | + "outputs": [], |
| 171 | + "source": [ |
| 172 | + "# Policy Evaluation Methods\n", |
| 173 | + "print(\"=\" * 60)\n", |
| 174 | + "print(\"POLICY EVALUATION RESULTS\")\n", |
| 175 | + "print(\"=\" * 60)\n", |
70 | 176 | "\n",
|
71 |
| - "# 왼쪽(0~0.5, 전체 y)\n", |
72 |
| - "plt.gca().add_patch(Rectangle((-.1, -.1), 0.6, 1.2, facecolor=col_notreat, edgecolor='none', hatch='///'))\n", |
73 |
| - "# 오른쪽 아래(0.5~1, 0~0.5)\n", |
74 |
| - "plt.gca().add_patch(Rectangle((0.5, -.1), 0.6, 0.6, facecolor=col_notreat, edgecolor='none', hatch='///'))\n", |
75 |
| - "# 오른쪽 위(0.5~1, 0.5~1)\n", |
76 |
| - "plt.gca().add_patch(Rectangle((0.5, 0.5), 0.6, 0.6, facecolor=col_treat, edgecolor='none', hatch='///'))\n", |
| 177 | + "# Method 1: Value of policy A (only valid in randomized setting)\n", |
| 178 | + "A = (X[:, 0] > 0.5) & (X[:, 1] > 0.5)\n", |
| 179 | + "value_estimate = np.mean(Y[A & (W == 1)]) * np.mean(A) + \\\n", |
| 180 | + " np.mean(Y[~A & (W == 0)]) * np.mean(~A)\n", |
| 181 | + "value_stderr = np.sqrt(\n", |
| 182 | + " np.var(Y[A & (W == 1)]) / np.sum(A & (W == 1)) * np.mean(A)**2 + \n", |
| 183 | + " np.var(Y[~A & (W == 0)]) / np.sum(~A & (W == 0)) * np.mean(~A)**2\n", |
| 184 | + ")\n", |
| 185 | + "print(f\"\\nMethod 1: Value of Policy A\")\n", |
| 186 | + "print(f\"Value estimate: {value_estimate:.6f}\")\n", |
| 187 | + "print(f\"Std. Error: {value_stderr:.6f}\")" |
| 188 | + ] |
| 189 | + }, |
| 190 | + { |
| 191 | + "cell_type": "code", |
| 192 | + "execution_count": null, |
| 193 | + "metadata": {}, |
| 194 | + "outputs": [], |
| 195 | + "source": [ |
| 196 | + "# Method 2: Value of fixed treatment proportion (p=0.75)\n", |
| 197 | + "p_treat = 0.75\n", |
| 198 | + "value_estimate2 = p_treat * np.mean(Y[W == 1]) + (1 - p_treat) * np.mean(Y[W == 0])\n", |
| 199 | + "value_stderr2 = np.sqrt(\n", |
| 200 | + " np.var(Y[W == 1]) / np.sum(W == 1) * p_treat**2 + \n", |
| 201 | + " np.var(Y[W == 0]) / np.sum(W == 0) * (1 - p_treat)**2\n", |
| 202 | + ")\n", |
| 203 | + "print(f\"\\nMethod 2: Value of Fixed Treatment Proportion (p={p_treat})\")\n", |
| 204 | + "print(f\"Value estimate: {value_estimate2:.6f}\")\n", |
| 205 | + "print(f\"Std. Error: {value_stderr2:.6f}\")" |
| 206 | + ] |
| 207 | + }, |
| 208 | + { |
| 209 | + "cell_type": "code", |
| 210 | + "execution_count": null, |
| 211 | + "metadata": {}, |
| 212 | + "outputs": [], |
| 213 | + "source": [ |
| 214 | + "# Method 3: Treatment effect within policy region A\n", |
| 215 | + "diff_estimate = (np.mean(Y[A & (W == 1)]) - np.mean(Y[A & (W == 0)])) * np.mean(A)\n", |
| 216 | + "diff_stderr = np.sqrt(\n", |
| 217 | + " np.var(Y[A & (W == 1)]) / np.sum(A & (W == 1)) + \n", |
| 218 | + " np.var(Y[A & (W == 0)]) / np.sum(A & (W == 0))\n", |
| 219 | + ") * np.mean(A)\n", |
| 220 | + "print(f\"\\nMethod 3: Treatment Effect within Policy Region A\")\n", |
| 221 | + "print(f\"Difference estimate: {diff_estimate:.6f}\")\n", |
| 222 | + "print(f\"Std. Error: {diff_stderr:.6f}\")" |
| 223 | + ] |
| 224 | + }, |
| 225 | + { |
| 226 | + "cell_type": "code", |
| 227 | + "execution_count": null, |
| 228 | + "metadata": {}, |
| 229 | + "outputs": [], |
| 230 | + "source": [ |
| 231 | + "# Method 4: Optimal policy difference\n", |
| 232 | + "diff_estimate2 = (np.mean(Y[A & (W == 1)]) - np.mean(Y[A & (W == 0)])) * np.mean(A) / 2 + \\\n", |
| 233 | + " (np.mean(Y[~A & (W == 0)]) - np.mean(Y[~A & (W == 1)])) * np.mean(~A) / 2\n", |
| 234 | + "diff_stderr2 = np.sqrt(\n", |
| 235 | + " (np.mean(A) / 2)**2 * (\n", |
| 236 | + " np.var(Y[A & (W == 1)]) / np.sum(A & (W == 1)) + \n", |
| 237 | + " np.var(Y[A & (W == 0)]) / np.sum(A & (W == 0))\n", |
| 238 | + " ) + \n", |
| 239 | + " (np.mean(~A) / 2)**2 * (\n", |
| 240 | + " np.var(Y[~A & (W == 1)]) / np.sum(~A & (W == 1)) + \n", |
| 241 | + " np.var(Y[~A & (W == 0)]) / np.sum(~A & (W == 0))\n", |
| 242 | + " )\n", |
| 243 | + ")\n", |
| 244 | + "print(f\"\\nMethod 4: Optimal Policy Difference\")\n", |
| 245 | + "print(f\"Difference estimate: {diff_estimate2:.6f}\")\n", |
| 246 | + "print(f\"Std. Error: {diff_stderr2:.6f}\")\n", |
77 | 247 | "\n",
|
78 |
| - "# 2) 점 찍기\n", |
79 |
| - "plt.scatter(X[W==0,0], X[W==0,1],\n", |
80 |
| - " c=y_norm[W==0], cmap='gray', vmin=0, vmax=1,\n", |
81 |
| - " s=60, marker='^', edgecolors='k', linewidths=0.5,\n", |
82 |
| - " label=\"Untreated\")\n", |
83 |
| - "plt.scatter(X[W==1,0], X[W==1,1],\n", |
84 |
| - " c=y_norm[W==1], cmap='gray', vmin=0, vmax=1,\n", |
85 |
| - " s=60, marker='o', edgecolors='k', linewidths=0.5,\n", |
86 |
| - " label=\"Treated\")\n", |
| 248 | + "print(\"\\n\" + \"=\" * 60)" |
| 249 | + ] |
| 250 | + }, |
| 251 | + { |
| 252 | + "cell_type": "code", |
| 253 | + "execution_count": null, |
| 254 | + "metadata": {}, |
| 255 | + "outputs": [], |
| 256 | + "source": [ |
| 257 | + "# Additional analysis: Treatment effect heterogeneity\n", |
| 258 | + "print(\"\\nADDITIONAL ANALYSIS\")\n", |
| 259 | + "print(\"=\" * 60)\n", |
87 | 260 | "\n",
|
88 |
| - "# 3) 텍스트 라벨 붙이기\n", |
89 |
| - "plt.text(0.75, 0.75, \"TREAT (A)\", fontsize=14, ha='center', va='center')\n", |
90 |
| - "plt.text(0.25, 0.25, \"DO NOT TREAT (A^C)\", fontsize=12, ha='center', va='center')\n", |
| 261 | + "# Calculate treatment effects by region\n", |
| 262 | + "te_in_A = np.mean(Y[A & (W == 1)]) - np.mean(Y[A & (W == 0)])\n", |
| 263 | + "te_out_A = np.mean(Y[~A & (W == 1)]) - np.mean(Y[~A & (W == 0)])\n", |
91 | 264 | "\n",
|
92 |
| - "plt.xlim(-0.1, 1.1)\n", |
93 |
| - "plt.ylim(-0.1, 1.1)\n", |
94 |
| - "plt.xlabel(\"X1\"); plt.ylabel(\"X2\")\n", |
95 |
| - "plt.title(\"Policy Regions with Treated vs Untreated\")\n", |
96 |
| - "plt.legend()\n", |
97 |
| - "plt.tight_layout()\n", |
98 |
| - "plt.show()\n" |
| 265 | + "print(f\"\\nTreatment Effect Heterogeneity:\")\n", |
| 266 | + "print(f\"Treatment effect in region A: {te_in_A:.6f}\")\n", |
| 267 | + "print(f\"Treatment effect outside region A: {te_out_A:.6f}\")\n", |
| 268 | + "print(f\"Difference in treatment effects: {te_in_A - te_out_A:.6f}\")" |
| 269 | + ] |
| 270 | + }, |
| 271 | + { |
| 272 | + "cell_type": "code", |
| 273 | + "execution_count": null, |
| 274 | + "metadata": {}, |
| 275 | + "outputs": [], |
| 276 | + "source": [ |
| 277 | + "# Summary statistics\n", |
| 278 | + "print(f\"\\nSummary Statistics:\")\n", |
| 279 | + "print(f\"Proportion in region A: {np.mean(A):.3f}\")\n", |
| 280 | + "print(f\"Proportion treated: {np.mean(W):.3f}\")\n", |
| 281 | + "print(f\"Mean outcome (treated): {np.mean(Y[W == 1]):.6f}\")\n", |
| 282 | + "print(f\"Mean outcome (untreated): {np.mean(Y[W == 0]):.6f}\")\n", |
| 283 | + "print(f\"Overall treatment effect: {np.mean(Y[W == 1]) - np.mean(Y[W == 0]):.6f}\")" |
99 | 284 | ]
|
100 | 285 | }
|
101 | 286 | ],
|
|
0 commit comments