diff --git a/README.md b/README.md index 66484e0..529f003 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![Review Assignment Due Date](https://classroom.github.com/assets/deadline-readme-button-24ddc0f5d75046c5622901739e7c5dd533143b0c8e959d652212380cedb1ea36.svg)](https://classroom.github.com/a/fSA_TVih) +[![Review Assignment Due Date](https://classroom.github.com/assets/deadline-readme-button-8d59dc4de5201274e310e4c54b9627a8934c3b88527886e3b421487c677d23eb.svg)](https://classroom.github.com/a/fSA_TVih) [![Open in Visual Studio Code](https://classroom.github.com/assets/open-in-vscode-c66648af7eb3fe8bc4f294546bfd86ef473780cde1dea487d3c4ff354943c9ae.svg)](https://classroom.github.com/online_ide?assignment_repo_id=10680703&assignment_repo_type=AssignmentRepo) ## Final Project Repo diff --git a/implementation.ipynb b/implementation.ipynb new file mode 100644 index 0000000..a297ab2 --- /dev/null +++ b/implementation.ipynb @@ -0,0 +1,495 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a041fef9", + "metadata": {}, + "source": [ + "## 4095 Final Project" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "47891953", + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import numpy as np\n", + "#import matplotlib.pyplot as plt\n", + "import os\n", + "from tqdm import tqdm\n", + "import cv2\n", + "from numpy import random as rng\n", + "from sklearn.utils import shuffle\n", + "import pickle\n", + "import time\n", + "\n", + "from tensorflow.keras.layers import Input, Dense, Lambda\n", + "from tensorflow.keras.models import Model\n", + "from tensorflow.keras.optimizers import Adam\n", + "from tensorflow.keras import backend as K\n", + "\n", + "#from vit_keras import vit\n", + "\n", + "# import timm\n", + "# import onnx\n", + "# import torch\n", + "# from onnx2keras import onnx_to_keras" + ] + }, + { + "cell_type": "markdown", + "id": "1e873e2e", + "metadata": {}, + "source": [ + "---\n", + "\n", + "### Preprocessing the data\n", + "Here we load the image data from the downloaded omniglot dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "75755838", + "metadata": {}, + "outputs": [], + "source": [ + "def load_images(path):\n", + " #Load the image file and return the coordinates of pixels in the binary image\n", + " \n", + " X = []\n", + " y = []\n", + " lang_dict = {} #used to map the alphabet characters to their class numbers.\n", + " classNum = 0\n", + " \n", + " #Next, we iterate over all the alphabet folders in the Omniglot dataset\n", + " for alphabet in tqdm(sorted(os.listdir(path))):\n", + " lang_dict[alphabet] = [classNum, None]\n", + " #set the path to the current alphabet folder.\n", + " alpha_path = os.path.join(path, alphabet)\n", + " \n", + " #We iterate over over all the letter folders in the current alphabet folder\n", + " for letter in sorted(os.listdir(alpha_path)):\n", + " cat_images = [] #concatenate\n", + " \n", + " #iterates over all the image files in the current letter folder\n", + " for img in sorted(os.listdir(os.path.join(alpha_path, letter))):\n", + " #define the path to the current image file\n", + " img_path = os.path.join(alpha_path, letter, img)\n", + " \n", + " #read the current image file and convert it to grayscale\n", + " img_gray = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2GRAY)\n", + " \n", + " #Resize the grayscale image\n", + " img_resized = cv2.resize(img_gray, (224, 224))\n", + " \n", + " # Convert the resized grayscale image to RGB\n", + " img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_GRAY2RGB)\n", + " cat_images.append(img_rgb)\n", + " \n", + " y.append(classNum)\n", + " \n", + " classNum+=1\n", + " X.append(cat_images) #appends the list of images for the current letter\n", + " lang_dict[alphabet][1] = classNum - 1 #Sets the second val in the list to the current class number-1\n", + " #Make X and y numpy arrays\n", + " X = np.array(X)\n", + " print(X.shape)\n", + " y = np.array(y)\n", + " return X, y, lang_dict\n", + "\n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5e535b31", + "metadata": {}, + "outputs": [], + "source": [ + "#We store the path to our training directory and evaluation directory\n", + "img_train_PATH = '/Users/siddharthsinha/Desktop/Spring_2023/CSE_5819/Honors_work/omniglot/python/images_background'\n", + "img_eval_PATH = '/Users/siddharthsinha/Desktop/Spring_2023/CSE_5819/Honors_work/omniglot/python/images_evaluation'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ebaf5238", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:10<00:00, 2.88it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(964, 20, 224, 224, 3)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00, 2.61it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(659, 20, 224, 224, 3)\n" + ] + } + ], + "source": [ + "trainImages, trainLabels, lang_dict = load_images(img_train_PATH)\n", + "valImages, valLabels, lang_dictVal = load_images(img_eval_PATH)" + ] + }, + { + "cell_type": "markdown", + "id": "2006ec33", + "metadata": {}, + "source": [ + "---\n", + "The model can be trained using a batch generator that randomly samples pairs of images from the training set. We also define the functions make_one_shot_task and test_one_shot to test the siamese network. " + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "2fa71b32", + "metadata": {}, + "outputs": [], + "source": [ + "def get_batch(batch_size,dset='train'):\n", + " if dset == 'train':\n", + " X = trainImages\n", + " else:\n", + " X = valImages\n", + " \n", + " n_classes, n_examples, w, h, _ = X.shape\n", + " cat = rng.choice(n_classes, size=batch_size, replace=False)\n", + " targets = np.zeros((batch_size,))\n", + " targets[batch_size//2:] = 1\n", + " pairs = [np.zeros((batch_size,w,h,3)) for _ in range(2)]\n", + " for i in range(batch_size):\n", + " ex_no = rng.randint(n_examples)\n", + " pairs[0][i,:,:,:] = X[cat[i],ex_no,:,:].reshape(w,h,3)\n", + " cat2 = 0\n", + " if i >= batch_size // 2:\n", + " cat2 = cat[i]\n", + " else:\n", + " cat2 = (cat[i] + rng.randint(1,n_classes)) % n_classes\n", + " ex_no2 = rng.randint(n_examples)\n", + " pairs[1][i,:,:,:] = X[cat2,ex_no2,:,:].reshape(w,h,3)\n", + " return pairs,targets\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "aede925a", + "metadata": {}, + "outputs": [], + "source": [ + "def make_one_shot_task(N, dset='val'):\n", + " if dset == 'train':\n", + " X = trainImages\n", + " else:\n", + " X = valImages\n", + " n_classes, n_examples, w, h, _ = X.shape # Updated to unpack 5 values\n", + " cats = rng.choice(n_classes, size=(N,))\n", + " indices = rng.choice(n_examples, size=(N,))\n", + " true_cat = cats[0]\n", + " ex1 = rng.randint(n_examples)\n", + " test_image = np.array([X[true_cat, ex1]] * N).reshape(N, w, h, 3) # Updated to handle 3 channels\n", + " support_set = X[cats, indices].reshape(N, w, h, 3) # Updated to handle 3 channels\n", + " targets = np.zeros((N,))\n", + " targets[0] = 1\n", + "\n", + " test_image, support_set, targets = shuffle(test_image, support_set, targets)\n", + "\n", + " return [test_image, support_set], targets" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "8cd66805", + "metadata": {}, + "outputs": [], + "source": [ + "def test_one_shot(model,N,k,dset='val'):\n", + " n_correct = 0\n", + " for _ in range(k):\n", + " inputs, outputs = make_one_shot_task(N,dset)\n", + " preds = model.predict(inputs)\n", + " if np.argmax(outputs) == np.argmax(preds):\n", + " n_correct += 1\n", + " return n_correct / k" + ] + }, + { + "cell_type": "markdown", + "id": "fe5f9d70", + "metadata": {}, + "source": [ + "---\n", + "### Siamese Neural Network" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c1e40b34", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/siddharthsinha/Desktop/Spring_2023/CSE_5819/Honors_work/myenv/lib/python3.9/site-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: \n", + "\n", + "TensorFlow Addons (TFA) has ended development and introduction of new features.\n", + "TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.\n", + "Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). \n", + "\n", + "For more information see: https://github.com/tensorflow/addons/issues/2807 \n", + "\n", + " warnings.warn(\n", + "/Users/siddharthsinha/Desktop/Spring_2023/CSE_5819/Honors_work/myenv/lib/python3.9/site-packages/tensorflow_addons/utils/ensure_tf_install.py:53: UserWarning: Tensorflow Addons supports using Python ops for all Tensorflow versions above or equal to 2.10.0 and strictly below 2.13.0 (nightly versions are not supported). \n", + " The versions of TensorFlow you are currently using is 2.7.0 and is not supported. \n", + "Some things might work, some things might not.\n", + "If you were to encounter a bug, do not file an issue.\n", + "If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. \n", + "You can find the compatibility matrix in TensorFlow Addon's readme:\n", + "https://github.com/tensorflow/addons\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from vit_keras import vit\n", + "\n", + "def create_vit_embedding_model(input_shape):\n", + " vit_model = vit.vit_b32(\n", + " image_size=input_shape[0],\n", + " activation=None,\n", + " pretrained=True,\n", + " include_top=False,\n", + " pretrained_top=False,\n", + " weights='imagenet21k',\n", + " )\n", + "\n", + " input_layer = Input(input_shape)\n", + " outputs = vit_model(input_layer)\n", + " return Model(inputs=input_layer, outputs=outputs)\n", + "\n", + "def get_siamese(input_shape, patch_size, num_patches):\n", + " # Define input tensors\n", + " left_input = Input(input_shape)\n", + " right_input = Input(input_shape)\n", + "\n", + " embedding_model = create_vit_embedding_model(input_shape)\n", + "\n", + " left_emb = embedding_model(left_input)\n", + " right_emb = embedding_model(right_input)\n", + "\n", + " L1_Layer = Lambda(lambda tensors: K.abs(tensors[0] - tensors[1]))\n", + " L1_Dist = L1_Layer([left_emb, right_emb])\n", + " OP = Dense(1, activation='sigmoid', kernel_regularizer='l2')(L1_Dist)\n", + "\n", + " siamese_net = Model(inputs=[left_input, right_input], outputs=OP)\n", + "\n", + " return siamese_net" + ] + }, + { + "cell_type": "markdown", + "id": "223fcea2", + "metadata": {}, + "source": [ + "### Training Loop and Hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "bacc4b06", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(964, 20, 224, 224, 3)\n", + "Model: \"model_2\"\n", + "__________________________________________________________________________________________________\n", + " Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + " input_5 (InputLayer) [(None, 224, 224, 3 0 [] \n", + " )] \n", + " \n", + " input_6 (InputLayer) [(None, 224, 224, 3 0 [] \n", + " )] \n", + " \n", + " model_1 (Functional) (None, 768) 88045824 ['input_5[0][0]', \n", + " 'input_6[0][0]'] \n", + " \n", + " lambda_1 (Lambda) (None, 768) 0 ['model_1[0][0]', \n", + " 'model_1[1][0]'] \n", + " \n", + " dense (Dense) (None, 1) 769 ['lambda_1[0][0]'] \n", + " \n", + "==================================================================================================\n", + "Total params: 88,046,593\n", + "Trainable params: 88,046,593\n", + "Non-trainable params: 0\n", + "__________________________________________________________________________________________________\n" + ] + } + ], + "source": [ + "num_iterations = 7000\n", + "batch_size = 128\n", + "\n", + "evaluateEvery = 100\n", + "k = 250\n", + "N = 1\n", + "\n", + "#n_classes, n_examples, w, h = trainImages.shape\n", + "print(trainImages.shape)\n", + "\n", + "lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(\n", + " initial_learning_rate=0.05,\n", + " decay_steps=4000,\n", + " decay_rate=0.0001)\n", + "\n", + "opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)\n", + "\n", + "input_shape = (224, 224, 3)\n", + "model = get_siamese(input_shape, 7, 225)\n", + "\n", + "model.compile(\n", + " loss='binary_crossentropy',\n", + " optimizer=opt,\n", + " metrics=['accuracy']\n", + ")\n", + "\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "id": "c470585d", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "7b120f9d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration 0 (350.5s) - Loss: 1.6035696268081665 Acc: 0.5 250 1-way train accuracy: 100.0 %, 250 1-way val accuracy: 100.0 %\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[25], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m,num_iterations\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m):\n\u001b[1;32m 8\u001b[0m x,y \u001b[38;5;241m=\u001b[39m get_batch(batch_size)\n\u001b[0;32m----> 9\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_on_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m evaluateEvery \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 11\u001b[0m lossArr\u001b[38;5;241m.\u001b[39mappend(loss[\u001b[38;5;241m0\u001b[39m])\n", + "File \u001b[0;32m~/Desktop/Spring_2023/CSE_5819/Honors_work/myenv/lib/python3.9/site-packages/keras/engine/training.py:1900\u001b[0m, in \u001b[0;36mModel.train_on_batch\u001b[0;34m(self, x, y, sample_weight, class_weight, reset_metrics, return_dict)\u001b[0m\n\u001b[1;32m 1896\u001b[0m iterator \u001b[38;5;241m=\u001b[39m data_adapter\u001b[38;5;241m.\u001b[39msingle_batch_iterator(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdistribute_strategy, x,\n\u001b[1;32m 1897\u001b[0m y, sample_weight,\n\u001b[1;32m 1898\u001b[0m class_weight)\n\u001b[1;32m 1899\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_function \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmake_train_function()\n\u001b[0;32m-> 1900\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43miterator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1902\u001b[0m logs \u001b[38;5;241m=\u001b[39m tf_utils\u001b[38;5;241m.\u001b[39msync_to_numpy_or_python_type(logs)\n\u001b[1;32m 1903\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_dict:\n", + "File \u001b[0;32m~/Desktop/Spring_2023/CSE_5819/Honors_work/myenv/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py:150\u001b[0m, in \u001b[0;36mfilter_traceback..error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 148\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 150\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 152\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n", + "File \u001b[0;32m~/Desktop/Spring_2023/CSE_5819/Honors_work/myenv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py:910\u001b[0m, in \u001b[0;36mFunction.__call__\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m 907\u001b[0m compiler \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mxla\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jit_compile \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnonXla\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 909\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m OptionalXlaContext(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jit_compile):\n\u001b[0;32m--> 910\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 912\u001b[0m new_tracing_count \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexperimental_get_tracing_count()\n\u001b[1;32m 913\u001b[0m without_tracing \u001b[38;5;241m=\u001b[39m (tracing_count \u001b[38;5;241m==\u001b[39m new_tracing_count)\n", + "File \u001b[0;32m~/Desktop/Spring_2023/CSE_5819/Honors_work/myenv/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py:942\u001b[0m, in \u001b[0;36mFunction._call\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m 939\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[1;32m 940\u001b[0m \u001b[38;5;66;03m# In this case we have created variables on the first call, so we run the\u001b[39;00m\n\u001b[1;32m 941\u001b[0m \u001b[38;5;66;03m# defunned version which is guaranteed to never create variables.\u001b[39;00m\n\u001b[0;32m--> 942\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_stateless_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# pylint: disable=not-callable\u001b[39;00m\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_stateful_fn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 944\u001b[0m \u001b[38;5;66;03m# Release the lock early so that multiple threads can perform the call\u001b[39;00m\n\u001b[1;32m 945\u001b[0m \u001b[38;5;66;03m# in parallel.\u001b[39;00m\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock\u001b[38;5;241m.\u001b[39mrelease()\n", + "File \u001b[0;32m~/Desktop/Spring_2023/CSE_5819/Honors_work/myenv/lib/python3.9/site-packages/tensorflow/python/eager/function.py:3130\u001b[0m, in \u001b[0;36mFunction.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3127\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock:\n\u001b[1;32m 3128\u001b[0m (graph_function,\n\u001b[1;32m 3129\u001b[0m filtered_flat_args) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_maybe_define_function(args, kwargs)\n\u001b[0;32m-> 3130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mgraph_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_flat\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3131\u001b[0m \u001b[43m \u001b[49m\u001b[43mfiltered_flat_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcaptured_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgraph_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcaptured_inputs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Desktop/Spring_2023/CSE_5819/Honors_work/myenv/lib/python3.9/site-packages/tensorflow/python/eager/function.py:1959\u001b[0m, in \u001b[0;36mConcreteFunction._call_flat\u001b[0;34m(self, args, captured_inputs, cancellation_manager)\u001b[0m\n\u001b[1;32m 1955\u001b[0m possible_gradient_type \u001b[38;5;241m=\u001b[39m gradients_util\u001b[38;5;241m.\u001b[39mPossibleTapeGradientTypes(args)\n\u001b[1;32m 1956\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (possible_gradient_type \u001b[38;5;241m==\u001b[39m gradients_util\u001b[38;5;241m.\u001b[39mPOSSIBLE_GRADIENT_TYPES_NONE\n\u001b[1;32m 1957\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m executing_eagerly):\n\u001b[1;32m 1958\u001b[0m \u001b[38;5;66;03m# No tape is watching; skip to running the function.\u001b[39;00m\n\u001b[0;32m-> 1959\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_build_call_outputs(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_inference_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1960\u001b[0m \u001b[43m \u001b[49m\u001b[43mctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcancellation_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcancellation_manager\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 1961\u001b[0m forward_backward \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_select_forward_and_backward_functions(\n\u001b[1;32m 1962\u001b[0m args,\n\u001b[1;32m 1963\u001b[0m possible_gradient_type,\n\u001b[1;32m 1964\u001b[0m executing_eagerly)\n\u001b[1;32m 1965\u001b[0m forward_function, args_with_tangents \u001b[38;5;241m=\u001b[39m forward_backward\u001b[38;5;241m.\u001b[39mforward()\n", + "File \u001b[0;32m~/Desktop/Spring_2023/CSE_5819/Honors_work/myenv/lib/python3.9/site-packages/tensorflow/python/eager/function.py:598\u001b[0m, in \u001b[0;36m_EagerDefinedFunction.call\u001b[0;34m(self, ctx, args, cancellation_manager)\u001b[0m\n\u001b[1;32m 596\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _InterpolateFunctionError(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 597\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cancellation_manager \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 598\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mexecute\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 599\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msignature\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 600\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_outputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_num_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 601\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 602\u001b[0m \u001b[43m \u001b[49m\u001b[43mattrs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattrs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 603\u001b[0m \u001b[43m \u001b[49m\u001b[43mctx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mctx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 604\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 605\u001b[0m outputs \u001b[38;5;241m=\u001b[39m execute\u001b[38;5;241m.\u001b[39mexecute_with_cancellation(\n\u001b[1;32m 606\u001b[0m \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msignature\u001b[38;5;241m.\u001b[39mname),\n\u001b[1;32m 607\u001b[0m num_outputs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_outputs,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 610\u001b[0m ctx\u001b[38;5;241m=\u001b[39mctx,\n\u001b[1;32m 611\u001b[0m cancellation_manager\u001b[38;5;241m=\u001b[39mcancellation_manager)\n", + "File \u001b[0;32m~/Desktop/Spring_2023/CSE_5819/Honors_work/myenv/lib/python3.9/site-packages/tensorflow/python/eager/execute.py:58\u001b[0m, in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 57\u001b[0m ctx\u001b[38;5;241m.\u001b[39mensure_initialized()\n\u001b[0;32m---> 58\u001b[0m tensors \u001b[38;5;241m=\u001b[39m \u001b[43mpywrap_tfe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTFE_Py_Execute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mctx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_handle\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mop_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattrs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_outputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m core\u001b[38;5;241m.\u001b[39m_NotOkStatusException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "lossArr = []\n", + "trainAccArr = []\n", + "valAccArr = []\n", + "currTime = time.time()\n", + "x,y = get_batch(batch_size)\n", + "\n", + "for i in range(0,num_iterations+1):\n", + " x,y = get_batch(batch_size)\n", + " loss = model.train_on_batch(x,y)\n", + " if i % evaluateEvery == 0:\n", + " lossArr.append(loss[0])\n", + " trainAcc = round(test_one_shot(model,N,k,'train') * 100,2)\n", + " valAcc = round(test_one_shot(model,N,k,'val') * 100,2)\n", + " trainAccArr.append(trainAcc)\n", + " valAccArr.append(valAcc)\n", + " print('Iteration',i,'('+str(round(time.time() - currTime,1))+'s) - Loss:',loss[0],'Acc:',round(loss[1],2),'',end='')\n", + " print(k,str(N)+'-way train accuracy:', trainAcc,'%, ',end='')\n", + " print(k,str(N)+'-way val accuracy:', valAcc,'%')\n", + " currTime = time.time()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "2da464e2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Final Validation Accuracy: 100.0\n" + ] + } + ], + "source": [ + "print('Final Validation Accuracy:', round(test_one_shot(model,N,k,'val') * 100,2))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}