diff --git a/1-setup.ipynb b/1-setup.ipynb deleted file mode 100644 index 4569e36..0000000 --- a/1-setup.ipynb +++ /dev/null @@ -1,391 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "2c160c9c-a445-493e-9948-7ba507c606fb", - "metadata": {}, - "source": [ - "# Running bash commands from your notebook\n", - "\n", - "First, let's install all the dependencies. \n", - "\n", - "You can directly run bash commands in your notebook, by either prefixing your commands with an exclamation mark `!`:\n", - "```ipython\n", - "[1] !echo \"this is a bash command\"\n", - "this is a bash command\n", - "\n", - "[2] !ls\n", - "/home/user/git_repos/FNO_workshop\n", - "```\n", - "\n", - "or by starting your cell with the `%%bash` ipython magic. \n", - "\n", - "Let's see a simple example:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "24e20734-97e5-4295-9952-d67ac36b63a0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Couldn't find program: 'bash'\n" - ] - } - ], - "source": [ - "%%bash\n", - "\n", - "for var in hello world\n", - "do\n", - " echo ${var} \n", - "done" - ] - }, - { - "cell_type": "markdown", - "id": "5b47acb6-a558-40bf-bd76-872941fdf879", - "metadata": {}, - "source": [ - "# Installing the dependencies\n", - "\n", - "Now, let's install the dependencies." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "bcce1c3e-b4d7-44ea-8b98-bce1f04182cf", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Couldn't find program: 'bash'\n" - ] - } - ], - "source": [ - "%%bash \n", - "\n", - "target_folder='./temp'\n", - "[ -d ${target_folder} ] || mkdir -p ${target_folder}\n", - "cd temp\n", - "\n", - "git clone https://github.com/tensorly/tensorly \n", - "cd tensorly\n", - "python -m pip install -e .\n", - "cd ..\n", - "\n", - "git clone https://github.com/tensorly/torch\n", - "cd torch\n", - "python -m pip install -e .\n", - "cd ..\n", - "\n", - "git clone https://github.com/NeuralOperator/neuraloperator\n", - "cd neuraloperator\n", - "python -m pip install -e ." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "e0bb548e-6e98-4fac-935e-52a8115c4aac", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting gpustat\n", - " Downloading gpustat-1.0.0.tar.gz (90 kB)\n", - "Requirement already satisfied: six>=1.7 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gpustat) (1.16.0)\n", - "Collecting nvidia-ml-py<=11.495.46,>=11.450.129\n", - " Downloading nvidia_ml_py-11.495.46-py3-none-any.whl (25 kB)\n", - "Requirement already satisfied: psutil>=5.6.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gpustat) (5.8.0)\n", - "Collecting blessed>=1.17.1\n", - " Downloading blessed-1.20.0-py2.py3-none-any.whl (58 kB)\n", - "Collecting jinxed>=1.1.0\n", - " Downloading jinxed-1.2.0-py2.py3-none-any.whl (33 kB)\n", - "Requirement already satisfied: wcwidth>=0.1.4 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from blessed>=1.17.1->gpustat) (0.2.5)\n", - "Collecting ansicon\n", - " Downloading ansicon-1.89.0-py2.py3-none-any.whl (63 kB)\n", - "Building wheels for collected packages: gpustat\n", - " Building wheel for gpustat (setup.py): started\n", - " Building wheel for gpustat (setup.py): finished with status 'done'\n", - " Created wheel for gpustat: filename=gpustat-1.0.0-py3-none-any.whl size=19886 sha256=647135e0be6c489fa67d18d54e79c7dca544dfd5496efe4d20129d52a8c8803f\n", - " Stored in directory: c:\\users\\devzh\\appdata\\local\\pip\\cache\\wheels\\1b\\ed\\14\\0d513c962b25da841c42022cb5847c2ef835902c8563b8fb01\n", - "Successfully built gpustat\n", - "Installing collected packages: ansicon, jinxed, nvidia-ml-py, blessed, gpustat\n", - "Successfully installed ansicon-1.89.0 blessed-1.20.0 gpustat-1.0.0 jinxed-1.2.0 nvidia-ml-py-11.495.46\n", - "Collecting gdown\n", - " Downloading gdown-4.6.4-py3-none-any.whl (14 kB)\n", - "Requirement already satisfied: requests[socks] in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gdown) (2.26.0)\n", - "Requirement already satisfied: beautifulsoup4 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gdown) (4.10.0)\n", - "Requirement already satisfied: filelock in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gdown) (3.3.1)\n", - "Requirement already satisfied: six in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gdown) (1.16.0)\n", - "Requirement already satisfied: tqdm in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gdown) (4.62.3)\n", - "Requirement already satisfied: soupsieve>1.2 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from beautifulsoup4->gdown) (2.2.1)\n", - "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests[socks]->gdown) (2.0.4)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests[socks]->gdown) (1.26.7)\n", - "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests[socks]->gdown) (2021.10.8)\n", - "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests[socks]->gdown) (3.2)\n", - "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests[socks]->gdown) (1.7.1)\n", - "Requirement already satisfied: colorama in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from tqdm->gdown) (0.4.4)\n", - "Installing collected packages: gdown\n", - "Successfully installed gdown-4.6.4\n", - "Requirement already satisfied: opt-einsum in c:\\users\\devzh\\anaconda3\\lib\\site-packages (3.3.0)\n", - "Requirement already satisfied: numpy>=1.7 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from opt-einsum) (1.21.2)\n", - "Requirement already satisfied: h5py in c:\\users\\devzh\\anaconda3\\lib\\site-packages (3.6.0)\n", - "Requirement already satisfied: wandb in c:\\users\\devzh\\anaconda3\\lib\\site-packages (0.12.1)\n", - "Requirement already satisfied: ruamel.yaml in c:\\users\\devzh\\anaconda3\\lib\\site-packages (0.17.21)\n", - "Requirement already satisfied: zarr in c:\\users\\devzh\\anaconda3\\lib\\site-packages (2.14.1)\n", - "Requirement already satisfied: numpy>=1.14.5 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from h5py) (1.21.2)\n", - "Requirement already satisfied: Click!=8.0.0,>=7.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (8.0.3)\n", - "Requirement already satisfied: promise<3,>=2.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (2.3)\n", - "Requirement already satisfied: sentry-sdk>=1.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (1.3.1)\n", - "Requirement already satisfied: GitPython>=1.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (3.1.18)\n", - "Requirement already satisfied: docker-pycreds>=0.4.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (0.4.0)\n", - "Requirement already satisfied: protobuf>=3.12.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (3.17.3)\n", - "Requirement already satisfied: psutil>=5.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (5.8.0)\n", - "Requirement already satisfied: subprocess32>=3.5.3 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (3.5.4)\n", - "Requirement already satisfied: python-dateutil>=2.6.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (2.8.2)\n", - "Requirement already satisfied: shortuuid>=0.5.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (1.0.1)\n", - "Requirement already satisfied: requests<3,>=2.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (2.26.0)\n", - "Requirement already satisfied: configparser>=3.8.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (5.0.2)\n", - "Requirement already satisfied: PyYAML in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (6.0)\n", - "Requirement already satisfied: six>=1.13.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (1.16.0)\n", - "Requirement already satisfied: pathtools in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from wandb) (0.1.2)\n", - "Requirement already satisfied: ruamel.yaml.clib>=0.2.6 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from ruamel.yaml) (0.2.7)\n", - "Requirement already satisfied: asciitree in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from zarr) (0.3.3)\n", - "Requirement already satisfied: numcodecs>=0.10.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from zarr) (0.11.0)\n", - "Requirement already satisfied: fasteners in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from zarr) (0.18)\n", - "Requirement already satisfied: colorama in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from Click!=8.0.0,>=7.0->wandb) (0.4.4)\n", - "Requirement already satisfied: gitdb<5,>=4.0.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from GitPython>=1.0.0->wandb) (4.0.7)\n", - "Requirement already satisfied: smmap<5,>=3.0.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from gitdb<5,>=4.0.1->GitPython>=1.0.0->wandb) (4.0.0)\n", - "Requirement already satisfied: entrypoints in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from numcodecs>=0.10.0->zarr) (0.3)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (1.26.7)\n", - "Requirement already satisfied: charset-normalizer~=2.0.0 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2.0.4)\n", - "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (2021.10.8)\n", - "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\devzh\\anaconda3\\lib\\site-packages (from requests<3,>=2.0.0->wandb) (3.2)\n" - ] - } - ], - "source": [ - "!pip install gpustat\n", - "!pip install gdown\n", - "!pip install opt-einsum\n", - "!pip install h5py wandb ruamel.yaml zarr" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "f4ed3b9d-fffd-4d5d-852c-7dc95dad086f", - "metadata": {}, - "source": [ - "# Prepare data " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "3a2484ab-0f02-45c9-acce-cb0bbe803dbb", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import requests\n", - "import hashlib\n", - "url_dict = {\n", - " 'darcyflow-1':'https://caltech-pde-data.s3.us-west-2.amazonaws.com/piececonst_r241_N1024_smooth1.mat', \n", - " 'darcyflow-2': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/piececonst_r241_N1024_smooth2.mat', \n", - " 'Navier-Stokes': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/ns_V1e-3_N5000_T50.mat', \n", - " 'darcy-test-32': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_32.pt', \n", - " 'darcy-test-64': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_64.pt', \n", - " 'darcy-train-32': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_32.pt', \n", - " 'darcy-train-64': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_64.pt', \n", - " 'KF-Re100': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/KFvorticity_Re100_N50_T500.npy'\n", - "}\n", - "\n", - "chksum_dict = {\n", - " 'piececonst_r241_N1024_smooth1.mat': '5ab3edf67bb5fd6d49ebf308cd79ed70340106d1a18af8a8439d3e7fc8e82d21', \n", - " 'piececonst_r241_N1024_smooth2.mat': '51a818ed2e4f08752eea5d3f137f0e00271589c48297a46c641382a51eb80acf', \n", - " 'ns_V1e-3_N5000_T50.mat': '78b8d9e83d767dc7050fb8145ee7e7f11e2d18d325bff9abc7f108ec3292ee78', \n", - " 'darcy_train_64.pt': 'c05770239c91ebf093ea971e4d724008a49c9f21b5363fcf182e80499fae7fb4', \n", - " 'darcy_train_32.pt': 'b8d8095d3832ed67f55b4a8fcb1970618b4ca2c6fc91aee2fe49b9c9b2c071ae', \n", - " 'darcy_test_64.pt': '2220bb25c920109e9565a7fc07b675de16d124d563996f6e7256e2faa1fde24f', \n", - " 'darcy_test_32.pt': '65137910193a553295c26e3d8273761daa44766597f4b34cfb12299fc6e3f311', \n", - " 'KFvorticity_Re100_N50_T500.npy': '55f5af44a732a7843d631ace6384ac75c787d4fb36765b2e83ce1febb52d5463'\n", - "}\n", - "\n", - "def download_file(url, file_path):\n", - " with requests.get(url, stream=True) as r:\n", - " r.raise_for_status()\n", - " with open(file_path, 'wb') as f:\n", - " for chunk in r.iter_content(chunk_size=1024 * 1024 * 1024):\n", - " f.write(chunk)\n", - " print('Complete')\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "d36ba93d", - "metadata": {}, - "source": [ - "## Download Darcy datasets" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "70b2c9d0-990d-43fd-9a80-7af9dbc8dd64", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_64.pt...\n", - "Complete\n", - "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_32.pt...\n", - "Complete\n", - "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_64.pt...\n", - "Complete\n", - "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_32.pt...\n", - "Complete\n" - ] - } - ], - "source": [ - "data_root = 'data'\n", - "darcy_dir = os.path.join(data_root, 'darcy_flow')\n", - "os.makedirs(darcy_dir, exist_ok=True)\n", - "\n", - "day1_data = ['darcy-train-64', 'darcy-train-32', 'darcy-test-64', 'darcy-test-32']\n", - "\n", - "for key in day1_data:\n", - " value = url_dict[key]\n", - " print(f'Downloading {value}...')\n", - " filename = os.path.basename(value)\n", - " save_path = os.path.join(darcy_dir, filename)\n", - " download_file(url=value, file_path=save_path)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "db98503a", - "metadata": {}, - "outputs": [], - "source": [ - "# verify data integrity\n", - "for data_file in os.listdir(darcy_dir):\n", - " data_path = os.path.join(darcy_dir, data_file)\n", - " with open(data_path, 'rb') as f:\n", - " data = f.read()\n", - " sha256 = hashlib.sha256(data).hexdigest()\n", - " if sha256 == chksum_dict[data_file]:\n", - " print(f'{data_file} verified!')\n", - " else:\n", - " print(f'{data_file} verfication failed!')" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "6a5cc551", - "metadata": {}, - "source": [ - "### Download KF datasets (2d NS)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "817f3d48", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/KFvorticity_Re100_N50_T500.npy to data\\kf\n", - "Complete\n" - ] - } - ], - "source": [ - "data_root = 'data'\n", - "kf_dir = os.path.join(data_root, 'kf')\n", - "os.makedirs(kf_dir, exist_ok=True)\n", - "\n", - "kf_data = ['KF-Re100']\n", - "for key in kf_data:\n", - " value = url_dict[key]\n", - " print(f'Downloading {value} to {kf_dir}')\n", - " filename = os.path.basename(value)\n", - " save_path = os.path.join(kf_dir, filename)\n", - " download_file(url=value, file_path=save_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "73acfd3d-23d4-4f02-9bc7-167438ac2de4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "KFvorticity_Re100_N50_T500.npy verified!\n" - ] - } - ], - "source": [ - "for data_file in os.listdir(kf_dir):\n", - " data_path = os.path.join(kf_dir, data_file)\n", - " with open(data_path, 'rb') as f:\n", - " data = f.read()\n", - " sha256 = hashlib.sha256(data).hexdigest()\n", - " if sha256 == chksum_dict[data_file]:\n", - " print(f'{data_file} verified!')\n", - " else:\n", - " print(f'{data_file} verfication failed!')\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "base", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.8" - }, - "vscode": { - "interpreter": { - "hash": "95d4b27ba6bfea4a66eebe0e0159b214d32a94d313a7f4c98bd9b87f5ee37cbe" - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/2-intro_FNO.ipynb b/2-intro_FNO.ipynb deleted file mode 100644 index 1f6c68b..0000000 --- a/2-intro_FNO.ipynb +++ /dev/null @@ -1,472 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "59194c45-83c9-4a77-a1b0-185eca26afd5", - "metadata": {}, - "source": [ - "# Check the dependencies " - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "af7a5c4c-b3a5-4f32-aee9-55290566ff56", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tl.__version__='0.8.0'\n", - "tltorch.__version__='0.3.0'\n", - "no.__version__='0.1.0'\n" - ] - } - ], - "source": [ - "import tensorly as tl\n", - "import tltorch\n", - "import neuralop as no\n", - "\n", - "print(f'{tl.__version__=}')\n", - "print(f'{tltorch.__version__=}')\n", - "print(f'{no.__version__=}')" - ] - }, - { - "cell_type": "markdown", - "id": "a36bb3e2-c158-497c-babe-5eead700cbf1", - "metadata": { - "tags": [] - }, - "source": [ - "# FFT and Spectral Convolution\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "4efa0d7f-e39c-496e-891d-6b34c62fbd9d", - "metadata": {}, - "outputs": [], - "source": [ - "from neuralop.models.fno_block import FactorizedSpectralConv\n", - "from neuralop.models import TFNO2d\n", - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "2c8c3eb7-82e1-4df7-b6e6-3f34331637c4", - "metadata": {}, - "outputs": [], - "source": [ - "fourier_conv = FactorizedSpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),\n", - " factorization=None, implementation='reconstructed')" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "4eaf645a-b7f5-4dcc-b8de-fb388ccc9b26", - "metadata": {}, - "outputs": [], - "source": [ - "in_data = torch.randn((2, 3, 16, 16))" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "016b33e0-88a6-4215-99e0-19da4f8fd5f5", - "metadata": {}, - "outputs": [], - "source": [ - "out = fourier_conv(in_data)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "36d0f546-9fa9-4936-a6b6-19d7bde03639", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 10, 16, 16])" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "out.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "4936746b-5abb-4a8b-9e74-238502c65930", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "FactorizedSpectralConv(\n", - " (weight): ModuleList(\n", - " (0): ComplexDenseTensor(shape=torch.Size([3, 10, 2, 2]), rank=None)\n", - " (1): ComplexDenseTensor(shape=torch.Size([3, 10, 2, 2]), rank=None)\n", - " )\n", - ")" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fourier_conv" - ] - }, - { - "cell_type": "markdown", - "id": "a616d68d-677a-4e6f-abd5-9e631ebf7fb6", - "metadata": {}, - "source": [ - "The way the spectral convolution works is that it multiplies (complex) coefficients with (complex) weights, learned end-to-end." - ] - }, - { - "cell_type": "markdown", - "id": "0c8d9860-d43d-47f3-a6aa-c7ed4522684e", - "metadata": { - "tags": [] - }, - "source": [ - "# Tensorized Spectral Convolutions\n", - "\n", - "It is possible to express the weights of one or more layers as in factorized form, as a low-rank decomposition of the full weights.\n", - "\n", - "`neuralop` comes with support for tensorization out of the box, you can simply specify, e.g., to use a Tucker factorization, `factorization='tucker'`." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "b3f919de-97c2-4f0b-bb40-8e47cd2c1e0e", - "metadata": {}, - "outputs": [], - "source": [ - "fourier_conv = FactorizedSpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),\n", - " factorization='tucker', implementation='reconstructed')" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "91a7aa04-9cc3-4f8c-b34f-54fbc625b718", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "FactorizedSpectralConv(\n", - " (weight): ModuleList(\n", - " (0): ComplexTuckerTensor(shape=(3, 10, 2, 2), rank=(1, 5, 1, 1))\n", - " (1): ComplexTuckerTensor(shape=(3, 10, 2, 2), rank=(1, 5, 1, 1))\n", - " )\n", - ")" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fourier_conv" - ] - }, - { - "cell_type": "markdown", - "id": "f8df876d-72e1-40cd-9a86-330a57dc0e8d", - "metadata": {}, - "source": [ - "## Efficient forward pass\n", - "\n", - "When factorizing the weights, have two main options during the forward pass:\n", - "1. reconstruct the full weights and use that for the forward pass \n", - "2. contract the input directly with the factorized weights to predict the output\n", - "\n", - "When the factorized weights are small, the second option can lead to large speedups or memory reduction, particularly when coupled with checkpointing. \n", - "\n", - "In `neuralop`, you can use those simply by specifying `implementation='reconstructed'` or `implementation='factorized'`:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "a0667a6b-1efe-47e0-8908-29c5fb0cf45a", - "metadata": {}, - "outputs": [], - "source": [ - "fourier_conv = FactorizedSpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),\n", - " factorization='tucker', implementation='factorized')" - ] - }, - { - "cell_type": "markdown", - "id": "ec3ab24a-09fe-4864-b2ed-e96b54792e9f", - "metadata": {}, - "source": [ - "# Full Tensorized Fourier Neural Operator \n", - "\n", - "The full architecture is composed of \n", - "\n", - "i) a lifting layer taking the number of input channels and lifting that to the desired number of hidden channels\n", - "ii) a number of spectral convolutions, as shown above\n", - "iii) a projection layer projecting back from the number of hidden channels to the desired number of output channels\n" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "d51aec17-2cf4-40c4-9452-84a4b5259db6", - "metadata": {}, - "outputs": [], - "source": [ - "tfno = TFNO2d(n_modes_height=16, n_modes_width=16, hidden_channels=16, \n", - " factorization=None, skip='linear')" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "c87127e5-d24c-4096-be3a-8872a853a132", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TFNO2d(\n", - " (convs): FactorizedSpectralConv2d(\n", - " (weight): ModuleList(\n", - " (0): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (1): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (2): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (3): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (4): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (5): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (6): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (7): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " )\n", - " )\n", - " (fno_skips): ModuleList(\n", - " (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " )\n", - " (lifting): Lifting(\n", - " (fc): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (projection): Projection(\n", - " (fc1): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))\n", - " (fc2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - ")" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfno" - ] - }, - { - "cell_type": "markdown", - "id": "0e70efec-bf3c-48ac-b53a-59800055f1b9", - "metadata": {}, - "source": [ - "## Lifting layer\n", - "\n", - "Increasing the number of channels" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "1deead74-bd3d-4aa9-8d2c-cfd9ab0763d7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Lifting(\n", - " (fc): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))\n", - ")" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfno.lifting" - ] - }, - { - "cell_type": "markdown", - "id": "08844bac-9335-4ac4-afc8-f1d67c3e31bb", - "metadata": {}, - "source": [ - "## Spectral convolutions" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "f2bc28dc-1226-4ed3-b757-3c42357d276a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "FactorizedSpectralConv2d(\n", - " (weight): ModuleList(\n", - " (0): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (1): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (2): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (3): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (4): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (5): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (6): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " (7): ComplexDenseTensor(shape=torch.Size([16, 16, 8, 8]), rank=None)\n", - " )\n", - ")" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfno.convs" - ] - }, - { - "cell_type": "markdown", - "id": "1c7d9882-13db-447d-affd-07ef17256e1c", - "metadata": {}, - "source": [ - "## Skip connections: recovering non-periodicity\n", - "\n", - "Recall the FNO architecture has skip connections: the FFT transformation will loose non-periodic information that has to be reinjected through skip connections. These skip connections also help with learning.\n", - "\n", - "![FNO_layer](./images/fourier_layer.png)\n", - "\n", - "Here, linear layer (represented by weight W in the image). We can also use Identity skip (`skip='identity'`) or soft-gated connections (`skip='soft-gating'`)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "f063e3bf-34e5-4d7f-83f9-b3522aa6430b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ModuleList(\n", - " (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - ")" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfno.fno_skips" - ] - }, - { - "cell_type": "markdown", - "id": "070e930e-38b6-4d3c-b62a-3ca700294c99", - "metadata": {}, - "source": [ - "## Projection: going back to the target number of channels \n", - "\n", - "Finally, the projection layer takes the hidden dimension to projection_channels and to the actual number of output channels (here, 1)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "88344f47-a7e8-458e-9fbb-775804fbbaad", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Projection(\n", - " (fc1): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))\n", - " (fc2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))\n", - ")" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tfno.projection" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7aae1ab6-852c-4720-9b3b-5791c2b42872", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/2024_bootcamp_notebook.ipynb b/2024_bootcamp_notebook.ipynb new file mode 100644 index 0000000..2abfd8e --- /dev/null +++ b/2024_bootcamp_notebook.ipynb @@ -0,0 +1,1762 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5b47acb6-a558-40bf-bd76-872941fdf879", + "metadata": { + "id": "5b47acb6-a558-40bf-bd76-872941fdf879" + }, + "source": [ + "# Installing the dependencies\n", + "\n", + "Now, let's install the dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e0bb548e-6e98-4fac-935e-52a8115c4aac", + "metadata": { + "id": "e0bb548e-6e98-4fac-935e-52a8115c4aac", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ebfac54f-998f-4f29-f455-a41d88874bb3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting tensorly\n", + " Downloading tensorly-0.8.1-py3-none-any.whl (229 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m229.7/229.7 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from tensorly) (1.25.2)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from tensorly) (1.11.4)\n", + "Installing collected packages: tensorly\n", + "Successfully installed tensorly-0.8.1\n", + "Collecting torch-harmonics\n", + " Downloading torch_harmonics-0.6.5-py3-none-any.whl (63 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.6/63.6 kB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from torch-harmonics) (2.2.1+cu121)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torch-harmonics) (1.25.2)\n", + "Requirement already satisfied: triton in /usr/local/lib/python3.10/dist-packages (from torch-harmonics) (2.2.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->torch-harmonics) (3.13.4)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch->torch-harmonics) (4.11.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->torch-harmonics) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->torch-harmonics) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->torch-harmonics) (3.1.3)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->torch-harmonics) (2023.6.0)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->torch-harmonics)\n", + " Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n", + "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->torch-harmonics)\n", + " Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n", + "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->torch-harmonics)\n", + " Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n", + "Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->torch-harmonics)\n", + " Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n", + "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->torch-harmonics)\n", + " Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n", + "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->torch-harmonics)\n", + " Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n", + "Collecting nvidia-curand-cu12==10.3.2.106 (from torch->torch-harmonics)\n", + " Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n", + "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch->torch-harmonics)\n", + " Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n", + "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch->torch-harmonics)\n", + " Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n", + "Collecting nvidia-nccl-cu12==2.19.3 (from torch->torch-harmonics)\n", + " Using cached nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)\n", + "Collecting nvidia-nvtx-cu12==12.1.105 (from torch->torch-harmonics)\n", + " Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n", + "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch->torch-harmonics)\n", + " Using cached nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->torch-harmonics) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->torch-harmonics) (1.3.0)\n", + "Installing collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch-harmonics\n", + "Successfully installed nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.4.127 nvidia-nvtx-cu12-12.1.105 torch-harmonics-0.6.5\n", + "Collecting neuraloperator\n", + " Downloading neuraloperator-0.3.0-py3-none-any.whl (4.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.0/4.0 MB\u001b[0m \u001b[31m31.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from neuraloperator) (1.25.2)\n", + "Collecting configmypy (from neuraloperator)\n", + " Downloading configmypy-0.1.0-py3-none-any.whl (11 kB)\n", + "Requirement already satisfied: pytest in /usr/local/lib/python3.10/dist-packages (from neuraloperator) (7.4.4)\n", + "Collecting black (from neuraloperator)\n", + " Downloading black-24.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m55.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: tensorly in /usr/local/lib/python3.10/dist-packages (from neuraloperator) (0.8.1)\n", + "Collecting tensorly-torch (from neuraloperator)\n", + " Downloading tensorly_torch-0.4.0-py3-none-any.whl (59 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.1/59.1 kB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from neuraloperator) (3.3.0)\n", + "Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from black->neuraloperator) (8.1.7)\n", + "Collecting mypy-extensions>=0.4.3 (from black->neuraloperator)\n", + " Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n", + "Requirement already satisfied: packaging>=22.0 in /usr/local/lib/python3.10/dist-packages (from black->neuraloperator) (24.0)\n", + "Collecting pathspec>=0.9.0 (from black->neuraloperator)\n", + " Downloading pathspec-0.12.1-py3-none-any.whl (31 kB)\n", + "Requirement already satisfied: platformdirs>=2 in /usr/local/lib/python3.10/dist-packages (from black->neuraloperator) (4.2.0)\n", + "Requirement already satisfied: tomli>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from black->neuraloperator) (2.0.1)\n", + "Requirement already satisfied: typing-extensions>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from black->neuraloperator) (4.11.0)\n", + "Collecting pytest-mock (from configmypy->neuraloperator)\n", + " Downloading pytest_mock-3.14.0-py3-none-any.whl (9.9 kB)\n", + "Collecting ruamel.yaml (from configmypy->neuraloperator)\n", + " Downloading ruamel.yaml-0.18.6-py3-none-any.whl (117 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m117.8/117.8 kB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: iniconfig in /usr/local/lib/python3.10/dist-packages (from pytest->neuraloperator) (2.0.0)\n", + "Requirement already satisfied: pluggy<2.0,>=0.12 in /usr/local/lib/python3.10/dist-packages (from pytest->neuraloperator) (1.4.0)\n", + "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /usr/local/lib/python3.10/dist-packages (from pytest->neuraloperator) (1.2.1)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from tensorly->neuraloperator) (1.11.4)\n", + "Collecting nose (from tensorly-torch->neuraloperator)\n", + " Downloading nose-1.3.7-py3-none-any.whl (154 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m154.7/154.7 kB\u001b[0m \u001b[31m20.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting ruamel.yaml.clib>=0.2.7 (from ruamel.yaml->configmypy->neuraloperator)\n", + " Downloading ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (526 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m526.7/526.7 kB\u001b[0m \u001b[31m47.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: nose, ruamel.yaml.clib, pathspec, mypy-extensions, tensorly-torch, ruamel.yaml, pytest-mock, black, configmypy, neuraloperator\n", + "Successfully installed black-24.4.1 configmypy-0.1.0 mypy-extensions-1.0.0 neuraloperator-0.3.0 nose-1.3.7 pathspec-0.12.1 pytest-mock-3.14.0 ruamel.yaml-0.18.6 ruamel.yaml.clib-0.2.8 tensorly-torch-0.4.0\n", + "Collecting gpustat\n", + " Downloading gpustat-1.1.1.tar.gz (98 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.1/98.1 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Collecting nvidia-ml-py>=11.450.129 (from gpustat)\n", + " Downloading nvidia_ml_py-12.535.133-py3-none-any.whl (37 kB)\n", + "Requirement already satisfied: psutil>=5.6.0 in /usr/local/lib/python3.10/dist-packages (from gpustat) (5.9.5)\n", + "Collecting blessed>=1.17.1 (from gpustat)\n", + " Downloading blessed-1.20.0-py2.py3-none-any.whl (58 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.4/58.4 kB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: wcwidth>=0.1.4 in /usr/local/lib/python3.10/dist-packages (from blessed>=1.17.1->gpustat) (0.2.13)\n", + "Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from blessed>=1.17.1->gpustat) (1.16.0)\n", + "Building wheels for collected packages: gpustat\n", + " Building wheel for gpustat (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for gpustat: filename=gpustat-1.1.1-py3-none-any.whl size=26532 sha256=eb0778d1f88000bcd0b5d52ae02e8eaae0b648c959bd0b27238ba2b6c0a3e66d\n", + " Stored in directory: /root/.cache/pip/wheels/ec/d7/80/a71ba3540900e1f276bcae685efd8e590c810d2108b95f1e47\n", + "Successfully built gpustat\n", + "Installing collected packages: nvidia-ml-py, blessed, gpustat\n", + "Successfully installed blessed-1.20.0 gpustat-1.1.1 nvidia-ml-py-12.535.133\n", + "Requirement already satisfied: gdown in /usr/local/lib/python3.10/dist-packages (5.1.0)\n", + "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from gdown) (4.12.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from gdown) (3.13.4)\n", + "Requirement already satisfied: requests[socks] in /usr/local/lib/python3.10/dist-packages (from gdown) (2.31.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from gdown) (4.66.2)\n", + "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->gdown) (2.5)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (2024.2.2)\n", + "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (1.7.1)\n", + "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (3.3.0)\n", + "Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.10/dist-packages (from opt-einsum) (1.25.2)\n", + "Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (3.9.0)\n", + "Collecting wandb\n", + " Downloading wandb-0.16.6-py3-none-any.whl (2.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m22.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: ruamel.yaml in /usr/local/lib/python3.10/dist-packages (0.18.6)\n", + "Collecting zarr\n", + " Downloading zarr-2.17.2-py3-none-any.whl (208 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m208.5/208.5 kB\u001b[0m \u001b[31m22.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.10/dist-packages (from h5py) (1.25.2)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n", + "Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)\n", + " Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.3/207.3 kB\u001b[0m \u001b[31m23.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.31.0)\n", + "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (5.9.5)\n", + "Collecting sentry-sdk>=1.0.0 (from wandb)\n", + " Downloading sentry_sdk-1.45.0-py2.py3-none-any.whl (267 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m267.1/267.1 kB\u001b[0m \u001b[31m24.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting docker-pycreds>=0.4.0 (from wandb)\n", + " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", + "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from wandb) (6.0.1)\n", + "Collecting setproctitle (from wandb)\n", + " Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n", + "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n", + "Requirement already satisfied: ruamel.yaml.clib>=0.2.7 in /usr/local/lib/python3.10/dist-packages (from ruamel.yaml) (0.2.8)\n", + "Collecting asciitree (from zarr)\n", + " Downloading asciitree-0.3.3.tar.gz (4.0 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Collecting numcodecs>=0.10.0 (from zarr)\n", + " Downloading numcodecs-0.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.7/7.7 MB\u001b[0m \u001b[31m86.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting fasteners (from zarr)\n", + " Downloading fasteners-0.19-py3-none-any.whl (18 kB)\n", + "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", + "Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)\n", + " Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2024.2.2)\n", + "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)\n", + " Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n", + "Building wheels for collected packages: asciitree\n", + " Building wheel for asciitree (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for asciitree: filename=asciitree-0.3.3-py3-none-any.whl size=5034 sha256=9b360729437a21a98c813a051c4706c88963112a53052d30dc1c010cd033e45a\n", + " Stored in directory: /root/.cache/pip/wheels/7f/4e/be/1171b40f43b918087657ec57cf3b81fa1a2e027d8755baa184\n", + "Successfully built asciitree\n", + "Installing collected packages: asciitree, smmap, setproctitle, sentry-sdk, numcodecs, fasteners, docker-pycreds, zarr, gitdb, GitPython, wandb\n", + "Successfully installed GitPython-3.1.43 asciitree-0.3.3 docker-pycreds-0.4.0 fasteners-0.19 gitdb-4.0.11 numcodecs-0.12.1 sentry-sdk-1.45.0 setproctitle-1.3.3 smmap-5.0.1 wandb-0.16.6 zarr-2.17.2\n" + ] + } + ], + "source": [ + "!pip install tensorly\n", + "!pip install torch-harmonics\n", + "!pip install neuraloperator\n", + "!pip install gpustat\n", + "!pip install gdown\n", + "!pip install opt-einsum\n", + "!pip install h5py wandb ruamel.yaml zarr" + ] + }, + { + "cell_type": "markdown", + "id": "f4ed3b9d-fffd-4d5d-852c-7dc95dad086f", + "metadata": { + "id": "f4ed3b9d-fffd-4d5d-852c-7dc95dad086f" + }, + "source": [ + "# Prepare data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3a2484ab-0f02-45c9-acce-cb0bbe803dbb", + "metadata": { + "id": "3a2484ab-0f02-45c9-acce-cb0bbe803dbb" + }, + "outputs": [], + "source": [ + "import os\n", + "import requests\n", + "import hashlib\n", + "url_dict = {\n", + " 'darcyflow-1':'https://caltech-pde-data.s3.us-west-2.amazonaws.com/piececonst_r241_N1024_smooth1.mat',\n", + " 'darcyflow-2': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/piececonst_r241_N1024_smooth2.mat',\n", + " 'Navier-Stokes': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/ns_V1e-3_N5000_T50.mat',\n", + " 'darcy-test-32': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_32.pt',\n", + " 'darcy-test-64': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_64.pt',\n", + " 'darcy-train-32': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_32.pt',\n", + " 'darcy-train-64': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_64.pt',\n", + " 'KF-Re100': 'https://caltech-pde-data.s3.us-west-2.amazonaws.com/KFvorticity_Re100_N50_T500.npy'\n", + "}\n", + "\n", + "chksum_dict = {\n", + " 'piececonst_r241_N1024_smooth1.mat': '5ab3edf67bb5fd6d49ebf308cd79ed70340106d1a18af8a8439d3e7fc8e82d21',\n", + " 'piececonst_r241_N1024_smooth2.mat': '51a818ed2e4f08752eea5d3f137f0e00271589c48297a46c641382a51eb80acf',\n", + " 'ns_V1e-3_N5000_T50.mat': '78b8d9e83d767dc7050fb8145ee7e7f11e2d18d325bff9abc7f108ec3292ee78',\n", + " 'darcy_train_64.pt': 'c05770239c91ebf093ea971e4d724008a49c9f21b5363fcf182e80499fae7fb4',\n", + " 'darcy_train_32.pt': 'b8d8095d3832ed67f55b4a8fcb1970618b4ca2c6fc91aee2fe49b9c9b2c071ae',\n", + " 'darcy_test_64.pt': '2220bb25c920109e9565a7fc07b675de16d124d563996f6e7256e2faa1fde24f',\n", + " 'darcy_test_32.pt': '65137910193a553295c26e3d8273761daa44766597f4b34cfb12299fc6e3f311',\n", + " 'KFvorticity_Re100_N50_T500.npy': '55f5af44a732a7843d631ace6384ac75c787d4fb36765b2e83ce1febb52d5463'\n", + "}\n", + "\n", + "def download_file(url, file_path):\n", + " with requests.get(url, stream=True) as r:\n", + " r.raise_for_status()\n", + " with open(file_path, 'wb') as f:\n", + " for chunk in r.iter_content(chunk_size=1024 * 1024 * 1024):\n", + " f.write(chunk)\n", + " print('Complete')\n" + ] + }, + { + "cell_type": "markdown", + "id": "d36ba93d", + "metadata": { + "id": "d36ba93d" + }, + "source": [ + "## Download Darcy datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "70b2c9d0-990d-43fd-9a80-7af9dbc8dd64", + "metadata": { + "id": "70b2c9d0-990d-43fd-9a80-7af9dbc8dd64", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "b5a2cfb6-0a34-44db-b466-a694c5d62c92" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_64.pt...\n", + "Complete\n", + "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_train_32.pt...\n", + "Complete\n", + "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_64.pt...\n", + "Complete\n", + "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/darcy_test_32.pt...\n", + "Complete\n" + ] + } + ], + "source": [ + "data_root = 'data'\n", + "darcy_dir = os.path.join(data_root, 'darcy_flow')\n", + "os.makedirs(darcy_dir, exist_ok=True)\n", + "\n", + "day1_data = ['darcy-train-64', 'darcy-train-32', 'darcy-test-64', 'darcy-test-32']\n", + "\n", + "for key in day1_data:\n", + " value = url_dict[key]\n", + " print(f'Downloading {value}...')\n", + " filename = os.path.basename(value)\n", + " save_path = os.path.join(darcy_dir, filename)\n", + " download_file(url=value, file_path=save_path)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "db98503a", + "metadata": { + "id": "db98503a", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "77095ee6-c06f-47b6-c95f-96a437ab165b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "darcy_test_64.pt verified!\n", + "darcy_test_32.pt verified!\n", + "darcy_train_32.pt verified!\n", + "darcy_train_64.pt verified!\n" + ] + } + ], + "source": [ + "# verify data integrity\n", + "for data_file in os.listdir(darcy_dir):\n", + " data_path = os.path.join(darcy_dir, data_file)\n", + " with open(data_path, 'rb') as f:\n", + " data = f.read()\n", + " sha256 = hashlib.sha256(data).hexdigest()\n", + " if sha256 == chksum_dict[data_file]:\n", + " print(f'{data_file} verified!')\n", + " else:\n", + " print(f'{data_file} verfication failed!')" + ] + }, + { + "cell_type": "markdown", + "id": "6a5cc551", + "metadata": { + "id": "6a5cc551" + }, + "source": [ + "### Download KF datasets (2d NS)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "817f3d48", + "metadata": { + "id": "817f3d48", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "784cc542-ee78-4181-f2c6-e331d1857f43" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading https://caltech-pde-data.s3.us-west-2.amazonaws.com/KFvorticity_Re100_N50_T500.npy to data/kf\n", + "Complete\n" + ] + } + ], + "source": [ + "data_root = 'data'\n", + "kf_dir = os.path.join(data_root, 'kf')\n", + "os.makedirs(kf_dir, exist_ok=True)\n", + "\n", + "kf_data = ['KF-Re100']\n", + "for key in kf_data:\n", + " value = url_dict[key]\n", + " print(f'Downloading {value} to {kf_dir}')\n", + " filename = os.path.basename(value)\n", + " save_path = os.path.join(kf_dir, filename)\n", + " download_file(url=value, file_path=save_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "73acfd3d-23d4-4f02-9bc7-167438ac2de4", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "73acfd3d-23d4-4f02-9bc7-167438ac2de4", + "outputId": "f63b45c9-5218-4f03-8041-e8d05924361a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "KFvorticity_Re100_N50_T500.npy verified!\n" + ] + } + ], + "source": [ + "for data_file in os.listdir(kf_dir):\n", + " data_path = os.path.join(kf_dir, data_file)\n", + " with open(data_path, 'rb') as f:\n", + " data = f.read()\n", + " sha256 = hashlib.sha256(data).hexdigest()\n", + " if sha256 == chksum_dict[data_file]:\n", + " print(f'{data_file} verified!')\n", + " else:\n", + " print(f'{data_file} verfication failed!')\n" + ] + }, + { + "cell_type": "markdown", + "id": "59194c45-83c9-4a77-a1b0-185eca26afd5", + "metadata": { + "id": "59194c45-83c9-4a77-a1b0-185eca26afd5" + }, + "source": [ + "# Check the dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "af7a5c4c-b3a5-4f32-aee9-55290566ff56", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "af7a5c4c-b3a5-4f32-aee9-55290566ff56", + "outputId": "deb7b49a-b7ed-49c0-98cf-40a2370b99ee" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tl.__version__='0.8.1'\n", + "no.__version__='0.3.0'\n" + ] + } + ], + "source": [ + "import tensorly as tl\n", + "import neuralop as no\n", + "\n", + "print(f'{tl.__version__=}')\n", + "# print(f'{tltorch.__version__=}')\n", + "print(f'{no.__version__=}')" + ] + }, + { + "cell_type": "markdown", + "id": "a36bb3e2-c158-497c-babe-5eead700cbf1", + "metadata": { + "id": "a36bb3e2-c158-497c-babe-5eead700cbf1", + "tags": [] + }, + "source": [ + "# FFT and Spectral Convolution\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4efa0d7f-e39c-496e-891d-6b34c62fbd9d", + "metadata": { + "id": "4efa0d7f-e39c-496e-891d-6b34c62fbd9d" + }, + "outputs": [], + "source": [ + "from neuralop.layers.spectral_convolution import SpectralConv\n", + "from neuralop.models import TFNO2d\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2c8c3eb7-82e1-4df7-b6e6-3f34331637c4", + "metadata": { + "id": "2c8c3eb7-82e1-4df7-b6e6-3f34331637c4" + }, + "outputs": [], + "source": [ + "fourier_conv = SpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),\n", + " factorization='tucker', implementation='reconstructed')" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "4eaf645a-b7f5-4dcc-b8de-fb388ccc9b26", + "metadata": { + "id": "4eaf645a-b7f5-4dcc-b8de-fb388ccc9b26" + }, + "outputs": [], + "source": [ + "in_data = torch.randn((2, 3, 16, 16))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "016b33e0-88a6-4215-99e0-19da4f8fd5f5", + "metadata": { + "id": "016b33e0-88a6-4215-99e0-19da4f8fd5f5" + }, + "outputs": [], + "source": [ + "out = fourier_conv(in_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "36d0f546-9fa9-4936-a6b6-19d7bde03639", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "36d0f546-9fa9-4936-a6b6-19d7bde03639", + "outputId": "bc07ed63-3855-40e3-a30b-9abfef5d6252" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([2, 10, 16, 16])" + ] + }, + "metadata": {}, + "execution_count": 12 + } + ], + "source": [ + "out.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "4936746b-5abb-4a8b-9e74-238502c65930", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4936746b-5abb-4a8b-9e74-238502c65930", + "outputId": "123c2655-5e96-46c9-9e26-50c2578736e1" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "SpectralConv(\n", + " (weight): ModuleList(\n", + " (0): ComplexTuckerTensor(shape=(3, 10, 4, 3), rank=(2, 7, 3, 2))\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 13 + } + ], + "source": [ + "fourier_conv" + ] + }, + { + "cell_type": "markdown", + "id": "a616d68d-677a-4e6f-abd5-9e631ebf7fb6", + "metadata": { + "id": "a616d68d-677a-4e6f-abd5-9e631ebf7fb6" + }, + "source": [ + "The way the spectral convolution works is that it multiplies (complex) coefficients with (complex) weights, learned end-to-end." + ] + }, + { + "cell_type": "markdown", + "id": "0c8d9860-d43d-47f3-a6aa-c7ed4522684e", + "metadata": { + "id": "0c8d9860-d43d-47f3-a6aa-c7ed4522684e", + "tags": [] + }, + "source": [ + "# Tensorized Spectral Convolutions\n", + "\n", + "It is possible to express the weights of one or more layers as in factorized form, as a low-rank decomposition of the full weights.\n", + "\n", + "`neuralop` comes with support for tensorization out of the box, you can simply specify, e.g., to use a Tucker factorization, `factorization='tucker'`." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b3f919de-97c2-4f0b-bb40-8e47cd2c1e0e", + "metadata": { + "id": "b3f919de-97c2-4f0b-bb40-8e47cd2c1e0e" + }, + "outputs": [], + "source": [ + "fourier_conv = SpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),\n", + " factorization='tucker', implementation='reconstructed')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "91a7aa04-9cc3-4f8c-b34f-54fbc625b718", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "91a7aa04-9cc3-4f8c-b34f-54fbc625b718", + "outputId": "cc0f8a05-2bb1-41f6-bc8d-56d9d5c19592" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "SpectralConv(\n", + " (weight): ModuleList(\n", + " (0): ComplexTuckerTensor(shape=(3, 10, 4, 3), rank=(2, 7, 3, 2))\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 15 + } + ], + "source": [ + "fourier_conv" + ] + }, + { + "cell_type": "markdown", + "id": "f8df876d-72e1-40cd-9a86-330a57dc0e8d", + "metadata": { + "id": "f8df876d-72e1-40cd-9a86-330a57dc0e8d" + }, + "source": [ + "## Efficient forward pass\n", + "\n", + "When factorizing the weights, have two main options during the forward pass:\n", + "1. reconstruct the full weights and use that for the forward pass\n", + "2. contract the input directly with the factorized weights to predict the output\n", + "\n", + "When the factorized weights are small, the second option can lead to large speedups or memory reduction, particularly when coupled with checkpointing.\n", + "\n", + "In `neuralop`, you can use those simply by specifying `implementation='reconstructed'` or `implementation='factorized'`:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a0667a6b-1efe-47e0-8908-29c5fb0cf45a", + "metadata": { + "id": "a0667a6b-1efe-47e0-8908-29c5fb0cf45a" + }, + "outputs": [], + "source": [ + "fourier_conv = SpectralConv(in_channels=3, out_channels=10, n_modes=(4, 4),\n", + " factorization='tucker', implementation='factorized')" + ] + }, + { + "cell_type": "markdown", + "id": "ec3ab24a-09fe-4864-b2ed-e96b54792e9f", + "metadata": { + "id": "ec3ab24a-09fe-4864-b2ed-e96b54792e9f" + }, + "source": [ + "# Full Tensorized Fourier Neural Operator\n", + "\n", + "The full architecture is composed of\n", + "\n", + "i) a lifting layer taking the number of input channels and lifting that to the desired number of hidden channels\n", + "ii) a number of spectral convolutions, as shown above\n", + "iii) a projection layer projecting back from the number of hidden channels to the desired number of output channels\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "d51aec17-2cf4-40c4-9452-84a4b5259db6", + "metadata": { + "id": "d51aec17-2cf4-40c4-9452-84a4b5259db6" + }, + "outputs": [], + "source": [ + "tfno = TFNO2d(n_modes_height=16, n_modes_width=16, hidden_channels=16,\n", + " factorization=None, skip='linear')" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "c87127e5-d24c-4096-be3a-8872a853a132", + "metadata": { + "id": "c87127e5-d24c-4096-be3a-8872a853a132", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "fb1db21e-b9e7-4f9e-904c-b1d50c73039d" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "TFNO2d(\n", + " (fno_blocks): FNOBlocks(\n", + " (convs): SpectralConv(\n", + " (weight): ModuleList(\n", + " (0-3): 4 x ComplexDenseTensor(shape=torch.Size([16, 16, 16, 9]), rank=None)\n", + " )\n", + " )\n", + " (fno_skips): ModuleList(\n", + " (0-3): 4 x Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (lifting): MLP(\n", + " (fcs): ModuleList(\n", + " (0): Conv2d(3, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Conv2d(256, 16, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " (projection): MLP(\n", + " (fcs): ModuleList(\n", + " (0): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 18 + } + ], + "source": [ + "tfno" + ] + }, + { + "cell_type": "markdown", + "id": "0e70efec-bf3c-48ac-b53a-59800055f1b9", + "metadata": { + "id": "0e70efec-bf3c-48ac-b53a-59800055f1b9" + }, + "source": [ + "## Lifting layer\n", + "\n", + "Increasing the number of channels" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "1deead74-bd3d-4aa9-8d2c-cfd9ab0763d7", + "metadata": { + "id": "1deead74-bd3d-4aa9-8d2c-cfd9ab0763d7", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "a676bbb8-9aff-45dc-b0e1-bc2fcd39404d" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "MLP(\n", + " (fcs): ModuleList(\n", + " (0): Conv2d(3, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Conv2d(256, 16, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 19 + } + ], + "source": [ + "tfno.lifting" + ] + }, + { + "cell_type": "markdown", + "id": "08844bac-9335-4ac4-afc8-f1d67c3e31bb", + "metadata": { + "id": "08844bac-9335-4ac4-afc8-f1d67c3e31bb" + }, + "source": [ + "## Spectral convolutions" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "f2bc28dc-1226-4ed3-b757-3c42357d276a", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "f2bc28dc-1226-4ed3-b757-3c42357d276a", + "outputId": "92f4c42f-301d-4741-ab65-9d4960877fb5" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "SpectralConv(\n", + " (weight): ModuleList(\n", + " (0-3): 4 x ComplexDenseTensor(shape=torch.Size([16, 16, 16, 9]), rank=None)\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ], + "source": [ + "tfno.fno_blocks.convs" + ] + }, + { + "cell_type": "markdown", + "id": "1c7d9882-13db-447d-affd-07ef17256e1c", + "metadata": { + "id": "1c7d9882-13db-447d-affd-07ef17256e1c" + }, + "source": [ + "## Skip connections: recovering non-periodicity\n", + "\n", + "Recall the FNO architecture has skip connections: the FFT transformation will loose non-periodic information that has to be reinjected through skip connections. These skip connections also help with learning.\n", + "\n", + "![FNO_layer](./images/fourier_layer.png)\n", + "\n", + "Here, linear layer (represented by weight W in the image). We can also use Identity skip (`skip='identity'`) or soft-gated connections (`skip='soft-gating'`)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "f063e3bf-34e5-4d7f-83f9-b3522aa6430b", + "metadata": { + "id": "f063e3bf-34e5-4d7f-83f9-b3522aa6430b", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "8bcbf0a3-473f-4319-aac5-d6b36907537f" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "ModuleList(\n", + " (0-3): 4 x Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 21 + } + ], + "source": [ + "tfno.fno_blocks.fno_skips" + ] + }, + { + "cell_type": "markdown", + "id": "070e930e-38b6-4d3c-b62a-3ca700294c99", + "metadata": { + "id": "070e930e-38b6-4d3c-b62a-3ca700294c99" + }, + "source": [ + "## Projection: going back to the target number of channels\n", + "\n", + "Finally, the projection layer takes the hidden dimension to projection_channels and to the actual number of output channels (here, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "88344f47-a7e8-458e-9fbb-775804fbbaad", + "metadata": { + "id": "88344f47-a7e8-458e-9fbb-775804fbbaad", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "3a4e8edc-5b22-46ce-e28b-9ebc25913887" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "MLP(\n", + " (fcs): ModuleList(\n", + " (0): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 22 + } + ], + "source": [ + "tfno.projection" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "4df7dcda-a364-4255-9339-a9a09c2a5e34", + "metadata": { + "id": "4df7dcda-a364-4255-9339-a9a09c2a5e34" + }, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from neuralop.datasets import load_darcy_pt" + ] + }, + { + "cell_type": "markdown", + "id": "ff12d431-bde9-4eba-906b-d0faea8c49fb", + "metadata": { + "id": "ff12d431-bde9-4eba-906b-d0faea8c49fb" + }, + "source": [ + "# Load the data" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "aa9c49f5-878b-4cac-9a35-b9dc53085d11", + "metadata": { + "id": "aa9c49f5-878b-4cac-9a35-b9dc53085d11" + }, + "outputs": [], + "source": [ + "data_path=darcy_dir" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "f40e8c0a-c031-457b-863c-c728de7d1b80", + "metadata": { + "id": "f40e8c0a-c031-457b-863c-c728de7d1b80" + }, + "outputs": [], + "source": [ + "train_loader, test_loaders, data_processor = load_darcy_pt(data_path, n_train=100, n_tests=[10],\n", + " batch_size=3, test_batch_sizes=[3],\n", + " test_resolutions=[32], train_resolution=32)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "2f29d90a-4fc7-4b83-8ed4-f1bb4dce2574", + "metadata": { + "id": "2f29d90a-4fc7-4b83-8ed4-f1bb4dce2574" + }, + "outputs": [], + "source": [ + "train_dataset = train_loader.dataset" + ] + }, + { + "cell_type": "markdown", + "id": "21000189-ecac-42e2-b008-06eefa7b1710", + "metadata": { + "id": "21000189-ecac-42e2-b008-06eefa7b1710" + }, + "source": [ + "# Visualizing the data " + ] + }, + { + "cell_type": "markdown", + "id": "1cf47f09-1fb3-4667-9b04-a98b3ee8d08d", + "metadata": { + "id": "1cf47f09-1fb3-4667-9b04-a98b3ee8d08d" + }, + "source": [ + "The data is stored in a dictionary" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "b6a9aed5-6532-42ba-8131-0307460c960d", + "metadata": { + "id": "b6a9aed5-6532-42ba-8131-0307460c960d" + }, + "outputs": [], + "source": [ + "data = train_dataset[0]\n", + "x = data['x']\n", + "y = data['y']" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "9f475172-62b0-4ce3-8dce-a7d0d9dca9fb", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9f475172-62b0-4ce3-8dce-a7d0d9dca9fb", + "outputId": "e705cbc7-f6c6-460f-cab9-a4ef638ac11e" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([1, 32, 32])" + ] + }, + "metadata": {}, + "execution_count": 28 + } + ], + "source": [ + "x.shape" + ] + }, + { + "cell_type": "markdown", + "id": "7d7947ad-f98c-414b-8a12-64270988ad1f", + "metadata": { + "id": "7d7947ad-f98c-414b-8a12-64270988ad1f" + }, + "source": [ + "`x` is of shape (1, height, width).\n", + "\n", + "After preprocessing, it becomes shape (3, height, width).\n", + "\n", + "This is because, in addition to the binary input, we appended a positional encoding, so the model knows the location of each pixel.\n", + "\n", + "Let's check the actual data:" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "077ebd7d-883b-4300-b13d-ed88813a3be1", + "metadata": { + "id": "077ebd7d-883b-4300-b13d-ed88813a3be1" + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "814a044d-a52f-4cf2-aa1b-859370012af5", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 725 + }, + "id": "814a044d-a52f-4cf2-aa1b-859370012af5", + "outputId": "5f6664f9-9e53-46ca-e9a6-4760ee938362" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([3, 32, 32])\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + } + ], + "source": [ + "# Which sample to view\n", + "index = 10\n", + "\n", + "data = train_dataset[index]\n", + "# add a batch dimension to both x and y for preprocessor\n", + "data['x'] = data['x'].unsqueeze(0)\n", + "data['y'] = data['y'].unsqueeze(0)\n", + "\n", + "# preprocessing is normally done during training\n", + "data = data_processor.preprocess(data)\n", + "\n", + "# squeeze the batch dimension out of x and y for visualization\n", + "x = data['x'].squeeze(0)\n", + "print(x.shape)\n", + "y = data['y'].squeeze(0)\n", + "fig = plt.figure(figsize=(7, 7))\n", + "ax = fig.add_subplot(2, 2, 1)\n", + "ax.imshow(x[0], cmap='gray')\n", + "ax.set_title('input x')\n", + "ax = fig.add_subplot(2, 2, 2)\n", + "ax.imshow(y.squeeze())\n", + "ax.set_title('input y')\n", + "ax = fig.add_subplot(2, 2, 3)\n", + "ax.imshow(x[1])\n", + "ax.set_title('x: 1st pos embedding')\n", + "ax = fig.add_subplot(2, 2, 4)\n", + "ax.imshow(x[2])\n", + "ax.set_title('x: 2nd pos embedding')\n", + "fig.suptitle('Visualizing one input sample', y=0.98)\n", + "plt.tight_layout()\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "012a357f-8533-482c-823d-a4587c49e726", + "metadata": { + "id": "012a357f-8533-482c-823d-a4587c49e726" + }, + "outputs": [], + "source": [ + "import torch\n", + "import wandb\n", + "import sys\n", + "from configmypy import ConfigPipeline, YamlConfig, ArgparseConfig\n", + "from neuralop import get_model\n", + "from neuralop import Trainer\n", + "from neuralop.training import setup\n", + "from neuralop.datasets import load_darcy_pt\n", + "from neuralop.utils import get_wandb_api_key, count_model_params\n", + "from neuralop import LpLoss, H1Loss" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Downloading the configuration\n", + "\n", + "Go to: https://github.com/neuraloperator/neuraloperator/blob/main/config/darcy_config.yaml and download the config file. Make the following changes:\n", + "\n", + "\n", + "\n", + "1. Add an option called folder under the data key. The path should be the path to the darcy flow in your home directory. On colab it should just be: /content/data/darcy_flow.\n", + "2. Change the training resolution to 32 and the test resolutions to 32 and 64. It should just look like this below.\n", + "3. Under the tfno key, change the data channels to 1 as we have only the data in black and white and not RGB.\n", + "\n", + "![config-change.png]()\n" + ], + "metadata": { + "id": "IRWO3VHFGqZV" + }, + "id": "IRWO3VHFGqZV" + }, + { + "cell_type": "markdown", + "id": "9b6358b5-78d1-4baf-8928-6bb49b150680", + "metadata": { + "id": "9b6358b5-78d1-4baf-8928-6bb49b150680" + }, + "source": [ + "# Loading the configuration\n", + "\n", + "You can open the yaml file in config/darcy_config in the same folder as this notebook to inspect the parameters and change them." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "4503f065-4063-4a4f-b00f-06a7c3a88e27", + "metadata": { + "id": "4503f065-4063-4a4f-b00f-06a7c3a88e27" + }, + "outputs": [], + "source": [ + "# Read the configuration\n", + "config_name = 'default'\n", + "pipe = ConfigPipeline([YamlConfig('./darcy_config.yaml', config_name='default'),\n", + " ])\n", + "config = pipe.read_conf()\n", + "config_name = pipe.steps[-1].config_name" + ] + }, + { + "cell_type": "markdown", + "id": "e95d820d-9578-4ad7-80b4-05a5771f1642", + "metadata": { + "id": "e95d820d-9578-4ad7-80b4-05a5771f1642" + }, + "source": [ + "## Setup\n", + "\n", + "Here we just setup pytorch and print the configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "46066d9f-21a3-4aab-b6e1-f7f38e05f88b", + "metadata": { + "id": "46066d9f-21a3-4aab-b6e1-f7f38e05f88b" + }, + "outputs": [], + "source": [ + "# Set-up distributed communication, if using\n", + "device, is_logger = setup(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "26d599f9-6463-4056-9a4d-72c01d05298e", + "metadata": { + "id": "26d599f9-6463-4056-9a4d-72c01d05298e", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "bc3a1ac1-75f0-4205-b5d1-c42259122ef4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "###############################\n", + "##### CONFIGURATION #####\n", + "###############################\n", + "\n", + "Steps:\n", + "------\n", + " (1) YamlConfig with config_file=./darcy_config.yaml, config_name=default, config_folder=.\n", + "\n", + "-------------------------------\n", + "\n", + "Configuration:\n", + "--------------\n", + "\n", + "n_params_baseline=None\n", + "verbose=True\n", + "arch=tfno2d\n", + "distributed.use_distributed=False\n", + "distributed.wireup_info=mpi\n", + "distributed.wireup_store=tcp\n", + "distributed.model_parallel_size=2\n", + "distributed.seed=666\n", + "tfno2d.data_channels=1\n", + "tfno2d.n_modes_height=16\n", + "tfno2d.n_modes_width=16\n", + "tfno2d.hidden_channels=32\n", + "tfno2d.projection_channels=64\n", + "tfno2d.n_layers=4\n", + "tfno2d.domain_padding=None\n", + "tfno2d.domain_padding_mode=one-sided\n", + "tfno2d.fft_norm=forward\n", + "tfno2d.norm=group_norm\n", + "tfno2d.skip=linear\n", + "tfno2d.implementation=factorized\n", + "tfno2d.separable=0\n", + "tfno2d.preactivation=0\n", + "tfno2d.use_mlp=1\n", + "tfno2d.mlp.expansion=0.5\n", + "tfno2d.mlp.dropout=0\n", + "tfno2d.factorization=None\n", + "tfno2d.rank=1.0\n", + "tfno2d.fixed_rank_modes=None\n", + "tfno2d.dropout=0.0\n", + "tfno2d.tensor_lasso_penalty=0.0\n", + "tfno2d.joint_factorization=False\n", + "tfno2d.fno_block_precision=full\n", + "tfno2d.stabilizer=None\n", + "opt.n_epochs=2\n", + "opt.learning_rate=0.005\n", + "opt.training_loss=h1\n", + "opt.weight_decay=0.0001\n", + "opt.amp_autocast=False\n", + "opt.scheduler_T_max=500\n", + "opt.scheduler_patience=5\n", + "opt.scheduler=StepLR\n", + "opt.step_size=60\n", + "opt.gamma=0.5\n", + "data.folder=/content/data/darcy_flow\n", + "data.batch_size=16\n", + "data.n_train=1000\n", + "data.train_resolution=32\n", + "data.n_tests=[100, 50]\n", + "data.test_resolutions=[32, 64]\n", + "data.test_batch_sizes=[16, 16]\n", + "data.positional_encoding=True\n", + "data.encode_input=True\n", + "data.encode_output=False\n", + "patching.levels=0\n", + "patching.padding=0\n", + "patching.stitching=False\n", + "wandb.log=False\n", + "wandb.name=None\n", + "wandb.group=\n", + "wandb.project=\n", + "wandb.entity=\n", + "wandb.sweep=False\n", + "wandb.log_output=True\n", + "wandb.log_test_interval=1\n", + "\n", + "###############################\n" + ] + } + ], + "source": [ + "# Make sure we only print information when needed\n", + "config.verbose = config.verbose and is_logger\n", + "\n", + "#Print config to screen\n", + "if config.verbose and is_logger:\n", + " pipe.log()\n", + " sys.stdout.flush()" + ] + }, + { + "cell_type": "markdown", + "id": "1339c794-3e1c-469b-b0a0-cf968fc1dfa1", + "metadata": { + "id": "1339c794-3e1c-469b-b0a0-cf968fc1dfa1" + }, + "source": [ + "# Loading the data\n", + "\n", + "We train in one resolution and test in several resolutions to show the zero-shot super-resolution capabilities of neural-operators." + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "3515a85a-40fc-4223-9cdb-8768de37d6e2", + "metadata": { + "id": "3515a85a-40fc-4223-9cdb-8768de37d6e2", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "7a6028bc-0c9a-4699-8c8c-bf53304d2b7f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Loading test db at resolution 64 with 50 samples and batch-size=16\n" + ] + } + ], + "source": [ + "# Loading the Darcy flow training set in 32x32 resolution, test set in 32x32 and 64x64 resolutions\n", + "train_loader, test_loaders, output_encoder = load_darcy_pt(\n", + " config.data.folder, train_resolution=config.data.train_resolution, n_train=config.data.n_train, batch_size=config.data.batch_size,\n", + " positional_encoding=config.data.positional_encoding,\n", + " test_resolutions=config.data.test_resolutions, n_tests=config.data.n_tests, test_batch_sizes=config.data.test_batch_sizes,\n", + " encode_input=config.data.encode_input, encode_output=config.data.encode_output,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "8109298a-aca3-45b7-a8de-c5cf4e1c210b", + "metadata": { + "id": "8109298a-aca3-45b7-a8de-c5cf4e1c210b" + }, + "source": [ + "# Creating the model and putting it on the GPU" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "db295d23-ab86-4f37-83cc-7af0a8e485ea", + "metadata": { + "id": "db295d23-ab86-4f37-83cc-7af0a8e485ea", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "3400fee1-42b7-4060-f6ab-0b9fae360f4d" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "n_params: 1199713\n" + ] + } + ], + "source": [ + "model = get_model(config)\n", + "model = model.to(device)\n", + "\n", + "#Log parameter count\n", + "if is_logger:\n", + " n_params = count_model_params(model)\n", + "\n", + " if config.verbose:\n", + " print(f'\\nn_params: {n_params}')\n", + " sys.stdout.flush()" + ] + }, + { + "cell_type": "markdown", + "id": "fec85d0a-4db4-4b1f-b599-8c2afc98520a", + "metadata": { + "id": "fec85d0a-4db4-4b1f-b599-8c2afc98520a" + }, + "source": [ + "# Create the optimizer and learning rate scheduler\n", + "\n", + "Here, we use an Adam optimizer and a learning rate schedule depending on the configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "5164537a-267b-4fda-9bcd-257dc3ac4826", + "metadata": { + "id": "5164537a-267b-4fda-9bcd-257dc3ac4826" + }, + "outputs": [], + "source": [ + "#Create the optimizer\n", + "optimizer = torch.optim.Adam(model.parameters(),\n", + " lr=config.opt.learning_rate,\n", + " weight_decay=config.opt.weight_decay)\n", + "\n", + "if config.opt.scheduler == 'ReduceLROnPlateau':\n", + " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.opt.gamma, patience=config.opt.scheduler_patience, mode='min')\n", + "elif config.opt.scheduler == 'CosineAnnealingLR':\n", + " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.opt.scheduler_T_max)\n", + "elif config.opt.scheduler == 'StepLR':\n", + " scheduler = torch.optim.lr_scheduler.StepLR(optimizer,\n", + " step_size=config.opt.step_size,\n", + " gamma=config.opt.gamma)\n", + "else:\n", + " raise ValueError(f'Got {config.opt.scheduler=}')" + ] + }, + { + "cell_type": "markdown", + "id": "e52a72eb-965a-4997-89a4-0cdfcbcb0a1a", + "metadata": { + "id": "e52a72eb-965a-4997-89a4-0cdfcbcb0a1a" + }, + "source": [ + "# Creating the loss\n", + "\n", + "We will optimize the Sobolev norm but also evaluate our goal: the l2 relative error" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "07a53d9d-2d06-4d36-9b46-2c7f15f29c40", + "metadata": { + "id": "07a53d9d-2d06-4d36-9b46-2c7f15f29c40" + }, + "outputs": [], + "source": [ + "# Creating the losses\n", + "l2loss = LpLoss(d=2, p=2)\n", + "h1loss = H1Loss(d=2)\n", + "if config.opt.training_loss == 'l2':\n", + " train_loss = l2loss\n", + "elif config.opt.training_loss == 'h1':\n", + " train_loss = h1loss\n", + "else:\n", + " raise ValueError(f'Got training_loss={config.opt.training_loss} but expected one of [\"l2\", \"h1\"]')\n", + "eval_losses={'h1': h1loss, 'l2': l2loss}" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "5dad660e-43e9-4f38-91f6-8427b14b8ae0", + "metadata": { + "id": "5dad660e-43e9-4f38-91f6-8427b14b8ae0", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "88714486-3734-4f29-deeb-fe1862def46c" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "### MODEL ###\n", + " TFNO2d(\n", + " (fno_blocks): FNOBlocks(\n", + " (convs): SpectralConv(\n", + " (weight): ModuleList(\n", + " (0-3): 4 x ComplexDenseTensor(shape=torch.Size([32, 32, 16, 9]), rank=None)\n", + " )\n", + " )\n", + " (fno_skips): ModuleList(\n", + " (0-3): 4 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " )\n", + " (mlp): ModuleList(\n", + " (0-3): 4 x MLP(\n", + " (fcs): ModuleList(\n", + " (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " (mlp_skips): ModuleList(\n", + " (0-3): 4 x SoftGating()\n", + " )\n", + " (norm): ModuleList(\n", + " (0-7): 8 x GroupNorm(1, 32, eps=1e-05, affine=True)\n", + " )\n", + " )\n", + " (lifting): MLP(\n", + " (fcs): ModuleList(\n", + " (0): Conv2d(1, 256, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " (projection): MLP(\n", + " (fcs): ModuleList(\n", + " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + ")\n", + "\n", + "### OPTIMIZER ###\n", + " Adam (\n", + "Parameter Group 0\n", + " amsgrad: False\n", + " betas: (0.9, 0.999)\n", + " capturable: False\n", + " differentiable: False\n", + " eps: 1e-08\n", + " foreach: None\n", + " fused: None\n", + " initial_lr: 0.005\n", + " lr: 0.005\n", + " maximize: False\n", + " weight_decay: 0.0001\n", + ")\n", + "\n", + "### SCHEDULER ###\n", + " \n", + "\n", + "### LOSSES ###\n", + "\n", + " * Train: \n", + "\n", + " * Test: {'h1': , 'l2': }\n", + "\n", + "### Beginning Training...\n", + "\n" + ] + } + ], + "source": [ + "if config.verbose and is_logger:\n", + " print('\\n### MODEL ###\\n', model)\n", + " print('\\n### OPTIMIZER ###\\n', optimizer)\n", + " print('\\n### SCHEDULER ###\\n', scheduler)\n", + " print('\\n### LOSSES ###')\n", + " print(f'\\n * Train: {train_loss}')\n", + " print(f'\\n * Test: {eval_losses}')\n", + " print(f'\\n### Beginning Training...\\n')\n", + " sys.stdout.flush()" + ] + }, + { + "cell_type": "markdown", + "id": "b5967441-b8bc-4ea8-a4d9-7a5bea384cbf", + "metadata": { + "id": "b5967441-b8bc-4ea8-a4d9-7a5bea384cbf" + }, + "source": [ + "# Creating the trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "a19ebfd3-8a2b-42c0-af98-7a1db2dda0f6", + "metadata": { + "id": "a19ebfd3-8a2b-42c0-af98-7a1db2dda0f6", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "b3a644be-de61-42b1-f3d5-74aa9f781ee4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "self.override_load_to_device=False\n", + "self.overrides_loss=False\n" + ] + } + ], + "source": [ + "trainer = Trainer(model=model, n_epochs=config.opt.n_epochs,\n", + " device=device,\n", + " wandb_log=config.wandb.log,\n", + " log_test_interval=config.wandb.log_test_interval,\n", + " log_output=False,\n", + " use_distributed=config.distributed.use_distributed,\n", + " verbose=config.verbose and is_logger)" + ] + }, + { + "cell_type": "markdown", + "id": "b16a3727-313d-4219-8f8f-0cec58d74b00", + "metadata": { + "id": "b16a3727-313d-4219-8f8f-0cec58d74b00" + }, + "source": [ + "# Training the model" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "0d6e3298-99ee-4371-8bad-60e6aac03d56", + "metadata": { + "id": "0d6e3298-99ee-4371-8bad-60e6aac03d56", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "157c95e8-be98-41a8-8eea-a83a25362c86" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'64_h1': 0.6382459831237793, '64_l2': 0.4895939874649048}" + ] + }, + "metadata": {}, + "execution_count": 61 + } + ], + "source": [ + "trainer.train(train_loader, test_loaders,\n", + " optimizer,\n", + " scheduler,\n", + " regularizer=False,\n", + " training_loss=train_loss,\n", + " eval_losses=eval_losses)" + ] + }, + { + "cell_type": "markdown", + "id": "1b20be56-d200-44dc-b97b-fca021e353c8", + "metadata": { + "id": "1b20be56-d200-44dc-b97b-fca021e353c8" + }, + "source": [ + "# Follow-up questions" + ] + }, + { + "cell_type": "markdown", + "id": "9a67e1d5-4b9a-4be3-bff4-fb2a6b152f9c", + "metadata": { + "id": "9a67e1d5-4b9a-4be3-bff4-fb2a6b152f9c" + }, + "source": [ + "You can now play with the configuration and see how the performance is impacted.\n", + "\n", + "Which parameters do you think will most influence performance?\n", + "Learning rate? Learning schedule? hidden_channels? Number of training samples?\n", + "\n", + "Does your intuition match the results you are getting?" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + }, + "vscode": { + "interpreter": { + "hash": "95d4b27ba6bfea4a66eebe0e0159b214d32a94d313a7f4c98bd9b87f5ee37cbe" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/3-darcy_flow.ipynb b/3-darcy_flow.ipynb deleted file mode 100644 index 6da114d..0000000 --- a/3-darcy_flow.ipynb +++ /dev/null @@ -1,198 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "id": "4df7dcda-a364-4255-9339-a9a09c2a5e34", - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "from neuralop.datasets import load_darcy_pt" - ] - }, - { - "cell_type": "markdown", - "id": "ff12d431-bde9-4eba-906b-d0faea8c49fb", - "metadata": {}, - "source": [ - "# Load the data " - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "aa9c49f5-878b-4cac-9a35-b9dc53085d11", - "metadata": {}, - "outputs": [], - "source": [ - "data_path=\"/dli/task/bootcamp/data/darcy_flow/\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f40e8c0a-c031-457b-863c-c728de7d1b80", - "metadata": {}, - "outputs": [], - "source": [ - "train_loader, test_loaders, output_encoder = load_darcy_pt(data_path, n_train=100, n_tests=[10], \n", - " batch_size=3, test_batch_sizes=[3],\n", - " test_resolutions=[32], train_resolution=32)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "2f29d90a-4fc7-4b83-8ed4-f1bb4dce2574", - "metadata": {}, - "outputs": [], - "source": [ - "train_dataset = train_loader.dataset" - ] - }, - { - "cell_type": "markdown", - "id": "21000189-ecac-42e2-b008-06eefa7b1710", - "metadata": {}, - "source": [ - "# Visualizing the data " - ] - }, - { - "cell_type": "markdown", - "id": "1cf47f09-1fb3-4667-9b04-a98b3ee8d08d", - "metadata": {}, - "source": [ - "The data is stored in a dictionary" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "b6a9aed5-6532-42ba-8131-0307460c960d", - "metadata": {}, - "outputs": [], - "source": [ - "data = train_dataset[0]\n", - "x = data['x']\n", - "y = data['y']" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "9f475172-62b0-4ce3-8dce-a7d0d9dca9fb", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([3, 128, 128])" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x.shape" - ] - }, - { - "cell_type": "markdown", - "id": "7d7947ad-f98c-414b-8a12-64270988ad1f", - "metadata": {}, - "source": [ - "`x` is of shape (3, height, width). \n", - "\n", - "This is because, in addition to the binary input, we appended a positional encoding, so the model knows the location of each pixel.\n", - "\n", - "Let's check the actual data:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "077ebd7d-883b-4300-b13d-ed88813a3be1", - "metadata": {}, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "import matplotlib.pyplot as plt" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "814a044d-a52f-4cf2-aa1b-859370012af5", - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Which sample to view\n", - "index = 10\n", - "\n", - "data = train_dataset[index]\n", - "x = data['x']\n", - "y = data['y']\n", - "fig = plt.figure(figsize=(7, 7))\n", - "ax = fig.add_subplot(2, 2, 1)\n", - "ax.imshow(x[0], cmap='gray')\n", - "ax.set_title('input x')\n", - "ax = fig.add_subplot(2, 2, 2)\n", - "ax.imshow(y.squeeze())\n", - "ax.set_title('input y')\n", - "ax = fig.add_subplot(2, 2, 3)\n", - "ax.imshow(x[1])\n", - "ax.set_title('x: 1st pos embedding')\n", - "ax = fig.add_subplot(2, 2, 4)\n", - "ax.imshow(x[2])\n", - "ax.set_title('x: 2nd pos embedding')\n", - "fig.suptitle('Visualizing one input sample', y=0.98)\n", - "plt.tight_layout()\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ba9e2a5c-98e7-47c0-9e24-7a8e41c657dc", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/4-training-on-Darcy-Flow.ipynb b/4-training-on-Darcy-Flow.ipynb deleted file mode 100644 index f6e1453..0000000 --- a/4-training-on-Darcy-Flow.ipynb +++ /dev/null @@ -1,681 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "012a357f-8533-482c-823d-a4587c49e726", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import wandb\n", - "import sys\n", - "from configmypy import ConfigPipeline, YamlConfig, ArgparseConfig\n", - "from neuralop import get_model\n", - "from neuralop import Trainer\n", - "from neuralop.training import setup\n", - "from neuralop.datasets import load_darcy_pt\n", - "from neuralop.utils import get_wandb_api_key, count_params\n", - "from neuralop import LpLoss, H1Loss" - ] - }, - { - "cell_type": "markdown", - "id": "9b6358b5-78d1-4baf-8928-6bb49b150680", - "metadata": {}, - "source": [ - "# Loading the configuration\n", - "\n", - "You can open the yaml file in config/darcy_config in the same folder as this notebook to inspect the parameters and change them." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "4503f065-4063-4a4f-b00f-06a7c3a88e27", - "metadata": {}, - "outputs": [], - "source": [ - "# Read the configuration\n", - "config_name = 'default'\n", - "pipe = ConfigPipeline([YamlConfig('./darcy_config.yaml', config_name='default', config_folder='./config'),\n", - " ])\n", - "config = pipe.read_conf()\n", - "config_name = pipe.steps[-1].config_name" - ] - }, - { - "cell_type": "markdown", - "id": "e95d820d-9578-4ad7-80b4-05a5771f1642", - "metadata": {}, - "source": [ - "## Setup\n", - "\n", - "Here we just setup pytorch and print the configuration" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "46066d9f-21a3-4aab-b6e1-f7f38e05f88b", - "metadata": {}, - "outputs": [], - "source": [ - "# Set-up distributed communication, if using\n", - "device, is_logger = setup(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "26d599f9-6463-4056-9a4d-72c01d05298e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "###############################\n", - "##### CONFIGURATION #####\n", - "###############################\n", - "\n", - "Steps:\n", - "------\n", - " (1) YamlConfig with config_file=./darcy_config.yaml, config_name=default, config_folder=./config\n", - "\n", - "-------------------------------\n", - "\n", - "Configuration:\n", - "--------------\n", - "\n", - "n_params_baseline=None\n", - "verbose=True\n", - "arch=tfno2d\n", - "distributed.use_distributed=False\n", - "tfno2d.data_channels=3\n", - "tfno2d.n_modes_height=32\n", - "tfno2d.n_modes_width=32\n", - "tfno2d.hidden_channels=64\n", - "tfno2d.projection_channels=256\n", - "tfno2d.n_layers=4\n", - "tfno2d.domain_padding=None\n", - "tfno2d.domain_padding_mode=one-sided\n", - "tfno2d.fft_norm=forward\n", - "tfno2d.norm=group_norm\n", - "tfno2d.skip=linear\n", - "tfno2d.implementation=factorized\n", - "tfno2d.separable=0\n", - "tfno2d.preactivation=0\n", - "tfno2d.use_mlp=1\n", - "tfno2d.mlp.expansion=0.5\n", - "tfno2d.mlp.dropout=0\n", - "tfno2d.factorization=None\n", - "tfno2d.rank=1.0\n", - "tfno2d.fixed_rank_modes=None\n", - "tfno2d.dropout=0.0\n", - "tfno2d.tensor_lasso_penalty=0.0\n", - "tfno2d.joint_factorization=False\n", - "opt.n_epochs=150\n", - "opt.learning_rate=0.005\n", - "opt.training_loss=h1\n", - "opt.weight_decay=0.0001\n", - "opt.amp_autocast=False\n", - "opt.scheduler_T_max=300\n", - "opt.scheduler_patience=5\n", - "opt.scheduler=CosineAnnealingLR\n", - "opt.step_size=50\n", - "opt.gamma=0.5\n", - "data.folder=/data/darcy_flow/\n", - "data.batch_size=32\n", - "data.n_train=3000\n", - "data.train_resolution=32\n", - "data.n_tests=[500, 500]\n", - "data.test_resolutions=[32, 64]\n", - "data.test_batch_sizes=[32, 32]\n", - "data.positional_encoding=True\n", - "data.encode_input=True\n", - "data.encode_output=False\n", - "patching.levels=0\n", - "patching.padding=0\n", - "patching.stitching=False\n", - "wandb.log=False\n", - "wandb.log_test_interval=1\n", - "\n", - "###############################\n" - ] - } - ], - "source": [ - "# Make sure we only print information when needed\n", - "config.verbose = config.verbose and is_logger\n", - "\n", - "#Print config to screen\n", - "if config.verbose and is_logger:\n", - " pipe.log()\n", - " sys.stdout.flush()" - ] - }, - { - "cell_type": "markdown", - "id": "1339c794-3e1c-469b-b0a0-cf968fc1dfa1", - "metadata": {}, - "source": [ - "# Loading the data \n", - "\n", - "We train in one resolution and test in several resolutions to show the zero-shot super-resolution capabilities of neural-operators. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3515a85a-40fc-4223-9cdb-8768de37d6e2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "UnitGaussianNormalizer init on 3000, reducing over [0, 1, 2, 3], samples of shape [1, 32, 32].\n", - " Mean and std of shape torch.Size([1, 1, 1]), eps=1e-05\n", - "Loading test db at resolution 64 with 500 samples and batch-size=32\n" - ] - } - ], - "source": [ - "# Loading the Darcy flow training set in 32x32 resolution, test set in 32x32 and 64x64 resolutions\n", - "train_loader, test_loaders, output_encoder = load_darcy_pt(\n", - " config.data.folder, train_resolution=config.data.train_resolution, n_train=config.data.n_train, batch_size=config.data.batch_size, \n", - " positional_encoding=config.data.positional_encoding,\n", - " test_resolutions=config.data.test_resolutions, n_tests=config.data.n_tests, test_batch_sizes=config.data.test_batch_sizes,\n", - " encode_input=config.data.encode_input, encode_output=config.data.encode_output,\n", - " )" - ] - }, - { - "cell_type": "markdown", - "id": "8109298a-aca3-45b7-a8de-c5cf4e1c210b", - "metadata": {}, - "source": [ - "# Creating the model and putting it on the GPU " - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "db295d23-ab86-4f37-83cc-7af0a8e485ea", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Given argument key='dropout' that is not in TFNO2d's signature.\n", - "Given argument key='tensor_lasso_penalty' that is not in TFNO2d's signature.\n", - "Keyword argument out_channels not specified for model TFNO2d, using default=1.\n", - "Keyword argument lifting_channels not specified for model TFNO2d, using default=256.\n", - "Keyword argument non_linearity not specified for model TFNO2d, using default=.\n", - "Keyword argument decomposition_kwargs not specified for model TFNO2d, using default={}.\n", - "\n", - "n_params: 16844673\n" - ] - } - ], - "source": [ - "model = get_model(config)\n", - "model = model.to(device)\n", - "\n", - "#Log parameter count\n", - "if is_logger:\n", - " n_params = count_params(model)\n", - "\n", - " if config.verbose:\n", - " print(f'\\nn_params: {n_params}')\n", - " sys.stdout.flush()" - ] - }, - { - "cell_type": "markdown", - "id": "fec85d0a-4db4-4b1f-b599-8c2afc98520a", - "metadata": {}, - "source": [ - "# Create the optimizer and learning rate scheduler\n", - "\n", - "Here, we use an Adam optimizer and a learning rate schedule depending on the configuration" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "5164537a-267b-4fda-9bcd-257dc3ac4826", - "metadata": {}, - "outputs": [], - "source": [ - "#Create the optimizer\n", - "optimizer = torch.optim.Adam(model.parameters(), \n", - " lr=config.opt.learning_rate, \n", - " weight_decay=config.opt.weight_decay)\n", - "\n", - "if config.opt.scheduler == 'ReduceLROnPlateau':\n", - " scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.opt.gamma, patience=config.opt.scheduler_patience, mode='min')\n", - "elif config.opt.scheduler == 'CosineAnnealingLR':\n", - " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.opt.scheduler_T_max)\n", - "elif config.opt.scheduler == 'StepLR':\n", - " scheduler = torch.optim.lr_scheduler.StepLR(optimizer, \n", - " step_size=config.opt.step_size,\n", - " gamma=config.opt.gamma)\n", - "else:\n", - " raise ValueError(f'Got {config.opt.scheduler=}')" - ] - }, - { - "cell_type": "markdown", - "id": "e52a72eb-965a-4997-89a4-0cdfcbcb0a1a", - "metadata": {}, - "source": [ - "# Creating the loss\n", - "\n", - "We will optimize the Sobolev norm but also evaluate our goal: the l2 relative error" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "07a53d9d-2d06-4d36-9b46-2c7f15f29c40", - "metadata": {}, - "outputs": [], - "source": [ - "# Creating the losses\n", - "l2loss = LpLoss(d=2, p=2)\n", - "h1loss = H1Loss(d=2)\n", - "if config.opt.training_loss == 'l2':\n", - " train_loss = l2loss\n", - "elif config.opt.training_loss == 'h1':\n", - " train_loss = h1loss\n", - "else:\n", - " raise ValueError(f'Got training_loss={config.opt.training_loss} but expected one of [\"l2\", \"h1\"]')\n", - "eval_losses={'h1': h1loss, 'l2': l2loss}" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "5dad660e-43e9-4f38-91f6-8427b14b8ae0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "### MODEL ###\n", - " TFNO2d(\n", - " (convs): FactorizedSpectralConv2d(\n", - " (weight): ModuleList(\n", - " (0): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", - " (1): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", - " (2): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", - " (3): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", - " (4): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", - " (5): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", - " (6): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", - " (7): ComplexDenseTensor(shape=torch.Size([64, 64, 16, 16]), rank=None)\n", - " )\n", - " )\n", - " (fno_skips): ModuleList(\n", - " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " )\n", - " (mlp): ModuleList(\n", - " (0): MLP(\n", - " (fcs): ModuleList(\n", - " (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " )\n", - " (1): MLP(\n", - " (fcs): ModuleList(\n", - " (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " )\n", - " (2): MLP(\n", - " (fcs): ModuleList(\n", - " (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " )\n", - " (3): MLP(\n", - " (fcs): ModuleList(\n", - " (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " )\n", - " )\n", - " (mlp_skips): ModuleList(\n", - " (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " )\n", - " (norm): ModuleList(\n", - " (0): GroupNorm(1, 64, eps=1e-05, affine=True)\n", - " (1): GroupNorm(1, 64, eps=1e-05, affine=True)\n", - " (2): GroupNorm(1, 64, eps=1e-05, affine=True)\n", - " (3): GroupNorm(1, 64, eps=1e-05, affine=True)\n", - " )\n", - " (lifting): Lifting(\n", - " (fc): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (projection): Projection(\n", - " (fc1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))\n", - " (fc2): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - ")\n", - "\n", - "### OPTIMIZER ###\n", - " Adam (\n", - "Parameter Group 0\n", - " amsgrad: False\n", - " betas: (0.9, 0.999)\n", - " capturable: False\n", - " differentiable: False\n", - " eps: 1e-08\n", - " foreach: None\n", - " fused: False\n", - " initial_lr: 0.005\n", - " lr: 0.005\n", - " maximize: False\n", - " weight_decay: 0.0001\n", - ")\n", - "\n", - "### SCHEDULER ###\n", - " \n", - "\n", - "### LOSSES ###\n", - "\n", - " * Train: \n", - "\n", - " * Test: {'h1': , 'l2': }\n", - "\n", - "### Beginning Training...\n", - "\n" - ] - } - ], - "source": [ - "if config.verbose and is_logger:\n", - " print('\\n### MODEL ###\\n', model)\n", - " print('\\n### OPTIMIZER ###\\n', optimizer)\n", - " print('\\n### SCHEDULER ###\\n', scheduler)\n", - " print('\\n### LOSSES ###')\n", - " print(f'\\n * Train: {train_loss}')\n", - " print(f'\\n * Test: {eval_losses}')\n", - " print(f'\\n### Beginning Training...\\n')\n", - " sys.stdout.flush()" - ] - }, - { - "cell_type": "markdown", - "id": "b5967441-b8bc-4ea8-a4d9-7a5bea384cbf", - "metadata": {}, - "source": [ - "# Creating the trainer" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "a19ebfd3-8a2b-42c0-af98-7a1db2dda0f6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training on regular inputs (no multi-grid patching).\n", - "MGPatching(self.n_patches=[1, 1], self.padding_fraction=[0, 0], self.levels=0, use_distributed=False, stitching=False)\n" - ] - } - ], - "source": [ - "trainer = Trainer(model, n_epochs=config.opt.n_epochs,\n", - " device=device,\n", - " mg_patching_levels=config.patching.levels,\n", - " mg_patching_padding=config.patching.padding,\n", - " mg_patching_stitching=config.patching.stitching,\n", - " wandb_log=config.wandb.log,\n", - " log_test_interval=config.wandb.log_test_interval,\n", - " log_output=False,\n", - " use_distributed=config.distributed.use_distributed,\n", - " verbose=config.verbose and is_logger)" - ] - }, - { - "cell_type": "markdown", - "id": "b16a3727-313d-4219-8f8f-0cec58d74b00", - "metadata": {}, - "source": [ - "# Training the model " - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "0d6e3298-99ee-4371-8bad-60e6aac03d56", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training on 3000 samples, testing on [32, 64].\n", - "[0] time=3.03, avg_loss=7.8899, train_err=0.3945, 32_h1=0.2295, 32_l2=0.1710, 64_h1=0.2847, 64_l2=0.1733\n", - "[1] time=1.38, avg_loss=3.7664, train_err=0.1883, 32_h1=0.1646, 32_l2=0.1177, 64_h1=0.2326, 64_l2=0.1221\n", - "[2] time=1.37, avg_loss=3.1005, train_err=0.1550, 32_h1=0.1411, 32_l2=0.1027, 64_h1=0.2156, 64_l2=0.1106\n", - "[3] time=1.36, avg_loss=2.5222, train_err=0.1261, 32_h1=0.1238, 32_l2=0.0800, 64_h1=0.2026, 64_l2=0.0936\n", - "[4] time=1.36, avg_loss=2.3043, train_err=0.1152, 32_h1=0.1235, 32_l2=0.0808, 64_h1=0.1874, 64_l2=0.0858\n", - "[5] time=1.36, avg_loss=2.2108, train_err=0.1105, 32_h1=0.1332, 32_l2=0.1041, 64_h1=0.2055, 64_l2=0.1122\n", - "[6] time=1.37, avg_loss=1.9753, train_err=0.0988, 32_h1=0.1077, 32_l2=0.0720, 64_h1=0.1885, 64_l2=0.0838\n", - "[7] time=1.37, avg_loss=1.9352, train_err=0.0968, 32_h1=0.1032, 32_l2=0.0642, 64_h1=0.1847, 64_l2=0.0753\n", - "[8] time=1.36, avg_loss=1.8174, train_err=0.0909, 32_h1=0.1013, 32_l2=0.0632, 64_h1=0.1798, 64_l2=0.0763\n", - "[9] time=1.37, avg_loss=1.7847, train_err=0.0892, 32_h1=0.1053, 32_l2=0.0672, 64_h1=0.1909, 64_l2=0.0788\n", - "[10] time=1.37, avg_loss=1.6375, train_err=0.0819, 32_h1=0.0926, 32_l2=0.0513, 64_h1=0.1808, 64_l2=0.0666\n", - "[11] time=1.37, avg_loss=1.5826, train_err=0.0791, 32_h1=0.0958, 32_l2=0.0574, 64_h1=0.1810, 64_l2=0.0700\n", - "[12] time=1.36, avg_loss=1.6231, train_err=0.0812, 32_h1=0.0940, 32_l2=0.0534, 64_h1=0.1740, 64_l2=0.0636\n", - "[13] time=1.42, avg_loss=1.5427, train_err=0.0771, 32_h1=0.0937, 32_l2=0.0532, 64_h1=0.1834, 64_l2=0.0692\n", - "[14] time=1.37, avg_loss=1.4741, train_err=0.0737, 32_h1=0.0989, 32_l2=0.0623, 64_h1=0.1844, 64_l2=0.0798\n", - "[15] time=1.36, avg_loss=1.5156, train_err=0.0758, 32_h1=0.1020, 32_l2=0.0649, 64_h1=0.1844, 64_l2=0.0730\n", - "[16] time=1.37, avg_loss=1.5620, train_err=0.0781, 32_h1=0.0940, 32_l2=0.0608, 64_h1=0.1803, 64_l2=0.0747\n", - "[17] time=1.36, avg_loss=1.3939, train_err=0.0697, 32_h1=0.1018, 32_l2=0.0620, 64_h1=0.1842, 64_l2=0.0772\n", - "[18] time=1.89, avg_loss=1.4904, train_err=0.0745, 32_h1=0.1010, 32_l2=0.0704, 64_h1=0.1868, 64_l2=0.0794\n", - "[19] time=1.83, avg_loss=1.4300, train_err=0.0715, 32_h1=0.0929, 32_l2=0.0525, 64_h1=0.1784, 64_l2=0.0679\n", - "[20] time=1.84, avg_loss=1.3752, train_err=0.0688, 32_h1=0.0964, 32_l2=0.0635, 64_h1=0.1825, 64_l2=0.0694\n", - "[21] time=1.84, avg_loss=1.4671, train_err=0.0734, 32_h1=0.0911, 32_l2=0.0513, 64_h1=0.1832, 64_l2=0.0696\n", - "[22] time=1.88, avg_loss=1.3043, train_err=0.0652, 32_h1=0.0938, 32_l2=0.0538, 64_h1=0.1804, 64_l2=0.0687\n", - "[23] time=1.37, avg_loss=1.2880, train_err=0.0644, 32_h1=0.0897, 32_l2=0.0492, 64_h1=0.1824, 64_l2=0.0629\n", - "[24] time=1.37, avg_loss=1.3901, train_err=0.0695, 32_h1=0.1080, 32_l2=0.0701, 64_h1=0.1828, 64_l2=0.0785\n", - "[25] time=1.37, avg_loss=1.3788, train_err=0.0689, 32_h1=0.0878, 32_l2=0.0514, 64_h1=0.1744, 64_l2=0.0613\n", - "[26] time=1.37, avg_loss=1.3071, train_err=0.0654, 32_h1=0.0880, 32_l2=0.0489, 64_h1=0.1847, 64_l2=0.0698\n", - "[27] time=1.36, avg_loss=1.3056, train_err=0.0653, 32_h1=0.0980, 32_l2=0.0679, 64_h1=0.1828, 64_l2=0.0830\n", - "[28] time=1.37, avg_loss=1.2677, train_err=0.0634, 32_h1=0.0956, 32_l2=0.0621, 64_h1=0.1827, 64_l2=0.0692\n", - "[29] time=1.37, avg_loss=1.2611, train_err=0.0631, 32_h1=0.0913, 32_l2=0.0500, 64_h1=0.1855, 64_l2=0.0652\n", - "[30] time=1.37, avg_loss=1.1833, train_err=0.0592, 32_h1=0.0888, 32_l2=0.0512, 64_h1=0.1818, 64_l2=0.0655\n", - "[31] time=1.36, avg_loss=1.2170, train_err=0.0608, 32_h1=0.0879, 32_l2=0.0481, 64_h1=0.1758, 64_l2=0.0625\n", - "[32] time=1.36, avg_loss=1.1431, train_err=0.0572, 32_h1=0.0886, 32_l2=0.0479, 64_h1=0.1756, 64_l2=0.0594\n", - "[33] time=1.37, avg_loss=1.2162, train_err=0.0608, 32_h1=0.0923, 32_l2=0.0522, 64_h1=0.1749, 64_l2=0.0629\n", - "[34] time=1.37, avg_loss=1.1588, train_err=0.0579, 32_h1=0.0892, 32_l2=0.0526, 64_h1=0.1797, 64_l2=0.0656\n", - "[35] time=1.37, avg_loss=1.1747, train_err=0.0587, 32_h1=0.0884, 32_l2=0.0481, 64_h1=0.1829, 64_l2=0.0650\n", - "[36] time=1.36, avg_loss=1.1491, train_err=0.0575, 32_h1=0.0936, 32_l2=0.0542, 64_h1=0.1787, 64_l2=0.0672\n", - "[37] time=1.37, avg_loss=1.1532, train_err=0.0577, 32_h1=0.0950, 32_l2=0.0569, 64_h1=0.1737, 64_l2=0.0679\n", - "[38] time=1.37, avg_loss=1.2426, train_err=0.0621, 32_h1=0.0875, 32_l2=0.0488, 64_h1=0.1750, 64_l2=0.0638\n", - "[39] time=1.37, avg_loss=1.1345, train_err=0.0567, 32_h1=0.0874, 32_l2=0.0493, 64_h1=0.1780, 64_l2=0.0658\n", - "[40] time=1.36, avg_loss=1.1238, train_err=0.0562, 32_h1=0.0914, 32_l2=0.0516, 64_h1=0.1796, 64_l2=0.0662\n", - "[41] time=1.36, avg_loss=1.1093, train_err=0.0555, 32_h1=0.0855, 32_l2=0.0457, 64_h1=0.1741, 64_l2=0.0621\n", - "[42] time=1.36, avg_loss=1.0772, train_err=0.0539, 32_h1=0.0899, 32_l2=0.0523, 64_h1=0.1807, 64_l2=0.0688\n", - "[43] time=1.36, avg_loss=1.0772, train_err=0.0539, 32_h1=0.0894, 32_l2=0.0556, 64_h1=0.1769, 64_l2=0.0705\n", - "[44] time=1.36, avg_loss=1.0901, train_err=0.0545, 32_h1=0.0843, 32_l2=0.0443, 64_h1=0.1750, 64_l2=0.0589\n", - "[45] time=1.36, avg_loss=1.0783, train_err=0.0539, 32_h1=0.0874, 32_l2=0.0486, 64_h1=0.1778, 64_l2=0.0593\n", - "[46] time=1.46, avg_loss=1.0837, train_err=0.0542, 32_h1=0.0874, 32_l2=0.0482, 64_h1=0.1722, 64_l2=0.0575\n", - "[47] time=1.37, avg_loss=1.1760, train_err=0.0588, 32_h1=0.0873, 32_l2=0.0507, 64_h1=0.1706, 64_l2=0.0639\n", - "[48] time=1.37, avg_loss=1.0357, train_err=0.0518, 32_h1=0.0889, 32_l2=0.0503, 64_h1=0.1799, 64_l2=0.0663\n", - "[49] time=1.36, avg_loss=1.0873, train_err=0.0544, 32_h1=0.0846, 32_l2=0.0464, 64_h1=0.1725, 64_l2=0.0592\n", - "[50] time=1.36, avg_loss=1.0996, train_err=0.0550, 32_h1=0.0861, 32_l2=0.0461, 64_h1=0.1696, 64_l2=0.0598\n", - "[51] time=1.37, avg_loss=1.0487, train_err=0.0524, 32_h1=0.0839, 32_l2=0.0433, 64_h1=0.1752, 64_l2=0.0602\n", - "[52] time=1.37, avg_loss=1.0527, train_err=0.0526, 32_h1=0.0858, 32_l2=0.0469, 64_h1=0.1736, 64_l2=0.0588\n", - "[53] time=1.36, avg_loss=1.0138, train_err=0.0507, 32_h1=0.0854, 32_l2=0.0475, 64_h1=0.1777, 64_l2=0.0619\n", - "[54] time=1.36, avg_loss=1.0210, train_err=0.0511, 32_h1=0.0832, 32_l2=0.0431, 64_h1=0.1728, 64_l2=0.0580\n", - "[55] time=1.36, avg_loss=0.9939, train_err=0.0497, 32_h1=0.0870, 32_l2=0.0474, 64_h1=0.1755, 64_l2=0.0609\n", - "[56] time=1.37, avg_loss=1.0085, train_err=0.0504, 32_h1=0.0833, 32_l2=0.0438, 64_h1=0.1731, 64_l2=0.0603\n", - "[57] time=1.37, avg_loss=1.0132, train_err=0.0507, 32_h1=0.0842, 32_l2=0.0462, 64_h1=0.1757, 64_l2=0.0613\n", - "[58] time=1.36, avg_loss=0.9938, train_err=0.0497, 32_h1=0.0839, 32_l2=0.0439, 64_h1=0.1811, 64_l2=0.0651\n", - "[59] time=1.36, avg_loss=0.9814, train_err=0.0491, 32_h1=0.0820, 32_l2=0.0425, 64_h1=0.1728, 64_l2=0.0565\n", - "[60] time=1.36, avg_loss=0.9849, train_err=0.0492, 32_h1=0.0861, 32_l2=0.0477, 64_h1=0.1715, 64_l2=0.0616\n", - "[61] time=1.37, avg_loss=0.9787, train_err=0.0489, 32_h1=0.0844, 32_l2=0.0450, 64_h1=0.1742, 64_l2=0.0623\n", - "[62] time=1.36, avg_loss=1.0104, train_err=0.0505, 32_h1=0.0830, 32_l2=0.0437, 64_h1=0.1769, 64_l2=0.0605\n", - "[63] time=1.36, avg_loss=0.9910, train_err=0.0495, 32_h1=0.0821, 32_l2=0.0415, 64_h1=0.1742, 64_l2=0.0579\n", - "[64] time=1.36, avg_loss=0.9622, train_err=0.0481, 32_h1=0.0849, 32_l2=0.0462, 64_h1=0.1763, 64_l2=0.0608\n", - "[65] time=1.36, avg_loss=1.0191, train_err=0.0510, 32_h1=0.0823, 32_l2=0.0419, 64_h1=0.1736, 64_l2=0.0570\n", - "[66] time=1.37, avg_loss=0.9814, train_err=0.0491, 32_h1=0.0873, 32_l2=0.0492, 64_h1=0.1752, 64_l2=0.0643\n", - "[67] time=1.36, avg_loss=0.9867, train_err=0.0493, 32_h1=0.0833, 32_l2=0.0446, 64_h1=0.1698, 64_l2=0.0588\n", - "[68] time=1.36, avg_loss=0.9983, train_err=0.0499, 32_h1=0.0815, 32_l2=0.0417, 64_h1=0.1712, 64_l2=0.0590\n", - "[69] time=1.37, avg_loss=0.9956, train_err=0.0498, 32_h1=0.0836, 32_l2=0.0453, 64_h1=0.1756, 64_l2=0.0604\n", - "[70] time=1.37, avg_loss=0.9433, train_err=0.0472, 32_h1=0.0830, 32_l2=0.0432, 64_h1=0.1739, 64_l2=0.0583\n", - "[71] time=1.36, avg_loss=0.9813, train_err=0.0491, 32_h1=0.0830, 32_l2=0.0433, 64_h1=0.1691, 64_l2=0.0588\n", - "[72] time=1.36, avg_loss=0.9456, train_err=0.0473, 32_h1=0.0828, 32_l2=0.0429, 64_h1=0.1695, 64_l2=0.0599\n", - "[73] time=1.37, avg_loss=0.9099, train_err=0.0455, 32_h1=0.0835, 32_l2=0.0438, 64_h1=0.1716, 64_l2=0.0599\n", - "[74] time=1.37, avg_loss=0.9241, train_err=0.0462, 32_h1=0.0816, 32_l2=0.0419, 64_h1=0.1699, 64_l2=0.0572\n", - "[75] time=1.37, avg_loss=0.8907, train_err=0.0445, 32_h1=0.0825, 32_l2=0.0410, 64_h1=0.1772, 64_l2=0.0604\n", - "[76] time=1.36, avg_loss=0.8940, train_err=0.0447, 32_h1=0.0821, 32_l2=0.0428, 64_h1=0.1733, 64_l2=0.0588\n", - "[77] time=1.37, avg_loss=0.8958, train_err=0.0448, 32_h1=0.0828, 32_l2=0.0447, 64_h1=0.1756, 64_l2=0.0593\n", - "[78] time=1.37, avg_loss=0.9276, train_err=0.0464, 32_h1=0.0816, 32_l2=0.0424, 64_h1=0.1740, 64_l2=0.0599\n", - "[79] time=1.37, avg_loss=0.8763, train_err=0.0438, 32_h1=0.0818, 32_l2=0.0414, 64_h1=0.1715, 64_l2=0.0570\n", - "[80] time=1.36, avg_loss=0.8634, train_err=0.0432, 32_h1=0.0812, 32_l2=0.0416, 64_h1=0.1753, 64_l2=0.0614\n", - "[81] time=1.36, avg_loss=0.8450, train_err=0.0423, 32_h1=0.0832, 32_l2=0.0448, 64_h1=0.1701, 64_l2=0.0626\n", - "[82] time=1.37, avg_loss=0.8997, train_err=0.0450, 32_h1=0.0818, 32_l2=0.0419, 64_h1=0.1718, 64_l2=0.0590\n", - "[83] time=1.37, avg_loss=0.8658, train_err=0.0433, 32_h1=0.0816, 32_l2=0.0415, 64_h1=0.1703, 64_l2=0.0552\n", - "[84] time=1.37, avg_loss=0.9292, train_err=0.0465, 32_h1=0.0815, 32_l2=0.0424, 64_h1=0.1674, 64_l2=0.0580\n", - "[85] time=1.36, avg_loss=0.9417, train_err=0.0471, 32_h1=0.0825, 32_l2=0.0439, 64_h1=0.1755, 64_l2=0.0608\n", - "[86] time=1.37, avg_loss=0.8608, train_err=0.0430, 32_h1=0.0792, 32_l2=0.0392, 64_h1=0.1720, 64_l2=0.0573\n", - "[87] time=1.38, avg_loss=0.9083, train_err=0.0454, 32_h1=0.0822, 32_l2=0.0440, 64_h1=0.1693, 64_l2=0.0602\n", - "[88] time=1.57, avg_loss=0.8522, train_err=0.0426, 32_h1=0.0823, 32_l2=0.0427, 64_h1=0.1695, 64_l2=0.0571\n", - "[89] time=1.36, avg_loss=0.8273, train_err=0.0414, 32_h1=0.0813, 32_l2=0.0414, 64_h1=0.1702, 64_l2=0.0568\n", - "[90] time=1.36, avg_loss=0.8612, train_err=0.0431, 32_h1=0.0834, 32_l2=0.0468, 64_h1=0.1718, 64_l2=0.0641\n", - "[91] time=1.36, avg_loss=0.8358, train_err=0.0418, 32_h1=0.0811, 32_l2=0.0410, 64_h1=0.1678, 64_l2=0.0558\n", - "[92] time=1.37, avg_loss=0.8725, train_err=0.0436, 32_h1=0.0807, 32_l2=0.0408, 64_h1=0.1688, 64_l2=0.0557\n", - "[93] time=1.36, avg_loss=0.8163, train_err=0.0408, 32_h1=0.0804, 32_l2=0.0417, 64_h1=0.1714, 64_l2=0.0593\n", - "[94] time=1.36, avg_loss=0.8119, train_err=0.0406, 32_h1=0.0791, 32_l2=0.0393, 64_h1=0.1706, 64_l2=0.0581\n", - "[95] time=1.36, avg_loss=0.8022, train_err=0.0401, 32_h1=0.0819, 32_l2=0.0416, 64_h1=0.1697, 64_l2=0.0555\n", - "[96] time=1.37, avg_loss=0.8371, train_err=0.0419, 32_h1=0.0793, 32_l2=0.0393, 64_h1=0.1684, 64_l2=0.0570\n", - "[97] time=1.37, avg_loss=0.8227, train_err=0.0411, 32_h1=0.0800, 32_l2=0.0407, 64_h1=0.1685, 64_l2=0.0583\n", - "[98] time=1.43, avg_loss=0.8176, train_err=0.0409, 32_h1=0.0841, 32_l2=0.0471, 64_h1=0.1681, 64_l2=0.0578\n", - "[99] time=1.85, avg_loss=0.8517, train_err=0.0426, 32_h1=0.0809, 32_l2=0.0401, 64_h1=0.1726, 64_l2=0.0607\n", - "[100] time=1.85, avg_loss=0.8445, train_err=0.0422, 32_h1=0.0810, 32_l2=0.0408, 64_h1=0.1688, 64_l2=0.0558\n", - "[101] time=1.42, avg_loss=0.7962, train_err=0.0398, 32_h1=0.0796, 32_l2=0.0393, 64_h1=0.1680, 64_l2=0.0577\n", - "[102] time=1.84, avg_loss=0.7758, train_err=0.0388, 32_h1=0.0799, 32_l2=0.0398, 64_h1=0.1664, 64_l2=0.0556\n", - "[103] time=1.87, avg_loss=0.8005, train_err=0.0400, 32_h1=0.0792, 32_l2=0.0395, 64_h1=0.1688, 64_l2=0.0552\n", - "[104] time=1.43, avg_loss=0.8099, train_err=0.0405, 32_h1=0.0791, 32_l2=0.0394, 64_h1=0.1664, 64_l2=0.0535\n", - "[105] time=1.37, avg_loss=0.7828, train_err=0.0391, 32_h1=0.0815, 32_l2=0.0430, 64_h1=0.1691, 64_l2=0.0574\n", - "[106] time=1.37, avg_loss=0.7799, train_err=0.0390, 32_h1=0.0795, 32_l2=0.0393, 64_h1=0.1679, 64_l2=0.0556\n", - "[107] time=1.36, avg_loss=0.7685, train_err=0.0384, 32_h1=0.0810, 32_l2=0.0434, 64_h1=0.1725, 64_l2=0.0633\n", - "[108] time=1.36, avg_loss=0.7581, train_err=0.0379, 32_h1=0.0801, 32_l2=0.0407, 64_h1=0.1744, 64_l2=0.0574\n", - "[109] time=1.37, avg_loss=0.7415, train_err=0.0371, 32_h1=0.0782, 32_l2=0.0383, 64_h1=0.1670, 64_l2=0.0540\n", - "[110] time=1.37, avg_loss=0.7387, train_err=0.0369, 32_h1=0.0790, 32_l2=0.0392, 64_h1=0.1664, 64_l2=0.0539\n", - "[111] time=1.37, avg_loss=0.7338, train_err=0.0367, 32_h1=0.0788, 32_l2=0.0385, 64_h1=0.1694, 64_l2=0.0574\n", - "[112] time=1.36, avg_loss=0.7426, train_err=0.0371, 32_h1=0.0811, 32_l2=0.0434, 64_h1=0.1745, 64_l2=0.0593\n", - "[113] time=1.36, avg_loss=0.7849, train_err=0.0392, 32_h1=0.0817, 32_l2=0.0452, 64_h1=0.1653, 64_l2=0.0627\n", - "[114] time=1.37, avg_loss=0.7933, train_err=0.0397, 32_h1=0.0803, 32_l2=0.0409, 64_h1=0.1715, 64_l2=0.0568\n", - "[115] time=1.37, avg_loss=0.7377, train_err=0.0369, 32_h1=0.0789, 32_l2=0.0389, 64_h1=0.1688, 64_l2=0.0556\n", - "[116] time=1.37, avg_loss=0.7639, train_err=0.0382, 32_h1=0.0794, 32_l2=0.0394, 64_h1=0.1683, 64_l2=0.0574\n", - "[117] time=1.36, avg_loss=0.7515, train_err=0.0376, 32_h1=0.0785, 32_l2=0.0382, 64_h1=0.1665, 64_l2=0.0549\n", - "[118] time=1.37, avg_loss=0.7180, train_err=0.0359, 32_h1=0.0792, 32_l2=0.0394, 64_h1=0.1671, 64_l2=0.0576\n", - "[119] time=1.37, avg_loss=0.7191, train_err=0.0360, 32_h1=0.0795, 32_l2=0.0396, 64_h1=0.1672, 64_l2=0.0541\n", - "[120] time=1.37, avg_loss=0.7148, train_err=0.0357, 32_h1=0.0792, 32_l2=0.0389, 64_h1=0.1671, 64_l2=0.0575\n", - "[121] time=1.36, avg_loss=0.7012, train_err=0.0351, 32_h1=0.0795, 32_l2=0.0399, 64_h1=0.1639, 64_l2=0.0555\n", - "[122] time=1.37, avg_loss=0.6962, train_err=0.0348, 32_h1=0.0787, 32_l2=0.0388, 64_h1=0.1697, 64_l2=0.0570\n", - "[123] time=1.37, avg_loss=0.6970, train_err=0.0349, 32_h1=0.0793, 32_l2=0.0388, 64_h1=0.1693, 64_l2=0.0567\n", - "[124] time=1.37, avg_loss=0.6888, train_err=0.0344, 32_h1=0.0788, 32_l2=0.0382, 64_h1=0.1687, 64_l2=0.0570\n", - "[125] time=1.37, avg_loss=0.7060, train_err=0.0353, 32_h1=0.0799, 32_l2=0.0412, 64_h1=0.1649, 64_l2=0.0576\n", - "[126] time=1.36, avg_loss=0.6991, train_err=0.0350, 32_h1=0.0792, 32_l2=0.0393, 64_h1=0.1681, 64_l2=0.0583\n", - "[127] time=1.37, avg_loss=0.7098, train_err=0.0355, 32_h1=0.0796, 32_l2=0.0406, 64_h1=0.1641, 64_l2=0.0574\n", - "[128] time=1.37, avg_loss=0.6971, train_err=0.0349, 32_h1=0.0792, 32_l2=0.0399, 64_h1=0.1690, 64_l2=0.0588\n", - "[129] time=1.37, avg_loss=0.6810, train_err=0.0340, 32_h1=0.0793, 32_l2=0.0393, 64_h1=0.1648, 64_l2=0.0559\n", - "[130] time=1.36, avg_loss=0.6848, train_err=0.0342, 32_h1=0.0780, 32_l2=0.0378, 64_h1=0.1670, 64_l2=0.0536\n", - "[131] time=1.94, avg_loss=0.6600, train_err=0.0330, 32_h1=0.0779, 32_l2=0.0379, 64_h1=0.1661, 64_l2=0.0545\n", - "[132] time=1.87, avg_loss=0.6428, train_err=0.0321, 32_h1=0.0794, 32_l2=0.0394, 64_h1=0.1695, 64_l2=0.0588\n", - "[133] time=1.85, avg_loss=0.6532, train_err=0.0327, 32_h1=0.0789, 32_l2=0.0392, 64_h1=0.1690, 64_l2=0.0568\n", - "[134] time=1.88, avg_loss=0.6573, train_err=0.0329, 32_h1=0.0780, 32_l2=0.0376, 64_h1=0.1664, 64_l2=0.0559\n", - "[135] time=1.78, avg_loss=0.6445, train_err=0.0322, 32_h1=0.0784, 32_l2=0.0386, 64_h1=0.1644, 64_l2=0.0560\n", - "[136] time=1.82, avg_loss=0.6378, train_err=0.0319, 32_h1=0.0780, 32_l2=0.0383, 64_h1=0.1651, 64_l2=0.0527\n", - "[137] time=1.85, avg_loss=0.6550, train_err=0.0327, 32_h1=0.0792, 32_l2=0.0400, 64_h1=0.1652, 64_l2=0.0551\n", - "[138] time=1.75, avg_loss=0.6341, train_err=0.0317, 32_h1=0.0775, 32_l2=0.0376, 64_h1=0.1661, 64_l2=0.0556\n", - "[139] time=2.00, avg_loss=0.8234, train_err=0.0412, 32_h1=0.0783, 32_l2=0.0386, 64_h1=0.1654, 64_l2=0.0566\n", - "[140] time=1.97, avg_loss=0.6822, train_err=0.0341, 32_h1=0.0784, 32_l2=0.0380, 64_h1=0.1675, 64_l2=0.0558\n", - "[141] time=1.98, avg_loss=0.6332, train_err=0.0317, 32_h1=0.0778, 32_l2=0.0379, 64_h1=0.1670, 64_l2=0.0546\n", - "[142] time=1.99, avg_loss=0.6205, train_err=0.0310, 32_h1=0.0786, 32_l2=0.0394, 64_h1=0.1678, 64_l2=0.0592\n", - "[143] time=1.82, avg_loss=0.6098, train_err=0.0305, 32_h1=0.0785, 32_l2=0.0386, 64_h1=0.1676, 64_l2=0.0584\n", - "[144] time=1.38, avg_loss=0.6116, train_err=0.0306, 32_h1=0.0794, 32_l2=0.0422, 64_h1=0.1702, 64_l2=0.0591\n", - "[145] time=1.37, avg_loss=0.6018, train_err=0.0301, 32_h1=0.0776, 32_l2=0.0373, 64_h1=0.1674, 64_l2=0.0564\n", - "[146] time=1.38, avg_loss=0.6001, train_err=0.0300, 32_h1=0.0781, 32_l2=0.0386, 64_h1=0.1662, 64_l2=0.0583\n", - "[147] time=1.38, avg_loss=0.5990, train_err=0.0300, 32_h1=0.0796, 32_l2=0.0416, 64_h1=0.1679, 64_l2=0.0572\n", - "[148] time=1.38, avg_loss=0.6462, train_err=0.0323, 32_h1=0.0802, 32_l2=0.0411, 64_h1=0.1721, 64_l2=0.0580\n", - "[149] time=1.37, avg_loss=0.6152, train_err=0.0308, 32_h1=0.0777, 32_l2=0.0373, 64_h1=0.1688, 64_l2=0.0562\n" - ] - } - ], - "source": [ - "trainer.train(train_loader, test_loaders,\n", - " output_encoder,\n", - " model, \n", - " optimizer,\n", - " scheduler, \n", - " regularizer=False, \n", - " training_loss=train_loss,\n", - " eval_losses=eval_losses)" - ] - }, - { - "cell_type": "markdown", - "id": "1b20be56-d200-44dc-b97b-fca021e353c8", - "metadata": {}, - "source": [ - "# Follow-up questions" - ] - }, - { - "cell_type": "markdown", - "id": "9a67e1d5-4b9a-4be3-bff4-fb2a6b152f9c", - "metadata": {}, - "source": [ - "You can now play with the configuration and see how the performance is impacted.\n", - "\n", - "Which parameters do you think will most influence performance? \n", - "Learning rate? Learning schedule? hidden_channels? Number of training samples? \n", - "\n", - "Does your intuition match the results you are getting?" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.15" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/slides/DeepONets.pdf b/slides/DeepONets.pdf new file mode 100644 index 0000000..e25bd00 Binary files /dev/null and b/slides/DeepONets.pdf differ diff --git a/slides/NeuralOperators.pdf b/slides/NeuralOperators.pdf new file mode 100644 index 0000000..cc248f9 Binary files /dev/null and b/slides/NeuralOperators.pdf differ