Add a Kaggle tutorial

This commit is contained in:
Anthony Wang 2021-08-23 21:59:53 -05:00
parent c6ec5f983d
commit 44d82a9276
4 changed files with 1241 additions and 87 deletions

File diff suppressed because one or more lines are too long

BIN
model.pth

Binary file not shown.

View file

@ -126,7 +126,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 23,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"jupyter": { "jupyter": {
@ -138,13 +138,13 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])\n", "Shape of X [N, C, H, W]: torch.Size([200, 1, 28, 28])\n",
"Shape of y: torch.Size([64]) torch.int64\n" "Shape of y: torch.Size([200]) torch.int64\n"
] ]
} }
], ],
"source": [ "source": [
"batch_size = 64\n", "batch_size = 200\n",
"\n", "\n",
"# Create data loaders.\n", "# Create data loaders.\n",
"train_dataloader = DataLoader(training_data, batch_size=batch_size)\n", "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n",
@ -189,7 +189,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 24,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"jupyter": { "jupyter": {
@ -205,11 +205,11 @@
"NeuralNetwork(\n", "NeuralNetwork(\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n", " (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" (linear_relu_stack): Sequential(\n", " (linear_relu_stack): Sequential(\n",
" (0): Linear(in_features=784, out_features=512, bias=True)\n", " (0): Linear(in_features=784, out_features=1024, bias=True)\n",
" (1): ReLU()\n", " (1): ReLU()\n",
" (2): Linear(in_features=512, out_features=512, bias=True)\n", " (2): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (3): ReLU()\n", " (3): ReLU()\n",
" (4): Linear(in_features=512, out_features=10, bias=True)\n", " (4): Linear(in_features=1024, out_features=10, bias=True)\n",
" (5): ReLU()\n", " (5): ReLU()\n",
" )\n", " )\n",
")\n" ")\n"
@ -227,11 +227,11 @@
" super(NeuralNetwork, self).__init__()\n", " super(NeuralNetwork, self).__init__()\n",
" self.flatten = nn.Flatten()\n", " self.flatten = nn.Flatten()\n",
" self.linear_relu_stack = nn.Sequential(\n", " self.linear_relu_stack = nn.Sequential(\n",
" nn.Linear(28*28, 512),\n", " nn.Linear(28*28, 1024),\n",
" nn.ReLU(),\n", " nn.ReLU(),\n",
" nn.Linear(512, 512),\n", " nn.Linear(1024, 1024),\n",
" nn.ReLU(),\n", " nn.ReLU(),\n",
" nn.Linear(512, 10),\n", " nn.Linear(1024, 10),\n",
" nn.ReLU()\n", " nn.ReLU()\n",
" )\n", " )\n",
"\n", "\n",
@ -275,7 +275,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 25,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"jupyter": { "jupyter": {
@ -299,7 +299,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 27,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"jupyter": { "jupyter": {
@ -337,7 +337,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 28,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"jupyter": { "jupyter": {
@ -374,7 +374,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 29,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"jupyter": { "jupyter": {
@ -388,85 +388,170 @@
"text": [ "text": [
"Epoch 1\n", "Epoch 1\n",
"-------------------------------\n", "-------------------------------\n",
"loss: 2.300270 [ 0/60000]\n", "loss: 2.301592 [ 0/60000]\n",
"loss: 2.290948 [ 6400/60000]\n", "loss: 2.289894 [20000/60000]\n",
"loss: 2.280627 [12800/60000]\n", "loss: 2.280160 [40000/60000]\n",
"loss: 2.283042 [19200/60000]\n",
"loss: 2.268817 [25600/60000]\n",
"loss: 2.262104 [32000/60000]\n",
"loss: 2.248093 [38400/60000]\n",
"loss: 2.233517 [44800/60000]\n",
"loss: 2.234299 [51200/60000]\n",
"loss: 2.234841 [57600/60000]\n",
"Test Error: \n", "Test Error: \n",
" Accuracy: 51.1%, Avg loss: 2.221716 \n", " Accuracy: 30.9%, Avg loss: 2.269258 \n",
"\n", "\n",
"Epoch 2\n", "Epoch 2\n",
"-------------------------------\n", "-------------------------------\n",
"loss: 2.203925 [ 0/60000]\n", "loss: 2.268190 [ 0/60000]\n",
"loss: 2.200477 [ 6400/60000]\n", "loss: 2.259592 [20000/60000]\n",
"loss: 2.180246 [12800/60000]\n", "loss: 2.251709 [40000/60000]\n",
"loss: 2.213840 [19200/60000]\n",
"loss: 2.155768 [25600/60000]\n",
"loss: 2.154045 [32000/60000]\n",
"loss: 2.130811 [38400/60000]\n",
"loss: 2.101326 [44800/60000]\n",
"loss: 2.118105 [51200/60000]\n",
"loss: 2.116674 [57600/60000]\n",
"Test Error: \n", "Test Error: \n",
" Accuracy: 51.0%, Avg loss: 2.093179 \n", " Accuracy: 34.5%, Avg loss: 2.238048 \n",
"\n", "\n",
"Epoch 3\n", "Epoch 3\n",
"-------------------------------\n", "-------------------------------\n",
"loss: 2.054594 [ 0/60000]\n", "loss: 2.236049 [ 0/60000]\n",
"loss: 2.047193 [ 6400/60000]\n", "loss: 2.230108 [20000/60000]\n",
"loss: 2.009665 [12800/60000]\n", "loss: 2.223023 [40000/60000]\n",
"loss: 2.093936 [19200/60000]\n",
"loss: 1.965194 [25600/60000]\n",
"loss: 1.976750 [32000/60000]\n",
"loss: 1.938592 [38400/60000]\n",
"loss: 1.889513 [44800/60000]\n",
"loss: 1.942611 [51200/60000]\n",
"loss: 1.936963 [57600/60000]\n",
"Test Error: \n", "Test Error: \n",
" Accuracy: 51.2%, Avg loss: 1.900150 \n", " Accuracy: 34.4%, Avg loss: 2.204923 \n",
"\n", "\n",
"Epoch 4\n", "Epoch 4\n",
"-------------------------------\n", "-------------------------------\n",
"loss: 1.837173 [ 0/60000]\n", "loss: 2.201698 [ 0/60000]\n",
"loss: 1.822518 [ 6400/60000]\n", "loss: 2.199297 [20000/60000]\n",
"loss: 1.775139 [12800/60000]\n", "loss: 2.192280 [40000/60000]\n",
"loss: 1.925843 [19200/60000]\n",
"loss: 1.731390 [25600/60000]\n",
"loss: 1.778743 [32000/60000]\n",
"loss: 1.714922 [38400/60000]\n",
"loss: 1.670009 [44800/60000]\n",
"loss: 1.755909 [51200/60000]\n",
"loss: 1.763030 [57600/60000]\n",
"Test Error: \n", "Test Error: \n",
" Accuracy: 52.5%, Avg loss: 1.711389 \n", " Accuracy: 34.7%, Avg loss: 2.168218 \n",
"\n", "\n",
"Epoch 5\n", "Epoch 5\n",
"-------------------------------\n", "-------------------------------\n",
"loss: 1.621799 [ 0/60000]\n", "loss: 2.163056 [ 0/60000]\n",
"loss: 1.615258 [ 6400/60000]\n", "loss: 2.165860 [20000/60000]\n",
"loss: 1.567131 [12800/60000]\n", "loss: 2.158360 [40000/60000]\n",
"loss: 1.768921 [19200/60000]\n",
"loss: 1.539987 [25600/60000]\n",
"loss: 1.627408 [32000/60000]\n",
"loss: 1.533756 [38400/60000]\n",
"loss: 1.510703 [44800/60000]\n",
"loss: 1.603583 [51200/60000]\n",
"loss: 1.636089 [57600/60000]\n",
"Test Error: \n", "Test Error: \n",
" Accuracy: 53.2%, Avg loss: 1.568043 \n", " Accuracy: 35.1%, Avg loss: 2.126790 \n",
"\n",
"Epoch 6\n",
"-------------------------------\n",
"loss: 2.118984 [ 0/60000]\n",
"loss: 2.128839 [20000/60000]\n",
"loss: 2.120741 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 35.5%, Avg loss: 2.080902 \n",
"\n",
"Epoch 7\n",
"-------------------------------\n",
"loss: 2.069693 [ 0/60000]\n",
"loss: 2.088246 [20000/60000]\n",
"loss: 2.080452 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 35.9%, Avg loss: 2.032795 \n",
"\n",
"Epoch 8\n",
"-------------------------------\n",
"loss: 2.017604 [ 0/60000]\n",
"loss: 2.045965 [20000/60000]\n",
"loss: 2.040246 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 36.2%, Avg loss: 1.986216 \n",
"\n",
"Epoch 9\n",
"-------------------------------\n",
"loss: 1.966350 [ 0/60000]\n",
"loss: 2.004781 [20000/60000]\n",
"loss: 2.002753 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 36.6%, Avg loss: 1.944349 \n",
"\n",
"Epoch 10\n",
"-------------------------------\n",
"loss: 1.919495 [ 0/60000]\n",
"loss: 1.967277 [20000/60000]\n",
"loss: 1.969415 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 37.1%, Avg loss: 1.908297 \n",
"\n",
"Epoch 11\n",
"-------------------------------\n",
"loss: 1.878579 [ 0/60000]\n",
"loss: 1.934277 [20000/60000]\n",
"loss: 1.940233 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 38.0%, Avg loss: 1.877542 \n",
"\n",
"Epoch 12\n",
"-------------------------------\n",
"loss: 1.843474 [ 0/60000]\n",
"loss: 1.905277 [20000/60000]\n",
"loss: 1.914412 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 38.9%, Avg loss: 1.850760 \n",
"\n",
"Epoch 13\n",
"-------------------------------\n",
"loss: 1.812682 [ 0/60000]\n",
"loss: 1.879232 [20000/60000]\n",
"loss: 1.890868 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 39.6%, Avg loss: 1.826784 \n",
"\n",
"Epoch 14\n",
"-------------------------------\n",
"loss: 1.785155 [ 0/60000]\n",
"loss: 1.855434 [20000/60000]\n",
"loss: 1.868896 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 40.1%, Avg loss: 1.804814 \n",
"\n",
"Epoch 15\n",
"-------------------------------\n",
"loss: 1.760130 [ 0/60000]\n",
"loss: 1.833383 [20000/60000]\n",
"loss: 1.848235 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 40.5%, Avg loss: 1.784439 \n",
"\n",
"Epoch 16\n",
"-------------------------------\n",
"loss: 1.737113 [ 0/60000]\n",
"loss: 1.812643 [20000/60000]\n",
"loss: 1.828667 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 40.9%, Avg loss: 1.765419 \n",
"\n",
"Epoch 17\n",
"-------------------------------\n",
"loss: 1.715825 [ 0/60000]\n",
"loss: 1.792752 [20000/60000]\n",
"loss: 1.809863 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 41.2%, Avg loss: 1.747591 \n",
"\n",
"Epoch 18\n",
"-------------------------------\n",
"loss: 1.695961 [ 0/60000]\n",
"loss: 1.773857 [20000/60000]\n",
"loss: 1.791943 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 41.5%, Avg loss: 1.723384 \n",
"\n",
"Epoch 19\n",
"-------------------------------\n",
"loss: 1.669920 [ 0/60000]\n",
"loss: 1.726300 [20000/60000]\n",
"loss: 1.758728 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 42.2%, Avg loss: 1.677625 \n",
"\n",
"Epoch 20\n",
"-------------------------------\n",
"loss: 1.626003 [ 0/60000]\n",
"loss: 1.674431 [20000/60000]\n",
"loss: 1.733464 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 43.5%, Avg loss: 1.645380 \n",
"\n", "\n",
"Done!\n" "Done!\n"
] ]
} }
], ],
"source": [ "source": [
"epochs = 5\n", "epochs = 20\n",
"for t in range(epochs):\n", "for t in range(epochs):\n",
" print(f\"Epoch {t+1}\\n-------------------------------\")\n", " print(f\"Epoch {t+1}\\n-------------------------------\")\n",
" train(train_dataloader, model, loss_fn, optimizer)\n", " train(train_dataloader, model, loss_fn, optimizer)\n",
@ -504,7 +589,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 9,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"jupyter": { "jupyter": {
@ -539,7 +624,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 10,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"jupyter": { "jupyter": {
@ -553,7 +638,7 @@
"<All keys matched successfully>" "<All keys matched successfully>"
] ]
}, },
"execution_count": 13, "execution_count": 10,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -573,7 +658,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 22,
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"jupyter": { "jupyter": {
@ -585,7 +670,56 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n" "Predicted: \"Sandal\", Actual: \"Ankle boot\"\n",
"Predicted: \"Shirt\", Actual: \"Pullover\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Shirt\", Actual: \"Shirt\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Shirt\", Actual: \"Coat\"\n",
"Predicted: \"Shirt\", Actual: \"Shirt\"\n",
"Predicted: \"Sneaker\", Actual: \"Sandal\"\n",
"Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
"Predicted: \"T-shirt/top\", Actual: \"Coat\"\n",
"Predicted: \"Sandal\", Actual: \"Sandal\"\n",
"Predicted: \"Sandal\", Actual: \"Sneaker\"\n",
"Predicted: \"Dress\", Actual: \"Dress\"\n",
"Predicted: \"Shirt\", Actual: \"Coat\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Shirt\", Actual: \"Pullover\"\n",
"Predicted: \"Shirt\", Actual: \"Coat\"\n",
"Predicted: \"Bag\", Actual: \"Bag\"\n",
"Predicted: \"T-shirt/top\", Actual: \"T-shirt/top\"\n",
"Predicted: \"Shirt\", Actual: \"Pullover\"\n",
"Predicted: \"Sneaker\", Actual: \"Sandal\"\n",
"Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
"Predicted: \"Sneaker\", Actual: \"Ankle boot\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Shirt\", Actual: \"Coat\"\n",
"Predicted: \"Shirt\", Actual: \"Shirt\"\n",
"Predicted: \"Dress\", Actual: \"T-shirt/top\"\n",
"Predicted: \"Sandal\", Actual: \"Ankle boot\"\n",
"Predicted: \"Dress\", Actual: \"Dress\"\n",
"Predicted: \"Bag\", Actual: \"Bag\"\n",
"Predicted: \"Bag\", Actual: \"Bag\"\n",
"Predicted: \"Dress\", Actual: \"Dress\"\n",
"Predicted: \"Dress\", Actual: \"Dress\"\n",
"Predicted: \"Bag\", Actual: \"Bag\"\n",
"Predicted: \"T-shirt/top\", Actual: \"T-shirt/top\"\n",
"Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
"Predicted: \"Sandal\", Actual: \"Sandal\"\n",
"Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
"Predicted: \"T-shirt/top\", Actual: \"Ankle boot\"\n",
"Predicted: \"T-shirt/top\", Actual: \"Shirt\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Dress\", Actual: \"Dress\"\n",
"Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
"Predicted: \"Shirt\", Actual: \"Shirt\"\n",
"Predicted: \"Sandal\", Actual: \"Sneaker\"\n",
"Predicted: \"Shirt\", Actual: \"Pullover\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Trouser\", Actual: \"Pullover\"\n",
"Predicted: \"Shirt\", Actual: \"Pullover\"\n"
] ]
} }
], ],
@ -604,21 +738,20 @@
"]\n", "]\n",
"\n", "\n",
"model.eval()\n", "model.eval()\n",
"x, y = test_data[0][0], test_data[0][1]\n", "for i in range(0, 50):\n",
"with torch.no_grad():\n", " x, y = test_data[i][0], test_data[i][1]\n",
" pred = model(x)\n", " with torch.no_grad():\n",
" predicted, actual = classes[pred[0].argmax(0)], classes[y]\n", " pred = model(x)\n",
" print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')" " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
" print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"source": [ "outputs": [],
"Read more about `Saving & Loading your model <saveloadrun_tutorial.html>`_.\n", "source": []
"\n",
"\n"
]
} }
], ],
"metadata": { "metadata": {

35
requirements.txt Normal file
View file

@ -0,0 +1,35 @@
backcall==0.2.0
cycler==0.10.0
debugpy==1.4.1
decorator==5.0.9
entrypoints==0.3
ipykernel==6.2.0
ipython==7.26.0
ipython-genutils==0.2.0
jedi==0.18.0
jupyter-client==7.0.1
jupyter-core==4.7.1
kiwisolver==1.3.1
matplotlib==3.4.3
matplotlib-inline==0.1.2
nest-asyncio==1.5.1
numpy==1.21.2
pandas==1.3.2
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.3.1
prompt-toolkit==3.0.20
ptyprocess==0.7.0
Pygments==2.10.0
pyparsing==2.4.7
python-dateutil==2.8.2
pytz==2021.1
pyzmq==22.2.1
six==1.16.0
torch==1.9.0
torchvision==0.10.0
tornado==6.1
traitlets==5.0.5
typing-extensions==3.10.0.0
wcwidth==0.2.5