Reformat mnist.py code
This commit is contained in:
parent
d50f8445a0
commit
2f0109fdbd
|
@ -0,0 +1,21 @@
|
|||
# This initialization file is associated with your terminal in PyTorch/..mnist.py-0-python3.term.init.
|
||||
# It is automatically run whenever it starts up -- restart the terminal via Ctrl-d and Return-key.
|
||||
|
||||
# Usually, your ~/.bashrc is executed and this behavior is emulated for completeness:
|
||||
source ~/.bashrc
|
||||
|
||||
# You can export environment variables, e.g. to set custom GIT_* variables
|
||||
# https://git-scm.com/book/en/v2/Git-Internals-Environment-Variables
|
||||
#export GIT_AUTHOR_NAME="Your Name"
|
||||
#export GIT_AUTHOR_EMAIL="your@email.address"
|
||||
#export GIT_COMMITTER_NAME="Your Name"
|
||||
#export GIT_COMMITTER_EMAIL="your@email.address"
|
||||
|
||||
# It is also possible to automatically start a program ...
|
||||
|
||||
#sage
|
||||
#sage -ipython
|
||||
#top
|
||||
|
||||
# ... or even define a terminal specific function.
|
||||
#hello () { echo "hello world"; }
|
|
@ -4,4 +4,31 @@ Type "help", "copyright", "credits" or "license" for more information.
|
|||
Python 3.8.10 (default, Jun 2 2021, 10:49:15)
|
||||
[GCC 9.4.0] on linux
|
||||
Type "help", "copyright", "credits" or "license" for more information.
|
||||
>>> Python 3.8.10 (default, Jun 2 2021, 10:49:15)
|
||||
[GCC 9.4.0] on linux
|
||||
Type "help", "copyright", "credits" or "license" for more information.
|
||||
>>>
[K>>> mnist.py
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
NameError: name 'mnist' is not defined
|
||||
>>> execfile('mnist.py')
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
NameError: name 'execfile' is not defined
|
||||
>>> exec(open('mnist.py').read())
|
||||
/projects/800fec81-81db-4589-8df3-d839b1d21871/.local/lib/python3.8/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)
|
||||
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
|
||||
/projects/800fec81-81db-4589-8df3-d839b1d21871/.local/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)
|
||||
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
Iteration: 50, Loss: 0.39537739753723145, Accuracy: 92.41999816894531%
|
||||
Iteration: 100, Loss: 0.13765878975391388, Accuracy: 92.33000183105469%
|
||||
Iteration: 150, Loss: 0.13507212698459625, Accuracy: 96.91999816894531%
|
||||
^CTraceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
File "<string>", line 86, in <module>
|
||||
File "/projects/800fec81-81db-4589-8df3-d839b1d21871/.local/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward
|
||||
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
|
||||
File "/projects/800fec81-81db-4589-8df3-d839b1d21871/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 147, in backward
|
||||
Variable._execution_engine.run_backward(
|
||||
KeyboardInterrupt
|
||||
>>>
|
112
.mnist.py-0.term
112
.mnist.py-0.term
|
@ -366,4 +366,114 @@ Iteration: 1050, Loss: 0.04832194745540619, Accuracy: 97.83000183105469%
|
|||
Variable._execution_engine.run_backward(
|
||||
KeyboardInterrupt
|
||||
|
||||
]0;~/PyTorch[01;34m~/PyTorch[00m$ g
|
||||
]0;~/PyTorch[01;34m~/PyTorch[00m$ git add- A
|
||||
git: 'add-' is not a git command. See 'git --help'.
|
||||
|
||||
The most similar command is
|
||||
add
|
||||
]0;~/PyTorch[01;34m~/PyTorch[00m$ g[Kgit add -A
|
||||
]0;~/PyTorch[01;34m~/PyTorch[00m$ git commit
|
||||
hint: Waiting for your editor to close the file... [?1049h[?1h=[1;11r[23m[24m[0m[H[J[?25l[11;1H"~/PyTorch/.git/COMMIT_EDITMSG" 11L, 288C[2;1H[34m# Please enter the commit message for your changes. Lines starting
|
||||
# with '#' will be ignored, and an empty message aborts the commit.
|
||||
#
|
||||
# On branch [0m[35mmain[0m
|
||||
[34m# Your branch is up to date with '[0m[35morigin/main[0m[34m'.
|
||||
#
|
||||
# [0m[35mChanges to be committed:[0m
|
||||
[34m# [0m[32mmodified[0m[34m: [0m[31m .mnist.py-0.term[0m
|
||||
[34m# [0m[32mmodified[0m[34m: [0m[31m mnist.py[0m[11;194H1,0-1[9CTop[1;1H[34h[?25h[?25l[11;184Hi[1;1H[11;184H [1;1H[11;1H[1m-- INSERT --[0m[11;13H[K[11;194H1,1[11CTop[1;1H[34h[?25h[?25l[33mU[0m[11;196H2[1;2H[34h[?25h[?25l[33ms[0m[11;196H3[1;3H[34h[?25h[?25l[33me[0m[11;196H4[1;4H[34h[?25h[?25l[33m [0m[11;196H5[1;5H[34h[?25h[?25l[33mC[0m[11;196H6[1;6H[34h[?25h[?25l[1;5H[K[11;196H5[1;5H[34h[?25h[?25l[33mG[0m[11;196H6[1;6H[34h[?25h[?25l[33mU[0m[11;196H7[1;7H[34h[?25h[?25l[33m [0m[11;196H8[1;8H[34h[?25h[?25l[33mi[0m[11;196H9[1;9H[34h[?25h[?25l[1;8H[K[11;196H8[1;8H[34h[?25h[?25l[1;7H[K[11;196H7[1;7H[34h[?25h[?25l[1;6H[K[11;196H6[1;6H[34h[?25h[?25l[33mP[0m[11;196H7[1;7H[34h[?25h[?25l[33mU[0m[11;196H8[1;8H[34h[?25h[?25l[33m [0m[11;196H9[1;9H[34h[?25h[?25l[33mi[0m[11;196H10[1;10H[34h[?25h[?25l[33mf[0m[11;197H1[1;11H[34h[?25h[?25l[33m [0m[11;197H2[1;12H[34h[?25h[?25l[33ma[0m[11;197H3[1;13H[34h[?25h[?25l[33mv[0m[11;197H4[1;14H[34h[?25h[?25l[33ma[0m[11;197H5[1;15H[34h[?25h[?25l[33ml[0m[11;197H6[1;16H[34h[?25h[?25l[33mi[0m[11;197H7[1;17H[34h[?25h[?25l[33ma[0m[11;197H8[1;18H[34h[?25h[?25l[33mb[0m[11;197H9[1;19H[34h[?25h[?25l[33ml[0m[11;196H20[1;20H[34h[?25h[?25l[1;19H[K[11;196H19[1;19H[34h[?25h[?25l[1;18H[K[11;197H8[1;18H[34h[?25h[?25l[1;17H[K[11;197H7[1;17H[34h[?25h[?25l[1;16H[K[11;197H6[1;16H[34h[?25h[?25l[1;15H[K[11;197H5[1;15H[34h[?25h[?25l[33mi[0m[11;197H6[1;16H[34h[?25h[?25l[33ma[0m[11;197H7[1;17H[34h[?25h[?25l[33mb[0m[11;197H8[1;18H[34h[?25h[?25l[33ml[0m[11;197H9[1;19H[34h[?25h[?25l[33me[0m[11;196H20[1;20H[34h[?25h[?25l[1;19H[K[11;196H19[1;19H[34h[?25h[?25l[1;18H[K[11;197H8[1;18H[34h[?25h[?25l[1;17H[K[11;197H7[1;17H[34h[?25h[?25l[1;16H[K[11;197H6[1;16H[34h[?25h[?25l[33ml[0m[11;197H7[1;17H[34h[?25h[?25l[33ma[0m[11;197H8[1;18H[34h[?25h[?25l[33mb[0m[11;197H9[1;19H[34h[?25h[?25l[33ml[0m[11;196H20[1;20H[34h[?25h[?25l[33me[0m[11;197H1[1;21H[34h[?25h[11;1H[K[1;20H[?25l[11;184H^[[1;20H[11;184H [1;21H[11;194H1,20[10CTop[1;20H[34h[?25h[?25l[11;184H:[1;20H[11;184H[K[11;1H:[34h[?25hx
[?25l".git/COMMIT_EDITMSG" 11L, 308C written
|
||||
[?1l>[34h[?25h[?1049l
[K[main 438d211] Use GPU if available
|
||||
2 files changed, 93 insertions(+), 2 deletions(-)
|
||||
]0;~/PyTorch[01;34m~/PyTorch[00m$ git push
|
||||
Warning: Permanently added the RSA host key for IP address '140.82.112.3' to the list of known hosts.
|
||||
Enumerating objects: 7, done.
|
||||
Counting objects: 14% (1/7)
Counting objects: 28% (2/7)
Counting objects: 42% (3/7)
Counting objects: 57% (4/7)
Counting objects: 71% (5/7)
Counting objects: 85% (6/7)
Counting objects: 100% (7/7)
Counting objects: 100% (7/7), done.
|
||||
Delta compression using up to 24 threads
|
||||
Compressing objects: 25% (1/4)
Compressing objects: 50% (2/4)
Compressing objects: 75% (3/4)
Compressing objects: 100% (4/4)
Compressing objects: 100% (4/4), done.
|
||||
Writing objects: 25% (1/4)
Writing objects: 50% (2/4)
Writing objects: 75% (3/4)
Writing objects: 100% (4/4)
Writing objects: 100% (4/4), 2.31 KiB | 2.31 MiB/s, done.
|
||||
Total 4 (delta 3), reused 0 (delta 0)
|
||||
remote: Resolving deltas: 0% (0/3)[K
remote: Resolving deltas: 33% (1/3)[K
remote: Resolving deltas: 66% (2/3)[K
remote: Resolving deltas: 100% (3/3)[K
remote: Resolving deltas: 100% (3/3), completed with 3 local objects.[K
|
||||
To github.com:Ta180m/PyTorch.git
|
||||
ada4698..438d211 main -> main
|
||||
]0;~/PyTorch[01;34m~/PyTorch[00m$ ]0;~/PyTorch[01;34m~/PyTorch[00m$
[K]0;~/PyTorch[01;34m~/PyTorch[00m$ [H[J]0;~/PyTorch[01;34m~/PyTorch[00m$ python[K/mn[K[K[K./mnist.py
|
||||
/projects/800fec81-81db-4589-8df3-d839b1d21871/.local/lib/python3.8/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)
|
||||
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
|
||||
/projects/800fec81-81db-4589-8df3-d839b1d21871/.local/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)
|
||||
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
^CTraceback (most recent call last):
|
||||
File "./mnist.py", line 95, in <module>
|
||||
loss.backward()
|
||||
File "/projects/800fec81-81db-4589-8df3-d839b1d21871/.local/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward
|
||||
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
|
||||
File "/projects/800fec81-81db-4589-8df3-d839b1d21871/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 147, in backward
|
||||
Variable._execution_engine.run_backward(
|
||||
KeyboardInterrupt
|
||||
|
||||
]0;~/PyTorch[01;34m~/PyTorch[00m$ [H[J]0;~/PyTorch[01;34m~/PyTorch[00m$
[K]0;~/PyTorch[01;34m~/PyTorch[00m$ ./mnist.py
|
||||
/projects/800fec81-81db-4589-8df3-d839b1d21871/.local/lib/python3.8/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)
|
||||
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
|
||||
/projects/800fec81-81db-4589-8df3-d839b1d21871/.local/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)
|
||||
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
Iteration: 50, Loss: 0.2655351459980011, Accuracy: 93.6500015258789%
|
||||
Iteration: 100, Loss: 0.09417656064033508, Accuracy: 93.77999877929688%
|
||||
Iteration: 150, Loss: 0.14728593826293945, Accuracy: 96.0999984741211%
|
||||
Iteration: 200, Loss: 0.08001164346933365, Accuracy: 97.04000091552734%
|
||||
Iteration: 250, Loss: 0.15678852796554565, Accuracy: 96.12999725341797%
|
||||
Iteration: 300, Loss: 0.09041783958673477, Accuracy: 96.87999725341797%
|
||||
Iteration: 350, Loss: 0.10900043696165085, Accuracy: 97.08999633789062%
|
||||
Iteration: 400, Loss: 0.2091383934020996, Accuracy: 96.8499984741211%
|
||||
Iteration: 450, Loss: 0.012568566016852856, Accuracy: 97.58999633789062%
|
||||
Iteration: 500, Loss: 0.09534303843975067, Accuracy: 95.9000015258789%
|
||||
Iteration: 550, Loss: 0.12004967778921127, Accuracy: 97.94999694824219%
|
||||
Iteration: 600, Loss: 0.26723435521125793, Accuracy: 96.75%
|
||||
Iteration: 650, Loss: 0.10009327530860901, Accuracy: 98.08999633789062%
|
||||
Iteration: 700, Loss: 0.03489111363887787, Accuracy: 95.26000213623047%
|
||||
Iteration: 750, Loss: 0.030101114884018898, Accuracy: 98.12999725341797%
|
||||
Iteration: 800, Loss: 0.05416644737124443, Accuracy: 97.61000061035156%
|
||||
Iteration: 850, Loss: 0.11499112099409103, Accuracy: 97.8499984741211%
|
||||
Iteration: 900, Loss: 0.20542272925376892, Accuracy: 97.66999816894531%
|
||||
Iteration: 950, Loss: 0.05691840127110481, Accuracy: 97.88999938964844%
|
||||
Iteration: 1000, Loss: 0.17045655846595764, Accuracy: 96.95999908447266%
|
||||
Iteration: 1050, Loss: 0.028369026258587837, Accuracy: 98.19000244140625%
|
||||
Iteration: 1100, Loss: 0.09225992113351822, Accuracy: 97.76000213623047%
|
||||
Iteration: 1150, Loss: 0.038039229810237885, Accuracy: 98.22000122070312%
|
||||
Iteration: 1200, Loss: 0.23273861408233643, Accuracy: 98.30000305175781%
|
||||
Iteration: 1250, Loss: 0.08464375138282776, Accuracy: 98.66999816894531%
|
||||
Iteration: 1300, Loss: 0.017008038237690926, Accuracy: 97.8499984741211%
|
||||
Iteration: 1350, Loss: 0.05763726308941841, Accuracy: 98.63999938964844%
|
||||
Iteration: 1400, Loss: 0.022395288571715355, Accuracy: 98.51000213623047%
|
||||
Iteration: 1450, Loss: 0.06815487146377563, Accuracy: 98.51000213623047%
|
||||
Iteration: 1500, Loss: 0.14768916368484497, Accuracy: 98.3499984741211%
|
||||
Iteration: 1550, Loss: 0.021466469392180443, Accuracy: 98.52999877929688%
|
||||
Iteration: 1600, Loss: 0.054903920739889145, Accuracy: 98.0999984741211%
|
||||
Iteration: 1650, Loss: 0.009115751832723618, Accuracy: 98.44999694824219%
|
||||
Iteration: 1700, Loss: 0.027846679091453552, Accuracy: 98.70999908447266%
|
||||
Iteration: 1750, Loss: 0.019951678812503815, Accuracy: 98.5199966430664%
|
||||
Iteration: 1800, Loss: 0.25205621123313904, Accuracy: 98.62000274658203%
|
||||
Iteration: 1850, Loss: 0.02951984480023384, Accuracy: 98.62999725341797%
|
||||
Iteration: 1900, Loss: 0.011210460215806961, Accuracy: 98.55000305175781%
|
||||
Iteration: 1950, Loss: 0.05040852725505829, Accuracy: 98.5%
|
||||
Iteration: 2000, Loss: 0.008486397564411163, Accuracy: 98.55999755859375%
|
||||
Iteration: 2050, Loss: 0.059381142258644104, Accuracy: 98.61000061035156%
|
||||
Iteration: 2100, Loss: 0.10324683040380478, Accuracy: 98.37000274658203%
|
||||
Iteration: 2150, Loss: 0.06498480588197708, Accuracy: 98.16999816894531%
|
||||
Iteration: 2200, Loss: 0.036080557852983475, Accuracy: 97.70999908447266%
|
||||
Iteration: 2250, Loss: 0.013293210417032242, Accuracy: 98.66000366210938%
|
||||
Iteration: 2300, Loss: 0.06331712007522583, Accuracy: 97.91000366210938%
|
||||
Iteration: 2350, Loss: 0.004426905419677496, Accuracy: 98.02999877929688%
|
||||
Iteration: 2400, Loss: 0.27985191345214844, Accuracy: 98.5%
|
||||
Iteration: 2450, Loss: 0.04614001885056496, Accuracy: 98.5%
|
||||
Iteration: 2500, Loss: 0.005236199591308832, Accuracy: 98.43000030517578%
|
||||
Iteration: 2550, Loss: 0.026349853724241257, Accuracy: 98.43000030517578%
|
||||
Iteration: 2600, Loss: 0.007622480392456055, Accuracy: 98.77999877929688%
|
||||
Iteration: 2650, Loss: 0.04031902924180031, Accuracy: 98.58999633789062%
|
||||
Iteration: 2700, Loss: 0.00840453989803791, Accuracy: 98.83999633789062%
|
||||
Iteration: 2750, Loss: 0.07304922491312027, Accuracy: 98.19000244140625%
|
||||
Iteration: 2800, Loss: 0.11154232174158096, Accuracy: 97.13999938964844%
|
||||
Iteration: 2850, Loss: 0.014337321743369102, Accuracy: 98.37999725341797%
|
||||
Iteration: 2900, Loss: 0.03685985505580902, Accuracy: 98.66999816894531%
|
||||
Iteration: 2950, Loss: 0.018538275733590126, Accuracy: 98.58000183105469%
|
||||
Iteration: 3000, Loss: 0.1944340467453003, Accuracy: 98.75%
|
||||
Saved PyTorch Model State to model.pth
|
||||
]0;~/PyTorch[01;34m~/PyTorch[00m$ exit
|
||||
]0;~/PyTorch[01;34m~/PyTorch[00m$ ]0;~/PyTorch[01;34m~/PyTorch[00m$
|
BIN
accuracy.png
BIN
accuracy.png
Binary file not shown.
Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 20 KiB |
BIN
loss.png
BIN
loss.png
Binary file not shown.
Before Width: | Height: | Size: 36 KiB After Width: | Height: | Size: 45 KiB |
32
mnist.py
32
mnist.py
|
@ -7,7 +7,6 @@ from torchvision import datasets
|
|||
from torchvision.transforms import ToTensor, Lambda, Compose
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
training_data = datasets.MNIST(
|
||||
root=".data",
|
||||
train=True,
|
||||
|
@ -22,7 +21,6 @@ test_data = datasets.MNIST(
|
|||
transform=ToTensor(),
|
||||
)
|
||||
|
||||
|
||||
batch_size = 100
|
||||
|
||||
train_loader = DataLoader(training_data, batch_size=batch_size)
|
||||
|
@ -34,18 +32,13 @@ class CNN(nn.Module):
|
|||
super(CNN, self).__init__()
|
||||
|
||||
self.layer1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
)
|
||||
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3,
|
||||
padding=1), nn.BatchNorm2d(32), nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
self.layer2 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2)
|
||||
)
|
||||
self.fc1 = nn.Linear(in_features=64*6*6, out_features=600)
|
||||
nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2))
|
||||
self.fc1 = nn.Linear(in_features=64 * 6 * 6, out_features=600)
|
||||
self.drop = nn.Dropout2d(0.25)
|
||||
self.fc2 = nn.Linear(in_features=600, out_features=120)
|
||||
self.fc3 = nn.Linear(in_features=120, out_features=10)
|
||||
|
@ -61,7 +54,6 @@ class CNN(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
model = CNN()
|
||||
|
@ -70,7 +62,6 @@ error = nn.CrossEntropyLoss()
|
|||
learning_rate = 0.001
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
|
||||
num_epochs = 5
|
||||
count = 0
|
||||
|
||||
|
@ -104,28 +95,27 @@ for epoch in range(num_epochs):
|
|||
images, labels = images.to(device), labels.to(device)
|
||||
|
||||
labels_list.append(labels)
|
||||
|
||||
|
||||
test = Variable(images.view(batch_size, 1, 28, 28))
|
||||
outputs = model(test)
|
||||
|
||||
|
||||
predictions = torch.max(outputs, 1)[1].to(device)
|
||||
predictions_list.append(predictions)
|
||||
correct += (predictions == labels).sum()
|
||||
|
||||
|
||||
total += len(labels)
|
||||
|
||||
|
||||
accuracy = correct * batch_size / total
|
||||
loss_list.append(loss.data)
|
||||
iteration_list.append(count)
|
||||
accuracy_list.append(accuracy)
|
||||
|
||||
print("Iteration: {}, Loss: {}, Accuracy: {}%".format(count, loss.data, accuracy))
|
||||
|
||||
print("Iteration: {}, Loss: {}, Accuracy: {}%".format(
|
||||
count, loss.data, accuracy))
|
||||
|
||||
torch.save(model.state_dict(), "model.pth")
|
||||
print("Saved PyTorch Model State to model.pth")
|
||||
|
||||
|
||||
plt.plot(iteration_list, loss_list)
|
||||
plt.xlabel("No. of Iteration")
|
||||
plt.ylabel("Loss")
|
||||
|
|
Loading…
Reference in New Issue