|
32 | 32 | " async>\n",
|
33 | 33 | "</script>"
|
34 | 34 | ]
|
| 35 | + }, |
| 36 | + { |
| 37 | + "cell_type": "code", |
| 38 | + "execution_count": null, |
| 39 | + "metadata": {}, |
| 40 | + "outputs": [], |
| 41 | + "source": [ |
| 42 | + "# 라이브러리와 난수 고정\n", |
| 43 | + "import numpy as np\n", |
| 44 | + "import matplotlib.pyplot as plt\n", |
| 45 | + "from matplotlib.patches import Rectangle\n", |
| 46 | + "rng = np.random.default_rng(42)\n", |
| 47 | + "\n", |
| 48 | + "# 데이터 생성\n", |
| 49 | + "n = 1000\n", |
| 50 | + "p = 4\n", |
| 51 | + "X = rng.random((n, p)) # 0~1 균등분포 공변량 4개\n", |
| 52 | + "W = rng.binomial(1, 0.5, size=n) # 무작위 처치(0/1), 확률 0.5\n", |
| 53 | + "Y = 0.5*(X[:, 0] - 0.5) + (X[:, 1] - 0.5)*W + 0.1*rng.normal(size=n)\n", |
| 54 | + "\n", |
| 55 | + "# 시각화\n", |
| 56 | + "y_norm = 1 - (Y - Y.min())/(Y.max() - Y.min()) # 0~1로 정규화\n", |
| 57 | + "gray_colors = np.array([str(v) for v in y_norm])\n", |
| 58 | + "\n", |
| 59 | + "plt.scatter(X[:, 0], X[:, 1], c=gray_colors, s=60, marker='o',\n", |
| 60 | + " edgecolors='k', linewidths=0.5)\n", |
| 61 | + "\n", |
| 62 | + "import matplotlib.pyplot as plt\n", |
| 63 | + "from matplotlib.patches import Rectangle\n", |
| 64 | + "\n", |
| 65 | + "plt.figure(figsize=(6, 5))\n", |
| 66 | + "\n", |
| 67 | + "# 1) 구역 칠하기 (사각형 3개)\n", |
| 68 | + "col_treat = (0.25, 0.69, 0.65, 0.35) # 초록 투명\n", |
| 69 | + "col_notreat = (0.996, 0.754, 0.027, 0.35) # 노랑 투명\n", |
| 70 | + "\n", |
| 71 | + "# 왼쪽(0~0.5, 전체 y)\n", |
| 72 | + "plt.gca().add_patch(Rectangle((-.1, -.1), 0.6, 1.2, facecolor=col_notreat, edgecolor='none', hatch='///'))\n", |
| 73 | + "# 오른쪽 아래(0.5~1, 0~0.5)\n", |
| 74 | + "plt.gca().add_patch(Rectangle((0.5, -.1), 0.6, 0.6, facecolor=col_notreat, edgecolor='none', hatch='///'))\n", |
| 75 | + "# 오른쪽 위(0.5~1, 0.5~1)\n", |
| 76 | + "plt.gca().add_patch(Rectangle((0.5, 0.5), 0.6, 0.6, facecolor=col_treat, edgecolor='none', hatch='///'))\n", |
| 77 | + "\n", |
| 78 | + "# 2) 점 찍기\n", |
| 79 | + "plt.scatter(X[W==0,0], X[W==0,1],\n", |
| 80 | + " c=y_norm[W==0], cmap='gray', vmin=0, vmax=1,\n", |
| 81 | + " s=60, marker='^', edgecolors='k', linewidths=0.5,\n", |
| 82 | + " label=\"Untreated\")\n", |
| 83 | + "plt.scatter(X[W==1,0], X[W==1,1],\n", |
| 84 | + " c=y_norm[W==1], cmap='gray', vmin=0, vmax=1,\n", |
| 85 | + " s=60, marker='o', edgecolors='k', linewidths=0.5,\n", |
| 86 | + " label=\"Treated\")\n", |
| 87 | + "\n", |
| 88 | + "# 3) 텍스트 라벨 붙이기\n", |
| 89 | + "plt.text(0.75, 0.75, \"TREAT (A)\", fontsize=14, ha='center', va='center')\n", |
| 90 | + "plt.text(0.25, 0.25, \"DO NOT TREAT (A^C)\", fontsize=12, ha='center', va='center')\n", |
| 91 | + "\n", |
| 92 | + "plt.xlim(-0.1, 1.1)\n", |
| 93 | + "plt.ylim(-0.1, 1.1)\n", |
| 94 | + "plt.xlabel(\"X1\"); plt.ylabel(\"X2\")\n", |
| 95 | + "plt.title(\"Policy Regions with Treated vs Untreated\")\n", |
| 96 | + "plt.legend()\n", |
| 97 | + "plt.tight_layout()\n", |
| 98 | + "plt.show()\n" |
| 99 | + ] |
35 | 100 | }
|
36 | 101 | ],
|
37 | 102 | "metadata": {
|
|
0 commit comments