PyTorch/fashion-mnist-with-pytorch-93-accuracy.ipynb

998 lines
168 KiB
Text
Raw Normal View History

2021-08-24 02:59:53 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### This is the tutorial of deep learning on FashionMNIST dataset using Pytorch. We will build a Convolutional Neural Network for predicting the classes of Dataset. I am assuming you know the basics of deep leanrning like layer architecture... convolution concepts. Without further ado... Lets start the tutorial."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "1rm65OztSCpF"
},
"source": [
"# **Importing Important Libraries**"
]
},
{
"cell_type": "code",
2021-08-25 01:27:14 +00:00
"execution_count": 1,
2021-08-24 02:59:53 +00:00
"metadata": {
"colab": {},
"colab_type": "code",
"id": "-1_8VZgpEtea"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.autograd import Variable\n",
"\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from sklearn.metrics import confusion_matrix"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "s6bjL6wESfoV"
},
"source": [
"### If the GPU is available use it for the computation otherwise use the CPU."
]
},
{
"cell_type": "code",
2021-08-25 01:27:14 +00:00
"execution_count": 2,
2021-08-24 02:59:53 +00:00
"metadata": {
"colab": {},
"colab_type": "code",
"id": "q6btjJ9YTXXw"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "CGIhqokgSv1W"
},
"source": [
"There are 2 ways to load the Fashion MNIST dataset. \n",
"\n",
"\n",
" 1. Load csv and then inherite Pytorch Dataset class .\n",
" 2. Use Pytorch module torchvision.datasets. It has many popular datasets like MNIST, FashionMNIST, CIFAR10 e.t.c.\n",
" \n",
" \n",
"\n",
"* We use DataLoader class from torch.utils.data to load data in batches in both method.\n",
"* Comment out the code of a method which you are not using. \n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "4jJh32jCT_J0"
},
"source": [
"### 1. Using a Dataset class.\n",
" \n",
" * First load the data from the disk using pandas read_csv() method.\n",
"\n",
" * Now inherit Dataset class in your own class that you are building, lets say FashionData.\n",
"\n",
" * It has 2 methods: __get_item__( ) and __len__().\n",
" * __get_item__( ) return the images and labels and __len__( ) returns the number of items in a dataset."
]
},
{
"cell_type": "code",
2021-08-25 00:49:49 +00:00
"execution_count": 7,
2021-08-24 02:59:53 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 413
},
"colab_type": "code",
"id": "PV2g1_8qUDvA",
"outputId": "b429a123-7574-413d-d15a-92f0481e6753"
},
2021-08-25 00:49:49 +00:00
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: '../input/fashion-mnist_train.csv'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_45959/798855398.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrain_csv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"../input/fashion-mnist_train.csv\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mtest_csv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"../input/fashion-mnist_test.csv\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/util/_decorators.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0mstacklevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstacklevel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 310\u001b[0m )\n\u001b[0;32m--> 311\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 312\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36mread_csv\u001b[0;34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, squeeze, prefix, mangle_dupe_cols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, error_bad_lines, warn_bad_lines, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options)\u001b[0m\n\u001b[1;32m 584\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwds_defaults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 585\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 586\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_read\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 587\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 588\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m_read\u001b[0;34m(filepath_or_buffer, kwds)\u001b[0m\n\u001b[1;32m 480\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 481\u001b[0m \u001b[0;31m# Create the parser.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 482\u001b[0;31m \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTextFileReader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 483\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 484\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mchunksize\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, f, engine, **kwds)\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"has_index_names\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"has_index_names\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 810\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 811\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_engine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 812\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 813\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m_make_engine\u001b[0;34m(self, engine)\u001b[0m\n\u001b[1;32m 1038\u001b[0m )\n\u001b[1;32m 1039\u001b[0m \u001b[0;31m# error: Too many arguments for \"ParserBase\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1040\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmapping\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[call-arg]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1041\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_failover_to_python\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/c_parser_wrapper.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, src, **kwds)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;31m# open handles\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_open_handles\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandles\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/base_parser.py\u001b[0m in \u001b[0;36m_open_handles\u001b[0;34m(self, src, kwds)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0mLet\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mreaders\u001b[0m \u001b[0mopen\u001b[0m \u001b[0mIOHandles\u001b[0m \u001b[0mafter\u001b[0m \u001b[0mthey\u001b[0m \u001b[0mare\u001b[0m \u001b[0mdone\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtheir\u001b[0m \u001b[0mpotential\u001b[0m \u001b[0mraises\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 221\u001b[0m \"\"\"\n\u001b[0;32m--> 222\u001b[0;31m self.handles = get_handle(\n\u001b[0m\u001b[1;32m 223\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;34m\"r\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/common.py\u001b[0m in \u001b[0;36mget_handle\u001b[0;34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[0m\n\u001b[1;32m 699\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoding\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m\"b\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 700\u001b[0m \u001b[0;31m# Encoding\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 701\u001b[0;31m handle = open(\n\u001b[0m\u001b[1;32m 702\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 703\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../input/fashion-mnist_train.csv'"
]
}
],
2021-08-24 02:59:53 +00:00
"source": [
"train_csv = pd.read_csv(\"../input/fashion-mnist_train.csv\")\n",
"test_csv = pd.read_csv(\"../input/fashion-mnist_test.csv\")"
]
},
{
"cell_type": "code",
2021-08-25 00:49:49 +00:00
"execution_count": 3,
2021-08-24 02:59:53 +00:00
"metadata": {
"colab": {},
"colab_type": "code",
"id": "3Q36xjPfeo0a"
},
"outputs": [],
"source": [
"class FashionDataset(Dataset):\n",
" \"\"\"User defined class to build a datset using Pytorch class Dataset.\"\"\"\n",
" \n",
" def __init__(self, data, transform = None):\n",
" \"\"\"Method to initilaize variables.\"\"\" \n",
" self.fashion_MNIST = list(data.values)\n",
" self.transform = transform\n",
" \n",
" label = []\n",
" image = []\n",
" \n",
" for i in self.fashion_MNIST:\n",
" # first column is of labels.\n",
" label.append(i[0])\n",
" image.append(i[1:])\n",
" self.labels = np.asarray(label)\n",
" # Dimension of Images = 28 * 28 * 1. where height = width = 28 and color_channels = 1.\n",
" self.images = np.asarray(image).reshape(-1, 28, 28, 1).astype('float32')\n",
"\n",
" def __getitem__(self, index):\n",
" label = self.labels[index]\n",
" image = self.images[index]\n",
" \n",
" if self.transform is not None:\n",
" image = self.transform(image)\n",
"\n",
" return image, label\n",
"\n",
" def __len__(self):\n",
" return len(self.images)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "uLeNwQWJkFkQ"
},
"outputs": [],
"source": [
"# Transform data into Tensor that has a range from 0 to 1\n",
"train_set = FashionDataset(train_csv, transform=transforms.Compose([transforms.ToTensor()]))\n",
"test_set = FashionDataset(test_csv, transform=transforms.Compose([transforms.ToTensor()]))\n",
"\n",
"train_loader = DataLoader(train_set, batch_size=100)\n",
"test_loader = DataLoader(train_set, batch_size=100)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "VePPK186m6iS"
},
"source": [
"### 2. Using FashionMNIST class from torchvision module.\n",
"\n",
"\n",
"* It will download the dataset first time.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
2021-08-25 01:27:14 +00:00
"execution_count": 4,
2021-08-24 02:59:53 +00:00
"metadata": {
"colab": {},
"colab_type": "code",
"id": "zM618_wYGM0n"
},
2021-08-25 01:27:14 +00:00
"outputs": [],
2021-08-24 02:59:53 +00:00
"source": [
"\n",
"train_set = torchvision.datasets.FashionMNIST(\"./data\", download=True, transform=\n",
" transforms.Compose([transforms.ToTensor()]))\n",
"test_set = torchvision.datasets.FashionMNIST(\"./data\", download=True, train=False, transform=\n",
" transforms.Compose([transforms.ToTensor()])) \n",
" "
]
},
{
"cell_type": "code",
2021-08-25 01:27:14 +00:00
"execution_count": 5,
2021-08-24 02:59:53 +00:00
"metadata": {
"colab": {},
"colab_type": "code",
"id": "8s2uhlGJZOOP"
},
"outputs": [],
"source": [
"\n",
"train_loader = torch.utils.data.DataLoader(train_set, \n",
" batch_size=100)\n",
"test_loader = torch.utils.data.DataLoader(test_set,\n",
" batch_size=100)\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "4SfsxOX4peHU"
},
"source": [
"### We have 10 types of clothes in FashionMNIST dataset.\n",
"\n",
"\n",
"> Making a method that return the name of class for the label number.\n",
"ex. if the label is 5, we return Sandal.\n",
"\n"
]
},
{
"cell_type": "code",
2021-08-25 01:27:14 +00:00
"execution_count": 6,
2021-08-24 02:59:53 +00:00
"metadata": {
"colab": {},
"colab_type": "code",
"id": "uWIE3hVqOlMi"
},
"outputs": [],
"source": [
"def output_label(label):\n",
" output_mapping = {\n",
" 0: \"T-shirt/Top\",\n",
" 1: \"Trouser\",\n",
" 2: \"Pullover\",\n",
" 3: \"Dress\",\n",
" 4: \"Coat\", \n",
" 5: \"Sandal\", \n",
" 6: \"Shirt\",\n",
" 7: \"Sneaker\",\n",
" 8: \"Bag\",\n",
" 9: \"Ankle Boot\"\n",
" }\n",
" input = (label.item() if type(label) == torch.Tensor else label)\n",
" return output_mapping[input]"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "yDH7i5UFo7w3"
},
"source": [
"### Playing with data and displaying some images using matplotlib imshow() method.\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
2021-08-25 00:49:49 +00:00
"execution_count": 11,
2021-08-24 02:59:53 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"colab_type": "code",
"id": "RB9jenaDYZmt",
"outputId": "3d8212d3-3cbf-4e30-b369-e6ede4ad6350"
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([100, 1, 28, 28])"
]
},
2021-08-25 00:49:49 +00:00
"execution_count": 11,
2021-08-24 02:59:53 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = next(iter(train_loader))\n",
"a[0].size()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"colab_type": "code",
"id": "D-PnIoRpjuZW",
"outputId": "cff25b7c-a4c6-42ea-81d1-43a4789959bc"
},
"outputs": [
{
"data": {
"text/plain": [
"60000"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_set)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 287
},
"colab_type": "code",
"id": "2kC6CrJrlbf_",
"outputId": "2eefede6-0e54-4512-b6a8-1c17eb42286a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAR10lEQVR4nO3db2yVdZYH8O+xgNqCBaxA+RPBESOTjVvWikbRjI4Q9IUwanB4scGo24kZk5lkTNa4L8bEFxLdmcm+IJN01AyzzjqZZCBi/DcMmcTdFEcqYdtKd0ZACK2lBUFoS6EUzr7og+lgn3Pqfe69z5Xz/SSk7T393fvrvf1yb+95fs9PVBVEdOm7LO8JEFF5MOxEQTDsREEw7ERBMOxEQUwq542JCN/6JyoxVZXxLs/0zC4iq0TkryKyV0SeyXJdRFRaUmifXUSqAPwNwAoAXQB2AlinqnuMMXxmJyqxUjyzLwOwV1X3q+owgN8BWJ3h+oiohLKEfR6AQ2O+7kou+zsi0iQirSLSmuG2iCijkr9Bp6rNAJoBvownylOWZ/ZuAAvGfD0/uYyIKlCWsO8EsFhEFonIFADfB7C1ONMiomIr+GW8qo6IyFMA3gNQBeBVVf24aDMjoqIquPVW0I3xb3aikivJQTVE9M3BsBMFwbATBcGwEwXBsBMFwbATBcGwEwXBsBMFwbATBcGwEwXBsBMFwbATBcGwEwVR1lNJU/mJjLsA6ktZVz1OmzbNrC9fvjy19s4772S6be9nq6qqSq2NjIxkuu2svLlbCn3M+MxOFATDThQEw04UBMNOFATDThQEw04UBMNOFAT77Je4yy6z/z8/d+6cWb/++uvN+hNPPGHWh4aGUmuDg4Pm2NOnT5v1Dz/80Kxn6aV7fXDvfvXGZ5mbdfyA9XjymZ0oCIadKAiGnSgIhp0oCIadKAiGnSgIhp0oCPbZL3FWTxbw++z33HOPWb/33nvNeldXV2rt8ssvN8dWV1eb9RUrVpj1l19+ObXW29trjvXWjHv3m2fq1KmptfPnz5tjT506VdBtZgq7iBwA0A/gHIARVW3Mcn1EVDrFeGa/W1WPFuF6iKiE+Dc7URBZw64A/igiH4lI03jfICJNItIqIq0Zb4uIMsj6Mn65qnaLyCwA20Tk/1T1/bHfoKrNAJoBQESynd2QiAqW6ZldVbuTj30AtgBYVoxJEVHxFRx2EakRkWkXPgewEkBHsSZGRMWV5WX8bABbknW7kwD8l6q+W5RZUdEMDw9nGn/LLbeY9YULF5p1q8/vrQl/7733zPrSpUvN+osvvphaa22130Jqb283652dnWZ92TL7Ra51v7a0tJhjd+zYkVobGBhIrRUcdlXdD+AfCx1PROXF1htREAw7URAMO1EQDDtREAw7URCSdcver3VjPIKuJKzTFnuPr7dM1GpfAcD06dPN+tmzZ1Nr3lJOz86dO8363r17U2tZW5L19fVm3fq5AXvuDz/8sDl248aNqbXW1lacPHly3F8IPrMTBcGwEwXBsBMFwbATBcGwEwXBsBMFwbATBcE+ewXwtvfNwnt8P/jgA7PuLWH1WD+bt21x1l64teWz1+PftWuXWbd6+ID/s61atSq1dt1115lj582bZ9ZVlX12osgYdqIgGHaiIBh2oiAYdqIgGHaiIBh2oiC4ZXMFKOexDhc7fvy4WffWbQ8NDZl1a1vmSZPsXz9rW2PA7qMDwJVXXpla8/rsd955p1m//fbbzbp3muxZs2al1t59tzRnZOczO1EQDDtREAw7URAMO1EQDDtREAw7URAMO1EQ7LMHV11dbda9frFXP3XqVGrtxIkT5tjPP//crHtr7a3jF7xzCHg/l3e/nTt3zqxbff4FCxaYYwvlPrOLyKsi0iciHWMumyki20Tkk+TjjJLMjoiKZiIv438N4OLTajwDYLuqLgawPfmaiCqYG3ZVfR/AsYsuXg1gU/L5JgBrijstIiq2Qv9mn62qPcnnhwHMTvtGEWkC0FTg7RBRkWR+g05V1TqRpKo2A2gGeMJJojwV2nrrFZF6AEg+9hVvSkRUCoWGfSuA9cnn6wG8UZzpEFGpuC/jReR1AN8BUCciXQB+CmADgN+LyOMADgJYW8pJXuqy9nytnq63Jnzu3Llm/cyZM5nq1np277zwVo8e8PeGt/r0Xp98ypQpZr2/v9+s19bWmvW2trbUmveYNTY2ptb27NmTWnPDrqrrUkrf9cYSUeXg4bJEQTDsREEw7ERBMOxEQTDsREFwiWsF8E4lXVVVZdat1tsjjzxijp0zZ45ZP3LkiFm3TtcM2Es5a2pqzLHeUk+vdWe1/c6ePWuO9U5z7f3cV199tVnfuHFjaq2hocEca83NauPymZ0oCIadKAiGnSgIhp0oCIadKAiGnSgIhp0oCCnndsE8U834vJ7uyMhIwdd96623mvW33nrLrHtbMmc5BmDatGnmWG9LZu9U05MnTy6oBvjHAHhbXXusn+2ll14yx7722mtmXVXHbbbzmZ0oCIadKAiGnSgIhp0oCIadKAiGnSgIhp0oiG/UenZrra7X7/VOx+ydztla/2yt2Z6ILH10z9tvv23WBwcHzbrXZ/dOuWwdx+Gtlfce0yuuuMKse2vWs4z1HnNv7jfddFNqzdvKulB8ZicKgmEnCoJhJwqCYScKgmEnCoJhJwqCYScKoqL67FnWRpeyV11qd911l1l/6KGHzPodd9yRWvO2PfbWhHt9dG8tvvWYeXPzfh+s88IDdh/eO4+DNzePd78NDAyk1h588EFz7JtvvlnQnNxndhF5VUT6RKRjzGXPiUi3iOxO/t1f0K0TUdlM5GX8rwGsGufyX6hqQ/LPPkyLiHLnhl1V3wdwrAxzIaISyvIG3VMi0pa8zJ+R9k0i0iQirSLSmuG2iCijQsP+SwDfAtAAoAfAz9K+UVWbVbVRVRsLvC0iKoKCwq6qvap6TlXPA/gVgGXFnRYRFVtBYReR+jFffg9AR9r3ElFlcM8bLyKvA/gOgDoAvQB+mnzdAEABHADwA1XtcW8sx/PGz5w506zPnTvXrC9evLjgsV7f9IYbbjDrZ86cMevWWn1vXba3z/hnn31m1r3zr1v9Zm8Pc2//9erqarPe0tKSWps6dao51jv2wVvP7q1Jt+633t5ec+ySJUvMetp5492DalR13TgXv+KNI6LKwsNliYJg2ImCYNiJgmDYiYJg2ImCqKgtm2+77TZz/PPPP59au+aaa8yx06dPN+vWUkzAXm75xRdfmGO95bdeC8lrQVmnwfZOBd3Z2WnW165da9ZbW+2joK1tmWfMSD3KGgCwcOFCs+7Zv39/as3bLrq/v9+se0tgvZam1fq76qqrzLHe7wu3bCYKjmEnCoJhJwqCYScKgmEnCoJhJwqCYScKoux9dqtfvWPHDnN8fX19as3rk3v1LKcO9k557PW6s6qtrU2t1dXVmWMfffRRs75y5Uqz/uSTT5p1a4ns6dOnzbGffvqpWbf66IC9LDnr8lpvaa/Xx7fGe8tnr732WrPOPjtRcAw7URAMO1EQDDtREAw7URAMO1EQDDtREGXts9fV1ekDDzyQWt+wYYM5ft++fak179TAXt3b/tfi9VytPjgAHDp0yKx7p3O21vJbp5kGgDlz5pj1NWvWmHVrW2TAXpPuPSY333xzprr1s3t9dO9+87Zk9ljnIPB+n6zzPhw+fBjDw8PssxNFxrATBcGwEwXBsBMFwbATBcGwEwXBsBMF4e7iWkwjIyPo6+tLrXv9ZmuNsLetsXfdXs/X6qt65/k+duyYWT948KBZ9+ZmrZf31ox757TfsmWLWW9vbzfrVp/d20bb64V75+u3tqv2fm5vTbnXC/fGW312r4dvbfFt3SfuM7uILBCRP4vIHhH5WER+lFw+U0S2icgnyUf7jP9ElKuJvIwfAfATVf02gNsA/FBEvg3gGQDbVXUxgO3J10RUodywq2qPqu5KPu8H0AlgHoDVADYl37YJwJoSzZGIiuBrvUEnIgsBLAXwFwCzVbUnKR0GMDtlTJOItIpIq/c3GBGVzoTDLiJTAfwBwI9V9eTYmo6
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"image, label = next(iter(train_set))\n",
"plt.imshow(image.squeeze(), cmap=\"gray\")\n",
"print(label)"
]
},
{
"cell_type": "code",
2021-08-25 01:27:14 +00:00
"execution_count": 7,
2021-08-24 02:59:53 +00:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 54
},
"colab_type": "code",
"id": "sgpYsgh0PY09",
"outputId": "4d00bf1d-fb85-4e2a-bb21-067e75360637"
},
"outputs": [
{
2021-08-25 01:27:14 +00:00
"name": "stdout",
2021-08-27 02:51:03 +00:00
"output_type": "stream",
2021-08-24 02:59:53 +00:00
"text": [
2021-08-27 02:51:03 +00:00
"<class 'torch.Tensor'> <class 'torch.Tensor'>\n",
"torch.Size([10, 1, 28, 28]) torch.Size([10])\n"
2021-08-24 02:59:53 +00:00
]
}
],
"source": [
"demo_loader = torch.utils.data.DataLoader(train_set, batch_size=10)\n",
"\n",
"batch = next(iter(demo_loader))\n",
"images, labels = batch\n",
"print(type(images), type(labels))\n",
"print(images.shape, labels.shape)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 158
},
"colab_type": "code",
"id": "2Z0D4BgQRW8e",
"outputId": "c33042d3-017f-4dcb-8bf9-244910e2b49e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"labels: Ankle Boot, T-shirt/Top, T-shirt/Top, Dress, T-shirt/Top, Pullover, Sneaker, Pullover, Sandal, Sandal, "
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2wAAAB6CAYAAADDC9BKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAACinElEQVR4nO39aYxkWXYeCH7P9n333T08Fo/IzKiIyKUyyawskmJVsapIgRJnBIEQB9JQAwIcDChgGtM/mtN/1AP0Dw0G08IMpkcDChJEDXqaJNiSKLRINYslsUpVLGVVVlZWRmTsEe4Rvrubue378uaHx3f92I33zM093D3MIt8HONzd7Nmz9+6799xzvrMZpmnCgQMHDhw4cODAgQMHDhyMHlwv+wIcOHDgwIEDBw4cOHDgwIE1HIPNgQMHDhw4cODAgQMHDkYUjsHmwIEDBw4cOHDgwIEDByMKx2Bz4MCBAwcOHDhw4MCBgxGFY7A5cODAgQMHDhw4cODAwYjCMdgcOHDgwIEDBw4cOHDgYETxQgabYRi/bBjGPcMwHhqG8bsndVEOHDhw4MCBAwcOHDhw4AAwjtuHzTAMN4D7AL4OYA3AjwD8hmmat0/u8hw4cODAgQMHDhw4cODg8wvPC3z2ZwA8NE3zMQAYhvEHAH4NgK3BZhiG06XbgQMHDhw4cODAgQMHn2dkTdOcGPbgFwmJnAOwKv5fe/ZaHwzD+G3DMD4yDOOjF/guBw4cOHDgwIEDBw4cOHgV8OQoB7+Ih20omKb5ewB+D3A8bA4cOHDgwIEDBw4cOHBwFLyIh20dwIL4f/7Zaw4cOHDgwIEDBw4cOHDg4ATwIh62HwG4bBjGBewban8HwP/mRK7KwdjAMAwYhgG32w3DMGCapvoB0Pf3oM/zh+Bner0eer3e6d7ECMDlcsHlcqnxetF75ljqY3rcIkOjDHmv+o+8316v99z8HPRZq7GTP5+Heeng5KDPLZerny+VMtPu8/K3PL7b7b6Sa9vB2UOXnUeZV1Zz1JmXDhycDI5tsJmm2TEM4x8A+F8AuAH8c9M0PzuxK3Mw0jAMA8FgEIFAAOFwGJcvX8bk5CSq1Sp2d3fRaDRQr9dRqVSU0aUbcR6PB+FwGF6vF9FoFJlMBl6vF81mE41GA61WC2tra9jd3UWv10On03mllGS32w2v1wuPx4N0Oo2JiQn0ej3s7Owgn8+j1+uh1Wqh2+0e6bxerxeRSARerxd+vx+hUAgAsLe3p877Koyjy+VCIBCA1+tFIBBAIpGAz+dDLBZDOp2G2+1Gp9NBp9NBq9XCzs4OisUiWq0WKpUKWq0W/H4/gsEgPB4Pkskk0uk0vF4vEokEIpGIMsxM00SpVMLOzg7q9TpyuRy2t7fR6XRe9jCMDQzDgM/ng9frRbfbRbPZfCXm4SAYhqHWuN/vRzKZVL+np6fVWPCnXq+j2Wz2zTu32w232w2Xy4VIJIJoNAoAKJfLqFarqFarWF5eRi6XU8abAwfHgc/nQyKRgN/vR7vdRr1eR7fbRbvdRrPZtP2cx+NBIBCA2+1GKBRCJBIBAORyOeTzecdoc+DgBPBCOWymaf4pgD89oWtxMEYwDAPhcBiJRAKTk5P42te+hi984QvY3d3FZ599hkKhgHw+j62tLbTbbXS7XaWA0GAIBoOYnJxEKBTC7OwsXnvtNYRCIZRKJRQKBVQqFXz44Ycol8vKWHuVFDyPx4NgMAi/34/z58/j6tWr6HQ6uHXrFtrtthq3oypgPp8PqVQKoVBIGS8A8OjRIzWWrwLz6Xa7EQ6HEQwGkUwmsbi4iEgkgoWFBVy+fBl+v18pwJVKBTdv3sTTp09RrVaxubmJarWKaDSKdDqNQCCAixcvYmlpCeFwGOfPn8f09DRM01TPYWNjAzdv3kQ+n8e9e/eQy+Ucg+0IoIEdDAbRarVeOQLGCoZhwO/3w+/3Ix6P48KFC+r3W2+9hXA4jFarhWaziXa7jVwuh1KppAgqEluBQAAejwdTU1OYnZ0FAGxsbGB3d1eRCMVi8TlyzIGDo8Dn82FychLxeBzVahX5fB6tVgvVahWtVst2Xnk8HkSjUfh8PqTTaczMzKhjC4WCMx8dODgBnHrREQfjC4Y5ut1ueDweuFwu9bfb7UYqlUI8Hkc6nUYikUA8Hken01HeDbfbjV6vh3a7jV6vpwwPhu8EAgFlsPEcwWBQhVN4vV5MTExgenpasX3yXFRq+No4KSoc10AggEAggEgkgkQigW63i1QqpTxBPp+vj9nU70+G7fHvSCSCdDqNUCiEcDiMcDgMAMojSrZ0nIwNzid6aTweD7xeL+LxOILBIOLxOOLxuPJAhMNh+P1+5cU0DAPJZBK1Wg3BYBDdbhe1Wg2RSER5PVKpFJLJJEKhEKLRqPKw0cvJc/d6PeWNazQaai7SazRO40oYhoFQKASfz9cXDiq94jQsZGjpILhcLhiGoeY4lbpAIKDmdqvVQqvVQr1eH6v1K8Hx4pzk2uYP52IsFkMymUQ0GkU8HlfzlDKs3W4D2FearQw2t9uNZDKJeDwO0zRRrVbRbDbRbDbVfKRXnvKQsrbdbqvzAUcLc3MwXtDXrg6uS4/Ho/Z1eoGj0WifwcY1WqvVEA6Hbc8pIxxSqRQmJvYrlReLRVSr1b49R85RBw4cDI9jN84+1pc5VSLHBgwLCwQCiMfjmJubQygUQiaTUaE8ZI5DoRCWlpYwMTGBRqOBQqGglNx2u22b1+bxeBAKhZQBSGWRSki320W5XEa9Xken01EsX6VSUazy+vo6Hj9+jEajgUqlgkql8pJHbjC4UbpcLkxMTGBxcRHRaBQ3btzA22+/DZfLhZ2dHRQKBbTb7T5mUyrK3Ow4dsyJcblcCAaDymtUqVRQLBbRbDbx5MkTPH36FLVaDRsbG8jlci95NIaDYRhIpVLIZDIIBoO4cOEC5ubm+sgDzlOfz4dwOIxYLNaXI0SDgwqD9PiapgnDMJQi7Xa74ff74fF40Ov11FzUGedKpYJ2u41sNotCoYBisYibN29ifX38ai9FIhG8//77uHr1qlqLLpcL3W5Xjdv9+/dx8+ZNFe7caDRsz+fz+eD3+xEIBPDuu+/ii1/8Inw+HwKBgCIhOC8fPHiAH/zgBygWi8fyKL9McKzcbjdmZ2dx4cIFtf6SyWTfXJKhyvF4HFNTU0qGBgIBAFAhvMBBziUVbEmedbtd7O3toVAooF6vY2NjA8ViEZ1ORynGpVIJe3t7aDQaWF1dxerqqjIOx2mMHRwNgww2rkGv14vZ2VnMzs4iHA7j0qVLmJmZUcRVMBhURla321Vh5YM8bCTI6EUHgHw+j3w+j0qlgtu3b2NtbQ3FYhEPHz5EqVQ69D4cYsHBK44fm6b57rAHOx42B5bweDyIxWKIRqOYnZ3F9evXkUgkcOHCBbz++uvw+/1K2XW5XPD7/fB6vXC5XFhYWFChQMFg8LnkegluLpVKBblcThl4vV4Pbrcb6XQasVgMnU5H5R3t7u7iwYMHKJVKuHnzJiqVigr1q1arIy3kabC53W7EYjHMz8+rEKmlpSX4fD4sLS2pTVIabDQyaHDInCBpCNJ48Xq9WF1dxb1791CtVtHpdGAYBsrlMorF4tgYbMC+QTE3N4dYLIaf+ZmfwfXr12EYBjqdDrrdLlwuFzweDwzD6FP6OT4y1Mfr9arcyVqthkKhoAw5zr1KpdKXv9HpdODz+TA7OwuXy6XyNLrdLlZWVrC+vo6trS2sr6+PpcEWCARw48YNfOMb34DX61V5fWTXW60WgsEg1tbWUCqV0Ol0BhpsbrcbwWAQsVgMb7/9Nv7W3/pbSlF0u91oNpsoFApoNpv47ne/i88++0yt3XEyJqTsm52dxZtvvol4PI7FxUUsLCz0RSfIPFzOV7fbjUQigVQqBY/Ho7zHErKoSLVa7QsR59q/ePEiXC6XikTodDrY3t7GkydPUC6X0ev1kM1mlQI+TmPsYHjIiAsrkCQNBAI4d+4crl27hlQqhS996Uu4evVqnzdXx2H7ql5oRKY/5HI5/Pmf/zk+/fRTbGxsYGNjY6DBJouXjPJ
"text/plain": [
"<Figure size 1080x1440 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"grid = torchvision.utils.make_grid(images, nrow=10)\n",
"\n",
"plt.figure(figsize=(15, 20))\n",
"plt.imshow(np.transpose(grid, (1, 2, 0)))\n",
"print(\"labels: \", end=\" \")\n",
"for i, label in enumerate(labels):\n",
" print(output_label(label), end=\", \")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "YM2u0Riup9mx"
},
"source": [
"## Building a CNN \n",
"\n",
"\n",
"* Make a model class (FashionCNN in our case)\n",
" * It inherit nn.Module class that is a super class for all the neural networks in Pytorch.\n",
"* Our Neural Net has following layers:\n",
" * Two Sequential layers each consists of following layers-\n",
" * Convolution layer that has kernel size of 3 * 3, padding = 1 (zero_padding) in 1st layer and padding = 0 in second one. Stride of 1 in both layer.\n",
" * Batch Normalization layer.\n",
" * Acitvation function: ReLU.\n",
" * Max Pooling layer with kernel size of 2 * 2 and stride 2.\n",
" * Flatten out the output for dense layer(a.k.a. fully connected layer).\n",
" * 3 Fully connected layer with different in/out features.\n",
" * 1 Dropout layer that has class probability p = 0.25.\n",
" \n",
" * All the functionaltiy is given in forward method that defines the forward pass of CNN.\n",
" * Our input image is changing in a following way:\n",
" * First Convulation layer : input: 28 \\* 28 \\* 3, output: 28 \\* 28 \\* 32\n",
" * First Max Pooling layer : input: 28 \\* 28 \\* 32, output: 14 \\* 14 \\* 32\n",
" * Second Conv layer : input : 14 \\* 14 \\* 32, output: 12 \\* 12 \\* 64\n",
" * Second Max Pooling layer : 12 \\* 12 \\* 64, output: 6 \\* 6 \\* 64\n",
" * Final fully connected layer has 10 output features for 10 types of clothes.\n",
"\n",
"> Lets implementing the network...\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
2021-08-25 01:27:14 +00:00
"execution_count": 8,
2021-08-24 02:59:53 +00:00
"metadata": {
"colab": {},
"colab_type": "code",
"id": "hyCH0Q4hSgFB"
},
"outputs": [],
"source": [
"class FashionCNN(nn.Module):\n",
" \n",
" def __init__(self):\n",
" super(FashionCNN, self).__init__()\n",
" \n",
" self.layer1 = nn.Sequential(\n",
" nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),\n",
" nn.BatchNorm2d(32),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=2, stride=2)\n",
" )\n",
" \n",
" self.layer2 = nn.Sequential(\n",
" nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),\n",
" nn.BatchNorm2d(64),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(2)\n",
" )\n",
" \n",
" self.fc1 = nn.Linear(in_features=64*6*6, out_features=600)\n",
" self.drop = nn.Dropout2d(0.25)\n",
" self.fc2 = nn.Linear(in_features=600, out_features=120)\n",
" self.fc3 = nn.Linear(in_features=120, out_features=10)\n",
" \n",
" def forward(self, x):\n",
" out = self.layer1(x)\n",
" out = self.layer2(out)\n",
" out = out.view(out.size(0), -1)\n",
" out = self.fc1(out)\n",
" out = self.drop(out)\n",
" out = self.fc2(out)\n",
" out = self.fc3(out)\n",
" \n",
" return out\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0mMVaW_PvCC2"
},
"source": [
"### Making a model of our CNN class\n",
"\n",
"* Creating a object(model in the code)\n",
"* Transfering it into GPU if available.\n",
"* Defining a Loss function. we're using CrossEntropyLoss() here.\n",
"* Using Adam algorithm for optimization purpose.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 345
},
"colab_type": "code",
"id": "NILDHzNgQ1Gt",
"outputId": "d16327ae-e7d2-4c46-bffe-e9724272c51d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FashionCNN(\n",
" (layer1): Sequential(\n",
" (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (layer2): Sequential(\n",
" (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
" (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): ReLU()\n",
" (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (fc1): Linear(in_features=2304, out_features=600, bias=True)\n",
" (drop): Dropout2d(p=0.25, inplace=False)\n",
" (fc2): Linear(in_features=600, out_features=120, bias=True)\n",
" (fc3): Linear(in_features=120, out_features=10, bias=True)\n",
")\n"
]
}
],
"source": [
"model = FashionCNN()\n",
"model.to(device)\n",
"\n",
"error = nn.CrossEntropyLoss()\n",
"\n",
"learning_rate = 0.001\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "im58ZlvMvkty"
},
"source": [
"## Training a network and Testing it on test dataset"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 126
},
"colab_type": "code",
"id": "SYh_6HtpUlNl",
"outputId": "b92aa41b-4cbf-4ceb-c63c-7d2a3999df87"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 500, Loss: 0.19216448068618774, Accuracy: 90.94000244140625%\n",
"Iteration: 1000, Loss: 0.2398412674665451, Accuracy: 90.38999938964844%\n",
"Iteration: 1500, Loss: 0.14675875008106232, Accuracy: 89.44999694824219%\n",
"Iteration: 2000, Loss: 0.2056727409362793, Accuracy: 90.43000030517578%\n",
"Iteration: 2500, Loss: 0.08225765824317932, Accuracy: 90.12999725341797%\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_190212/1265142370.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;31m# Forward pass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/tmp/ipykernel_190212/2706098268.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc3\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1049\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1050\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/torch/nn/modules/dropout.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 100\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minplace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mdropout2d\u001b[0;34m(input, p, training, inplace)\u001b[0m\n\u001b[1;32m 1200\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0.0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1201\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"dropout probability has to be between 0 and 1, \"\u001b[0m \u001b[0;34m\"but got {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1202\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_VF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_dropout_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minplace\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0m_VF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_dropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1203\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1204\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"num_epochs = 5\n",
"count = 0\n",
"# Lists for visualization of loss and accuracy \n",
"loss_list = []\n",
"iteration_list = []\n",
"accuracy_list = []\n",
"\n",
"# Lists for knowing classwise accuracy\n",
"predictions_list = []\n",
"labels_list = []\n",
"\n",
"for epoch in range(num_epochs):\n",
" for images, labels in train_loader:\n",
" # Transfering images and labels to GPU if available\n",
" images, labels = images.to(device), labels.to(device)\n",
" \n",
" train = Variable(images.view(100, 1, 28, 28))\n",
" labels = Variable(labels)\n",
" \n",
" # Forward pass \n",
" outputs = model(train)\n",
" loss = error(outputs, labels)\n",
" \n",
" # Initializing a gradient as 0 so there is no mixing of gradient among the batches\n",
" optimizer.zero_grad()\n",
" \n",
" #Propagating the error backward\n",
" loss.backward()\n",
" \n",
" # Optimizing the parameters\n",
" optimizer.step()\n",
" \n",
" count += 1\n",
" \n",
" # Testing the model\n",
" \n",
" if not (count % 50): # It's same as \"if count % 50 == 0\"\n",
" total = 0\n",
" correct = 0\n",
" \n",
" for images, labels in test_loader:\n",
" images, labels = images.to(device), labels.to(device)\n",
" labels_list.append(labels)\n",
" \n",
" test = Variable(images.view(100, 1, 28, 28))\n",
" \n",
" outputs = model(test)\n",
" \n",
" predictions = torch.max(outputs, 1)[1].to(device)\n",
" predictions_list.append(predictions)\n",
" correct += (predictions == labels).sum()\n",
" \n",
" total += len(labels)\n",
" \n",
" accuracy = correct * 100 / total\n",
" loss_list.append(loss.data)\n",
" iteration_list.append(count)\n",
" accuracy_list.append(accuracy)\n",
" \n",
" if not (count % 500):\n",
" print(\"Iteration: {}, Loss: {}, Accuracy: {}%\".format(count, loss.data, accuracy))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "1xm9MZ6O0irC"
},
"source": [
"### Visualizing the Loss and Accuracy with Iterations\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 295
},
"colab_type": "code",
"id": "4s4cGX4Hanyz",
"outputId": "4f19caa4-327f-4cca-bd9f-726774b9ddfd"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABcv0lEQVR4nO29eZwcZ3nv+316755NM5oZLaPdkhd5t2XZZg8BbCfBZg1mh3BCuAcnZAcuuUBIOAS4cAgJXEJOIGyOIXB8cILBgYTFLLYlW7axZBtr36WZ0WzdM72/94+qt6emp3qd7mr19Pv9fPRRT3VVV1Uv9dSz/R5RSmEwGAwGQzG+Vh+AwWAwGM5PjIEwGAwGgyvGQBgMBoPBFWMgDAaDweCKMRAGg8FgcMUYCIPBYDC4YgyEwVACEYmLyJZWH4fB0CqMgTCcl4jIYRF5kf34LSLy0ybv70ci8t+cy5RS3Uqpg83cb6MRkU0iokQk0OpjMbQ/xkAYlj3mYmkw1IcxEIbzGhG5BPgccKMd8pm0l4dF5P8VkaMickZEPiciUfu5F4jIcRF5t4icBr4oIv0i8u8iMioiE/bjdfb6HwaeC/y9vY+/t5crEdlqP+4TkS/b2x8Rkb8QEZ/93FtE5Kf28UyIyCERucVxDm8RkYMiMmM/93qX81wrInMiMuBYdrWIjIlIUES2isiPRWTKXvb1Ot7LtSJyj4icE5H9IvK7jud2ishuEZm2389P2ssjIvJVERkXkUkR2SUiq2rdt6E9MQbCcF6jlHoSeAfwCzvks8J+6m+AC4GrgK3ACPB+x6argQFgI/B2rO/6F+2/NwBzwN/b+3gfcD9wh72PO1wO5e+APmAL8HzgTcBbHc9fDzwNDAIfA/5JLLqATwO3KKV6gGcBj7qc50ngF8ArHYtfB3xTKZUB/gr4D6AfWGcfT63cBRwH1gKvAv6HiLzQfu5vgb9VSvUCFwDfsJe/2T7v9cBKrM9iro59G9oQYyAMbYeICNZF/4+UUueUUjPA/wBud6yWBz6glEoppeaUUuNKqW8ppWbt9T+MdaGvZn9++7Xfq5SaUUodBj4BvNGx2hGl1D8qpXLAl4A1gL7TzgOXiUhUKXVKKbW3xK7uBF7rOMfb7WUAGSzjtlYplVRK1ZSTEZH1wLOBd9vbPwr8LyxDp19/q4gMKqXiSqkHHMtXAluVUjml1MNKqela9m1oX4yBMLQjQ0AMeNgOe0wC37OXa0aVUkn9h4jEROQf7PDQNPATYIV98a/EIBAEjjiWHcHyWjSn9QOl1Kz9sFsplQBeg3XnfUpEviMiF5fYz7ewQmlrgOdhGZb77ef+HBDgIRHZKyK/U8VxO1kLaGPqdg5vw/LInrLDSL9lL/8KcB9wl4icFJGPiUiwxn0b2hRjIAztQLHk8BhWmONSpdQK+1+fUqq7zDZ/AlwEXG+HUZ5nL5cS6xfvT9/BazYAJ6o6eKXuU0q9GMureAr4xxLrTWCFkV6DFV66S9lyy0qp00qp31VKrQV+D/iszo9UyUlgQER63M5BKfWMUuq1wDDwUeCbItKllMoopf5SKbUdKzz2W8x7HYZljjEQhnbgDLBOREIASqk81kX2f4rIMICIjIjITWVeowfLqEzaieAPuOzDtefBDht9A/iwiPSIyEbgj4GvVjpwEVklIrfZuYgUEMfyDEpxJ9YF+FXMh5cQkVfrpDowgWXQyr1O2E4wR0QkgmUIfg58xF52BZbX8FX79d8gIkP2eztpv0ZeRH5NRC63Pa1pLENZbr+GZYQxEIZ24L+AvcBpERmzl70b2A88YIeMfoDlIZTiU0AUyxt4ACsk5eRvgVfZVUifdtn+94EEcBD4KdbF+wtVHLsPy5icBM5h5T3+rzLr3wNsA04rpR5zLL8OeFBE4vY676rQoxHHMoj63wux8hub7GO5GytH8wN7/ZuBvfbr/y1wu1JqDivZ/00s4/Ak8GOssJOhAxAzMMhgMBgMbhgPwmAwGAyuGANhMBgMBleMgTAYDAaDK8ZAGAwGg8GVZSNiNjg4qDZt2tTqwzAYDIa24uGHHx5TSg25PbdsDMSmTZvYvXt3qw/DYDAY2goROVLqORNiMhgMBoMrxkAYDAaDwRVjIAwGg8HgijEQBoPBYHClqQZCRG4Wkaft6VXvKbPeK+3pXTvsvzfZ07Uetf99rpnHaTAYDIbFNK2KyVZ//AzwYqwpVrtE5B6l1L6i9XqAdwEPFr3EAaXUVc06PoPBYDCUp5kexE5gv1LqoFIqjTXu8DaX9f4KS38+6fKcwWAwGFpEMw3ECHDM8fdxFk7gQkSuAdYrpb7jsv1mEdljD2p/rtsOROTt9qD13aOjow078HbgP/ae5vSUsakGg6F5tCxJLSI+4JNYk76KOQVsUEpdjaWlf6eI9BavpJT6vFJqh1Jqx9CQayPgsiSby/OOrz7MnQ+W7G8xGAyGJdNMA3ECWO/4ex0LRzT2AJcBPxKRw8ANwD0issMeND8OoJR6GDiANS/XAEwns+QVzKSyrT4Ug8GwjGmmgdgFbBORzfaoyNuxJmEBoJSaUkoNKqU2KaU2YU35ulUptVtEhvQweRHZgjVhq9z0rI5iei4DQDKTa/GRGAyG5UzTqpiUUlkRuQO4D/ADX1BK7RWRDwG7lVL3lNn8ecCHRETPv32HUupcs4613ZiyDcRs2hgIg8HQPJoq1qeUuhe4t2jZ+0us+wLH428B32rmsbUzxkAYDAYvMJ3Ubch00jIQc8ZAGAyGJmIMRBuiPYg5k4MA4MTkHH/6r4+Rypr3w2BoJMZAtCHTc1b1kgkxWfzsmTG++fBxjozPtvpQDIZlhTEQbUjBg0ibMleYL/c1BtNgaCzGQLQhJsS0kHjSNhCmL8RgaCjGQLQhOklt7pgt4inzfhgMzcAYiDZkes5UMTmJ255DwoTcDIaGYgxEG6JDTNm8Ip3Nt/hoWs+MHWIyBtNgaCzGQLQh2oMAk4eAeQORMAbCYGgoxkC0IVNzGYJ+AcxdM8yHmExVl8HQWIyBaDOUUkwns6zqjQAway6KhSom40EYDI3FGIg2I5HOkcsrVtsGwoSYnB6EeS8MhkZiDESboRPUq/psA2EuiszYZb8J0wdhMDQUYyDajKlZ62K4phBi6mwDoZQqeBCzxpsyGBqKMRBthm6SW91nDARYIba8sh93+HthMDQaYyDajEKIyfYgOn2qnE5QgwkxGQyNxhiINkMbiDXGgwAWzuU2CXuDobE01UCIyM0i8rSI7BeR95RZ75UiokRkh2PZe+3tnhaRm5p5nO3EdJEH0ellrtqDiAR9xoMwGBpM00aOiogf+AzwYuA4sEtE7lFK7Starwd4F/CgY9l24HbgUmAt8AMRuVAp1fG3iNNzGURguDcMmBCTTlCv6o2YHITB0GCa6UHsBPYrpQ4qpdLAXcBtLuv9FfBRIOlYdhtwl1IqpZQ6BOy3X6/jmZrL0BMOEA74CfjEhJhsD2JVT8Q0yhkMDaaZBmIEOOb4+7i9rICIXAOsV0p9p9Zt7e3fLiK7RWT36OhoY476PGc6maUvFgQgGvJ3vIHQHsRQb9h4EAZDg2lZklpEfMAngT+p9zWUUp9XSu1QSu0YGhpq3MGdx0zNZeiNWAYiFvJ3/EUxbpf9ruqJkM7lyeSMuq3B0CialoMATgDrHX+vs5dpeoDLgB+JCMBq4B4RubWKbTuWqbkMfVFtIAIdX7mjPQidk5lN5+iLmuI8g6ERNPOXtAvYJiKbRSSElXS+Rz+plJpSSg0qpTYppTYBDwC3KqV22+vdLiJhEdkMbAMeauKxtg3TDgMRCZoQ00wqSzjgK7wnnV7VZTA0kqZ5EEqprIjcAdwH+IEvKKX2isiHgN1KqXvKbLtXRL4B7AOywDtNBZPFohBTprMviPFklp5IgFjID5i+EIOhkTQzxIRS6l7g3qJl7y+x7guK/v4w8OGmHVybMjWXKSSpYyF/x9f+x1NZusMBYiHrq9zpORmABw6Ok8zkeMFFw60+FEOb01QDYWgsyUyOVDa/IMQ0OpNq8VG1lngyS7fDg+h0gwnw6f98hulkxhgIw5I
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(iteration_list, loss_list)\n",
"plt.xlabel(\"No. of Iteration\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.title(\"Iterations vs Loss\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 295
},
"colab_type": "code",
"id": "TtsjkmY8qo_t",
"outputId": "18a10340-f465-4c80-db26-cdce0201e347"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABdYUlEQVR4nO29d5gkZ3mvfT+dw/SknbBBm7W7CiuQ0EpIAgUkIUA20dgm2AjbIDAYA7YxhuPPcOwLH2OD4RCMkAlHxmRhTDJIQqAEKKyklbSr1eYcJqfumen4fn9UVU9PT4fqnq7pMO99XXNNd3VVd1WHeupJv0eUUmg0Go1Gk4+r3jug0Wg0msZEGwiNRqPRFEQbCI1Go9EURBsIjUaj0RREGwiNRqPRFEQbCI1Go9EURBsITcsjIlER2VTv/dBomg1tIDSOIiJHReRG8/ZbReQhh1/vPhF5W+4ypVSbUuqwk6/rFCLSZhq4n9Z7XzTLD20gNE2DiHjqvQ914HeAOPBSEVm5lC+8TN9vTQ7aQGiWBBE5H7gNuNK8Ih43l/tF5BMiclxEBkTkNhEJmo9dJyInReSDInIW+KqIdInIj0VkSETGzNvnmOt/DLga+Jz5Gp8zlysROde83SEi/2Fuf0xE/lZEXOZjbxWRh8z9GRORIyLyipxjeKuIHBaRKfOxNxc4ztUiMiMi3TnLLhGRYRHxisi5InK/iEyYy75d5q27xXzfngb+IO+1XiwivxaRcRE5ISJvNZcHReST5vFNmMcUtN7PvOfI9fA+KiJ3ish/isgk8FYRuVxEfmO+xhkR+ZyI+HK2v1BE7hGRUfPz+7CIrBSRaRFZkbPeC8z33FvmeDUNhDYQmiVBKbUXeCfwGzPk02k+9E/AVuBi4FxgDfB3OZuuBLqB9cCtGN/Zr5r31wEzwOfM1/hfwIPAn5mv8WcFduWzQAewCbgWeAvwRzmPvxDYB/QA/wx8WQzCwGeAVyilIsBVwK4Cx3ka+A3Glb/Fm4A7lVJJ4B+Au4Eu4BxzfwoiIuuB64Cvm39vyXvsp+b2vRjvn7U/nwAuNfexG/hrIFPsdfJ4NXAn0Gm+Zhp4P8b7cSVwA/Aucx8iwM+BnwGrMT6/e5VSZ4H7gN/Led4/BL5lvgeaZkEppf/0n2N/wFHgRvP2W4GHch4TIAZszll2JXDEvH0dkAACJZ7/YmAs5/59wNvy1lEYJy+3+XwX5Dz2DuC+nP07mPNYyNx2JRAGxjFO/MEyx/w24Bc5x3gCuMa8/x/A7cA5Nt67vwV2mbfXYJysLzHvfwj4foFtXBhG8/kFHrsOOFni8/ko8ECZfXqf9brAG4Eni6z3+8CvzNtu4Cxweb2/j/qvsj/tQWjqSS/GSfhxM4QxjnE12puzzpBSata6IyIhEfmiGT6ZBB4AOkXEbeP1egAvcCxn2TGMk6/FWeuGUmravNmmlIphnPTeCZwRkZ+IyHlFXud7GKG0VcA1GFfvD5qP/TWG0XhURPaIyB+X2N+3YFzFo5Q6BdyPEXICWAscKnKMgSKP2eFE7h0R2WqG8c6a7/c/mq9Rah8AfgBcICIbgZcCE0qpR6vcJ02d0AZCs5TkSwcPY1ztXqiU6jT/OpRSbSW2+UtgG/BCpVQ7xgkYjJNuofXzXy+JEZ6yWAecsrXzSt2llHopsAp4Dvj3IuuNYYSRfh8jvPQtZV5KK6XOKqXerpRajeG9/JuVH8lFRK4CtgAfMk/OZzHCX28yk8cngM1FjnG2yGMxDINsvYab+cYYFr5/XzCPdYv5fn+Yuff6BEaortB7MAt8ByNv8ofA1wqtp2lstIHQLCUDwDlWklMplcE4yX5KRPoARGSNiLysxHNEMIzKuJkI/kiB1yh20kpjnLQ+JiIRM47/F8B/lttxEekXkVebuYg4EKV0XP8bGB7A683b1vP8rpVUB8YwTsiFnucW4B7gAoww2sXAdiAIvALDs7hRRH5PRDwiskJELjbf068A/2omzN0icqWI+IH9QEBEfstMFv8t4C9z6BFgEoiaHtOf5jz2Y2CViLxPjGKDiIi8MOfx/8AI270KbSCaEm0gNEvJL4A9wFkRGTaXfRA4CDxshjB+juEhFOPTGCfJYeBhjJBULv8XeL1ZhfSZAtu/B+NK+jDwEMbJ+ys29t2FYUxOA6MYCe4/LbH+DzE8gLNKqadyll8GPCIiUXOd96q8Hg0RCWAkeD9rehzW3xGME+0tSqnjwM0YHtUoRoL6+eZT/BXwDPCY+djHAZdSagIjwfwlDK8pBsyrairAX2F4QVMYxjxbdaWUmsIIH70SIzR3AHhJzuO/wjB+TyilcsN6miZBTM9Xo9Foao6I/AL4hlLqS/XeF03laAOh0WgcQUQuwwiTrTW9DU2ToUNMGo2m5ojIHRjhwvdp49C8OOpBiMh7gbdjVD38u1Lq0yLyuxj11udj1EXvLLLtUYy4ZxpIKaV2OLajGo1Go1mAY1orIrIdwzhcjtGc9DMR+TGwG3gd8EUbT/MSpdRw+dU0Go1GU2ucFOM6H3jEajYSkfuB1yml/tm8X/MX7OnpURs2bKj582o0Gk2r8vjjjw8rpfL7YQBnDcRujHrzFRh16zcDBcNJRVDA3SKigC8qpW4vtJKI3Iqh0cO6devYubOSl9BoNJrljYgULUF2zEAopfaKyMcxOkpjGHXa6Qqe4sVKqVNmA9U9IvKcUuqBAq9zO4a2DTt27NAlWRqNRlMjHK1iUkp9WSl1qVLqGoyu0f0VbHvK/D8IfB8jl6HRaDSaJcJRA5Ejn7AOIzH9jdJbZLcLm1LCmNIGN2GErDQajUazRDjdB/E9EXkW+BHwbqXUuIi81hxaciXwExG5C7KDVv7H3K4feEhEngIeBX6ilMqXVNBoNBqNgzg6UlApdXWBZd/HCBnlLz+NkcjG1KZ5fv46Go1Go1k6dCe1RqPRaAqiDYRGo9FoCqINhGYes8k039l5Ai3iqNFotIHQzOOeZwf46zuf5plTE/XeFY1GU2e0gdDMY2gqDsCZidkya2o0mlZHGwjNPEZjCQAGJ7WB0GiWO9pAaOYxEjM8iLPaQGg0JTk9PsNn7z3Q0vk6bSA08xiOGh7EwGS8znui0TQ2P9h1mk/es59T4zP13hXH0AZCM4+RqGEYBrQHodGU5LRpGMank3XeE+fQBkIzj5GY5UFoA6FpDR47OpotvqgllucwNp2o+XM3CtpAaOYxokNMmhZCKcVbvvwon/q5bSFp25zOGgjtQWiWAbPJNNF4ija/h4mZJLPJSsZ3aDSNx+RMiplkmieOjdX8uU9lQ0zag2hpXvF/H+S2+w/VezfqjhVeumBVO6DDTMuZrz18jI/95Nl678aiGTJzavsHpojGUzV73snZJFOzxvONxbQH0dKcHp/hTAtXItjFSlBfsNoyEDrMtFy5a/dZ/v3BI+w7O1XvXVkUw+Z3OqPg6ZPjNXve0znnC52DaHHa/B6manh10axY+QfLg9C9EDAxnXQkwdnojM8Y34XbHzhc5z1ZHJaBAHjy+HjNnlcbiGVEJOAhpg1E9sdkeRC6mxr+7oe7eed/Pl7v3VhyJmaMsMkPdp3izETzetfDpnHvDvtqaiBOjRnvycr2gE5Stzptfk9N45PNipWD2NgTxu9xNX0OYiyW4MDA4kIkR4djnG0RXap9Z6eYsHkyG59OcuP5/SjgKw8dcXbHHGQ4msAlcO3WXnadGKtZ1/Op8Vm8bmFLf5tOUrc6bQEP0VltIEaicYJeN2G/h5UdAc42eQ7iz7/1JG/5yqOLeo7BqTiTs81/hTgcjfPKzz3EFx8oX4yRziimZlNcuLqdVz5vFd945HjWo2g2hqNxusN+Ll3fxXA0wcmx2nhDp8dnWNURpKfNr0NMrY7OQRiMRBOsaPMB0B8JNLUH8eTxMR48MMzZyVlS6UxVz5HJKIam4kTjKTKZ5tbb+fZjJ0ikMrYKDyZNY9AZ8nLrNZuJJdJ8/ZFjTu+iIwxHE/S0+bhkXScATxyvTbnr6fEZVncG6Ax5GddVTK1Nm197EADDsQQr2vwA9HcEmjoH8blfHARAKRit8gpvbDpBKqN
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(iteration_list, accuracy_list)\n",
"plt.xlabel(\"No. of Iteration\")\n",
"plt.ylabel(\"Accuracy\")\n",
"plt.title(\"Iterations vs Accuracy\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mZTkzUKn0vq3"
},
"source": [
"### Looking the Accuracy in each class of FashionMNIST dataset"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 199
},
"colab_type": "code",
"id": "Mq9_qes8Qg6h",
"outputId": "fd7e3194-d5ab-4c3d-8693-ced0369df983"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of T-shirt/Top: 83.60%\n",
"Accuracy of Trouser: 98.70%\n",
"Accuracy of Pullover: 87.40%\n",
"Accuracy of Dress: 92.00%\n",
"Accuracy of Coat: 78.70%\n",
"Accuracy of Sandal: 97.20%\n",
"Accuracy of Shirt: 79.30%\n",
"Accuracy of Sneaker: 98.40%\n",
"Accuracy of Bag: 97.60%\n",
"Accuracy of Ankle Boot: 94.50%\n"
]
}
],
"source": [
"class_correct = [0. for _ in range(10)]\n",
"total_correct = [0. for _ in range(10)]\n",
"\n",
"with torch.no_grad():\n",
" for images, labels in test_loader:\n",
" images, labels = images.to(device), labels.to(device)\n",
" test = Variable(images)\n",
" outputs = model(test)\n",
" predicted = torch.max(outputs, 1)[1]\n",
" c = (predicted == labels).squeeze()\n",
" \n",
" for i in range(100):\n",
" label = labels[i]\n",
" class_correct[label] += c[i].item()\n",
" total_correct[label] += 1\n",
" \n",
"for i in range(10):\n",
" print(\"Accuracy of {}: {:.2f}%\".format(output_label(i), class_correct[i] * 100 / total_correct[i]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ZxZhix0v0_KZ"
},
"source": [
"### Printing the Confusion Matrix "
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "GXme4c22cii2"
},
"outputs": [],
"source": [
"from itertools import chain \n",
"\n",
"predictions_l = [predictions_list[i].tolist() for i in range(len(predictions_list))]\n",
"labels_l = [labels_list[i].tolist() for i in range(len(labels_list))]\n",
"predictions_l = list(chain.from_iterable(predictions_l))\n",
"labels_l = list(chain.from_iterable(labels_l))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
},
"colab_type": "code",
"id": "ft-Qlbb5bl0A",
"outputId": "5665491d-d6ee-4891-95ff-2749de05a110"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classification report for CNN :\n",
" precision recall f1-score support\n",
"\n",
" 0 0.86 0.83 0.84 55000\n",
" 1 0.99 0.98 0.98 55000\n",
" 2 0.86 0.86 0.86 55000\n",
" 3 0.91 0.89 0.90 55000\n",
" 4 0.82 0.89 0.85 55000\n",
" 5 0.98 0.97 0.97 55000\n",
" 6 0.74 0.72 0.73 55000\n",
" 7 0.94 0.96 0.95 55000\n",
" 8 0.97 0.98 0.97 55000\n",
" 9 0.96 0.96 0.96 55000\n",
"\n",
" accuracy 0.90 550000\n",
" macro avg 0.90 0.90 0.90 550000\n",
"weighted avg 0.90 0.90 0.90 550000\n",
"\n",
"\n"
]
}
],
"source": [
"import sklearn.metrics as metrics\n",
"\n",
"confusion_matrix(labels_l, predictions_l)\n",
"print(\"Classification report for CNN :\\n%s\\n\"\n",
" % (metrics.classification_report(labels_l, predictions_l)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "4Y9pGNv64UYl"
},
"source": [
"### This is my implementation of deep learning in FashionMNIST dataset using Pytorch. I've achieved 93% test accuracy. Change those layer architecture or parameters to make it better. \n",
"***I hope you like it. Give your feedback. It helps me to a lot. Thank you. :)***"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "fashion_MNIST.ipynb",
"provenance": [],
"version": "0.3.2"
},
"kernelspec": {
"display_name": "PyTorch",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-08-27 02:51:03 +00:00
"version": "3.9.6"
2021-08-24 02:59:53 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 4
2021-08-27 02:51:03 +00:00
}