Start working on model.py

This commit is contained in:
Anthony Wang 2021-08-23 21:32:41 -05:00
parent ac768636c6
commit cc2d7d02ad
3 changed files with 163 additions and 334 deletions

View file

@ -79,7 +79,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {
"collapsed": false,
"jupyter": {
@ -87,113 +87,11 @@
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "47918efb82854fc7a269ce73230391b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/26421880 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw\n",
"\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9abecd52d9144d53bd028f14a2cfd60b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/29515 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n",
"\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "df61f428b0c44a818d2ab0f64420d9b3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4422102 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw\n",
"\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "418ca86b3df24c84979a54ca66cebe56",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/5148 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ta180m/.local/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /build/python-pytorch/src/pytorch-1.9.0-opt/torch/csrc/utils/tensor_numpy.cpp:174.)\n",
"/home/ta180m/git/PyTorch/.venv/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n",
" return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n"
]
}
@ -228,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {
"collapsed": false,
"jupyter": {
@ -291,7 +189,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {
"collapsed": false,
"jupyter": {
@ -377,7 +275,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {
"collapsed": false,
"jupyter": {
@ -401,7 +299,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {
"collapsed": false,
"jupyter": {
@ -439,7 +337,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"metadata": {
"collapsed": false,
"jupyter": {
@ -476,7 +374,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"metadata": {
"collapsed": false,
"jupyter": {
@ -490,78 +388,78 @@
"text": [
"Epoch 1\n",
"-------------------------------\n",
"loss: 1.758146 [ 0/60000]\n",
"loss: 1.820034 [ 6400/60000]\n",
"loss: 1.846449 [12800/60000]\n",
"loss: 1.975245 [19200/60000]\n",
"loss: 1.612495 [25600/60000]\n",
"loss: 1.748993 [32000/60000]\n",
"loss: 1.628008 [38400/60000]\n",
"loss: 1.655061 [44800/60000]\n",
"loss: 1.770255 [51200/60000]\n",
"loss: 1.654287 [57600/60000]\n",
"loss: 2.300270 [ 0/60000]\n",
"loss: 2.290948 [ 6400/60000]\n",
"loss: 2.280627 [12800/60000]\n",
"loss: 2.283042 [19200/60000]\n",
"loss: 2.268817 [25600/60000]\n",
"loss: 2.262104 [32000/60000]\n",
"loss: 2.248093 [38400/60000]\n",
"loss: 2.233517 [44800/60000]\n",
"loss: 2.234299 [51200/60000]\n",
"loss: 2.234841 [57600/60000]\n",
"Test Error: \n",
" Accuracy: 37.7%, Avg loss: 1.749445 \n",
" Accuracy: 51.1%, Avg loss: 2.221716 \n",
"\n",
"Epoch 2\n",
"-------------------------------\n",
"loss: 1.670408 [ 0/60000]\n",
"loss: 1.743051 [ 6400/60000]\n",
"loss: 1.773547 [12800/60000]\n",
"loss: 1.924395 [19200/60000]\n",
"loss: 1.529726 [25600/60000]\n",
"loss: 1.692361 [32000/60000]\n",
"loss: 1.559834 [38400/60000]\n",
"loss: 1.593531 [44800/60000]\n",
"loss: 1.712157 [51200/60000]\n",
"loss: 1.605115 [57600/60000]\n",
"loss: 2.203925 [ 0/60000]\n",
"loss: 2.200477 [ 6400/60000]\n",
"loss: 2.180246 [12800/60000]\n",
"loss: 2.213840 [19200/60000]\n",
"loss: 2.155768 [25600/60000]\n",
"loss: 2.154045 [32000/60000]\n",
"loss: 2.130811 [38400/60000]\n",
"loss: 2.101326 [44800/60000]\n",
"loss: 2.118105 [51200/60000]\n",
"loss: 2.116674 [57600/60000]\n",
"Test Error: \n",
" Accuracy: 38.1%, Avg loss: 1.694516 \n",
" Accuracy: 51.0%, Avg loss: 2.093179 \n",
"\n",
"Epoch 3\n",
"-------------------------------\n",
"loss: 1.607648 [ 0/60000]\n",
"loss: 1.684907 [ 6400/60000]\n",
"loss: 1.716139 [12800/60000]\n",
"loss: 1.888849 [19200/60000]\n",
"loss: 1.474264 [25600/60000]\n",
"loss: 1.652733 [32000/60000]\n",
"loss: 1.514825 [38400/60000]\n",
"loss: 1.549373 [44800/60000]\n",
"loss: 1.670293 [51200/60000]\n",
"loss: 1.571395 [57600/60000]\n",
"loss: 2.054594 [ 0/60000]\n",
"loss: 2.047193 [ 6400/60000]\n",
"loss: 2.009665 [12800/60000]\n",
"loss: 2.093936 [19200/60000]\n",
"loss: 1.965194 [25600/60000]\n",
"loss: 1.976750 [32000/60000]\n",
"loss: 1.938592 [38400/60000]\n",
"loss: 1.889513 [44800/60000]\n",
"loss: 1.942611 [51200/60000]\n",
"loss: 1.936963 [57600/60000]\n",
"Test Error: \n",
" Accuracy: 39.0%, Avg loss: 1.653676 \n",
" Accuracy: 51.2%, Avg loss: 1.900150 \n",
"\n",
"Epoch 4\n",
"-------------------------------\n",
"loss: 1.561757 [ 0/60000]\n",
"loss: 1.640771 [ 6400/60000]\n",
"loss: 1.669458 [12800/60000]\n",
"loss: 1.862879 [19200/60000]\n",
"loss: 1.435348 [25600/60000]\n",
"loss: 1.623189 [32000/60000]\n",
"loss: 1.482370 [38400/60000]\n",
"loss: 1.515045 [44800/60000]\n",
"loss: 1.638349 [51200/60000]\n",
"loss: 1.545919 [57600/60000]\n",
"loss: 1.837173 [ 0/60000]\n",
"loss: 1.822518 [ 6400/60000]\n",
"loss: 1.775139 [12800/60000]\n",
"loss: 1.925843 [19200/60000]\n",
"loss: 1.731390 [25600/60000]\n",
"loss: 1.778743 [32000/60000]\n",
"loss: 1.714922 [38400/60000]\n",
"loss: 1.670009 [44800/60000]\n",
"loss: 1.755909 [51200/60000]\n",
"loss: 1.763030 [57600/60000]\n",
"Test Error: \n",
" Accuracy: 39.9%, Avg loss: 1.621615 \n",
" Accuracy: 52.5%, Avg loss: 1.711389 \n",
"\n",
"Epoch 5\n",
"-------------------------------\n",
"loss: 1.525517 [ 0/60000]\n",
"loss: 1.604991 [ 6400/60000]\n",
"loss: 1.630397 [12800/60000]\n",
"loss: 1.841878 [19200/60000]\n",
"loss: 1.406707 [25600/60000]\n",
"loss: 1.599460 [32000/60000]\n",
"loss: 1.456716 [38400/60000]\n",
"loss: 1.485950 [44800/60000]\n",
"loss: 1.612476 [51200/60000]\n",
"loss: 1.525381 [57600/60000]\n",
"loss: 1.621799 [ 0/60000]\n",
"loss: 1.615258 [ 6400/60000]\n",
"loss: 1.567131 [12800/60000]\n",
"loss: 1.768921 [19200/60000]\n",
"loss: 1.539987 [25600/60000]\n",
"loss: 1.627408 [32000/60000]\n",
"loss: 1.533756 [38400/60000]\n",
"loss: 1.510703 [44800/60000]\n",
"loss: 1.603583 [51200/60000]\n",
"loss: 1.636089 [57600/60000]\n",
"Test Error: \n",
" Accuracy: 40.7%, Avg loss: 1.595456 \n",
" Accuracy: 53.2%, Avg loss: 1.568043 \n",
"\n",
"Done!\n"
]
@ -725,9 +623,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "PyTorch",
"language": "python",
"name": "python3"
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {

33
model.py Normal file
View file

@ -0,0 +1,33 @@
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break

View file

@ -79,7 +79,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {
"collapsed": false,
"jupyter": {
@ -87,113 +87,11 @@
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "47918efb82854fc7a269ce73230391b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/26421880 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw\n",
"\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9abecd52d9144d53bd028f14a2cfd60b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/29515 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n",
"\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "df61f428b0c44a818d2ab0f64420d9b3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4422102 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw\n",
"\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n",
"Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "418ca86b3df24c84979a54ca66cebe56",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/5148 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ta180m/.local/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /build/python-pytorch/src/pytorch-1.9.0-opt/torch/csrc/utils/tensor_numpy.cpp:174.)\n",
"/home/ta180m/git/PyTorch/.venv/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n",
" return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n"
]
}
@ -228,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {
"collapsed": false,
"jupyter": {
@ -291,7 +189,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {
"collapsed": false,
"jupyter": {
@ -377,7 +275,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {
"collapsed": false,
"jupyter": {
@ -401,7 +299,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {
"collapsed": false,
"jupyter": {
@ -439,7 +337,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"metadata": {
"collapsed": false,
"jupyter": {
@ -476,7 +374,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"metadata": {
"collapsed": false,
"jupyter": {
@ -490,78 +388,78 @@
"text": [
"Epoch 1\n",
"-------------------------------\n",
"loss: 1.758146 [ 0/60000]\n",
"loss: 1.820034 [ 6400/60000]\n",
"loss: 1.846449 [12800/60000]\n",
"loss: 1.975245 [19200/60000]\n",
"loss: 1.612495 [25600/60000]\n",
"loss: 1.748993 [32000/60000]\n",
"loss: 1.628008 [38400/60000]\n",
"loss: 1.655061 [44800/60000]\n",
"loss: 1.770255 [51200/60000]\n",
"loss: 1.654287 [57600/60000]\n",
"loss: 2.300270 [ 0/60000]\n",
"loss: 2.290948 [ 6400/60000]\n",
"loss: 2.280627 [12800/60000]\n",
"loss: 2.283042 [19200/60000]\n",
"loss: 2.268817 [25600/60000]\n",
"loss: 2.262104 [32000/60000]\n",
"loss: 2.248093 [38400/60000]\n",
"loss: 2.233517 [44800/60000]\n",
"loss: 2.234299 [51200/60000]\n",
"loss: 2.234841 [57600/60000]\n",
"Test Error: \n",
" Accuracy: 37.7%, Avg loss: 1.749445 \n",
" Accuracy: 51.1%, Avg loss: 2.221716 \n",
"\n",
"Epoch 2\n",
"-------------------------------\n",
"loss: 1.670408 [ 0/60000]\n",
"loss: 1.743051 [ 6400/60000]\n",
"loss: 1.773547 [12800/60000]\n",
"loss: 1.924395 [19200/60000]\n",
"loss: 1.529726 [25600/60000]\n",
"loss: 1.692361 [32000/60000]\n",
"loss: 1.559834 [38400/60000]\n",
"loss: 1.593531 [44800/60000]\n",
"loss: 1.712157 [51200/60000]\n",
"loss: 1.605115 [57600/60000]\n",
"loss: 2.203925 [ 0/60000]\n",
"loss: 2.200477 [ 6400/60000]\n",
"loss: 2.180246 [12800/60000]\n",
"loss: 2.213840 [19200/60000]\n",
"loss: 2.155768 [25600/60000]\n",
"loss: 2.154045 [32000/60000]\n",
"loss: 2.130811 [38400/60000]\n",
"loss: 2.101326 [44800/60000]\n",
"loss: 2.118105 [51200/60000]\n",
"loss: 2.116674 [57600/60000]\n",
"Test Error: \n",
" Accuracy: 38.1%, Avg loss: 1.694516 \n",
" Accuracy: 51.0%, Avg loss: 2.093179 \n",
"\n",
"Epoch 3\n",
"-------------------------------\n",
"loss: 1.607648 [ 0/60000]\n",
"loss: 1.684907 [ 6400/60000]\n",
"loss: 1.716139 [12800/60000]\n",
"loss: 1.888849 [19200/60000]\n",
"loss: 1.474264 [25600/60000]\n",
"loss: 1.652733 [32000/60000]\n",
"loss: 1.514825 [38400/60000]\n",
"loss: 1.549373 [44800/60000]\n",
"loss: 1.670293 [51200/60000]\n",
"loss: 1.571395 [57600/60000]\n",
"loss: 2.054594 [ 0/60000]\n",
"loss: 2.047193 [ 6400/60000]\n",
"loss: 2.009665 [12800/60000]\n",
"loss: 2.093936 [19200/60000]\n",
"loss: 1.965194 [25600/60000]\n",
"loss: 1.976750 [32000/60000]\n",
"loss: 1.938592 [38400/60000]\n",
"loss: 1.889513 [44800/60000]\n",
"loss: 1.942611 [51200/60000]\n",
"loss: 1.936963 [57600/60000]\n",
"Test Error: \n",
" Accuracy: 39.0%, Avg loss: 1.653676 \n",
" Accuracy: 51.2%, Avg loss: 1.900150 \n",
"\n",
"Epoch 4\n",
"-------------------------------\n",
"loss: 1.561757 [ 0/60000]\n",
"loss: 1.640771 [ 6400/60000]\n",
"loss: 1.669458 [12800/60000]\n",
"loss: 1.862879 [19200/60000]\n",
"loss: 1.435348 [25600/60000]\n",
"loss: 1.623189 [32000/60000]\n",
"loss: 1.482370 [38400/60000]\n",
"loss: 1.515045 [44800/60000]\n",
"loss: 1.638349 [51200/60000]\n",
"loss: 1.545919 [57600/60000]\n",
"loss: 1.837173 [ 0/60000]\n",
"loss: 1.822518 [ 6400/60000]\n",
"loss: 1.775139 [12800/60000]\n",
"loss: 1.925843 [19200/60000]\n",
"loss: 1.731390 [25600/60000]\n",
"loss: 1.778743 [32000/60000]\n",
"loss: 1.714922 [38400/60000]\n",
"loss: 1.670009 [44800/60000]\n",
"loss: 1.755909 [51200/60000]\n",
"loss: 1.763030 [57600/60000]\n",
"Test Error: \n",
" Accuracy: 39.9%, Avg loss: 1.621615 \n",
" Accuracy: 52.5%, Avg loss: 1.711389 \n",
"\n",
"Epoch 5\n",
"-------------------------------\n",
"loss: 1.525517 [ 0/60000]\n",
"loss: 1.604991 [ 6400/60000]\n",
"loss: 1.630397 [12800/60000]\n",
"loss: 1.841878 [19200/60000]\n",
"loss: 1.406707 [25600/60000]\n",
"loss: 1.599460 [32000/60000]\n",
"loss: 1.456716 [38400/60000]\n",
"loss: 1.485950 [44800/60000]\n",
"loss: 1.612476 [51200/60000]\n",
"loss: 1.525381 [57600/60000]\n",
"loss: 1.621799 [ 0/60000]\n",
"loss: 1.615258 [ 6400/60000]\n",
"loss: 1.567131 [12800/60000]\n",
"loss: 1.768921 [19200/60000]\n",
"loss: 1.539987 [25600/60000]\n",
"loss: 1.627408 [32000/60000]\n",
"loss: 1.533756 [38400/60000]\n",
"loss: 1.510703 [44800/60000]\n",
"loss: 1.603583 [51200/60000]\n",
"loss: 1.636089 [57600/60000]\n",
"Test Error: \n",
" Accuracy: 40.7%, Avg loss: 1.595456 \n",
" Accuracy: 53.2%, Avg loss: 1.568043 \n",
"\n",
"Done!\n"
]
@ -725,9 +623,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "PyTorch",
"language": "python",
"name": "python3"
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {