Add a Kaggle tutorial

This commit is contained in:
Anthony Wang 2021-08-23 21:59:53 -05:00
parent c6ec5f983d
commit 44d82a9276
4 changed files with 1241 additions and 87 deletions

File diff suppressed because one or more lines are too long

BIN
model.pth

Binary file not shown.

View file

@ -126,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 23,
"metadata": {
"collapsed": false,
"jupyter": {
@ -138,13 +138,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])\n",
"Shape of y: torch.Size([64]) torch.int64\n"
"Shape of X [N, C, H, W]: torch.Size([200, 1, 28, 28])\n",
"Shape of y: torch.Size([200]) torch.int64\n"
]
}
],
"source": [
"batch_size = 64\n",
"batch_size = 200\n",
"\n",
"# Create data loaders.\n",
"train_dataloader = DataLoader(training_data, batch_size=batch_size)\n",
@ -189,7 +189,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 24,
"metadata": {
"collapsed": false,
"jupyter": {
@ -205,11 +205,11 @@
"NeuralNetwork(\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" (linear_relu_stack): Sequential(\n",
" (0): Linear(in_features=784, out_features=512, bias=True)\n",
" (0): Linear(in_features=784, out_features=1024, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=512, out_features=512, bias=True)\n",
" (2): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (3): ReLU()\n",
" (4): Linear(in_features=512, out_features=10, bias=True)\n",
" (4): Linear(in_features=1024, out_features=10, bias=True)\n",
" (5): ReLU()\n",
" )\n",
")\n"
@ -227,11 +227,11 @@
" super(NeuralNetwork, self).__init__()\n",
" self.flatten = nn.Flatten()\n",
" self.linear_relu_stack = nn.Sequential(\n",
" nn.Linear(28*28, 512),\n",
" nn.Linear(28*28, 1024),\n",
" nn.ReLU(),\n",
" nn.Linear(512, 512),\n",
" nn.Linear(1024, 1024),\n",
" nn.ReLU(),\n",
" nn.Linear(512, 10),\n",
" nn.Linear(1024, 10),\n",
" nn.ReLU()\n",
" )\n",
"\n",
@ -275,7 +275,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 25,
"metadata": {
"collapsed": false,
"jupyter": {
@ -299,7 +299,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 27,
"metadata": {
"collapsed": false,
"jupyter": {
@ -337,7 +337,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 28,
"metadata": {
"collapsed": false,
"jupyter": {
@ -374,7 +374,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 29,
"metadata": {
"collapsed": false,
"jupyter": {
@ -388,85 +388,170 @@
"text": [
"Epoch 1\n",
"-------------------------------\n",
"loss: 2.300270 [ 0/60000]\n",
"loss: 2.290948 [ 6400/60000]\n",
"loss: 2.280627 [12800/60000]\n",
"loss: 2.283042 [19200/60000]\n",
"loss: 2.268817 [25600/60000]\n",
"loss: 2.262104 [32000/60000]\n",
"loss: 2.248093 [38400/60000]\n",
"loss: 2.233517 [44800/60000]\n",
"loss: 2.234299 [51200/60000]\n",
"loss: 2.234841 [57600/60000]\n",
"loss: 2.301592 [ 0/60000]\n",
"loss: 2.289894 [20000/60000]\n",
"loss: 2.280160 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 51.1%, Avg loss: 2.221716 \n",
" Accuracy: 30.9%, Avg loss: 2.269258 \n",
"\n",
"Epoch 2\n",
"-------------------------------\n",
"loss: 2.203925 [ 0/60000]\n",
"loss: 2.200477 [ 6400/60000]\n",
"loss: 2.180246 [12800/60000]\n",
"loss: 2.213840 [19200/60000]\n",
"loss: 2.155768 [25600/60000]\n",
"loss: 2.154045 [32000/60000]\n",
"loss: 2.130811 [38400/60000]\n",
"loss: 2.101326 [44800/60000]\n",
"loss: 2.118105 [51200/60000]\n",
"loss: 2.116674 [57600/60000]\n",
"loss: 2.268190 [ 0/60000]\n",
"loss: 2.259592 [20000/60000]\n",
"loss: 2.251709 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 51.0%, Avg loss: 2.093179 \n",
" Accuracy: 34.5%, Avg loss: 2.238048 \n",
"\n",
"Epoch 3\n",
"-------------------------------\n",
"loss: 2.054594 [ 0/60000]\n",
"loss: 2.047193 [ 6400/60000]\n",
"loss: 2.009665 [12800/60000]\n",
"loss: 2.093936 [19200/60000]\n",
"loss: 1.965194 [25600/60000]\n",
"loss: 1.976750 [32000/60000]\n",
"loss: 1.938592 [38400/60000]\n",
"loss: 1.889513 [44800/60000]\n",
"loss: 1.942611 [51200/60000]\n",
"loss: 1.936963 [57600/60000]\n",
"loss: 2.236049 [ 0/60000]\n",
"loss: 2.230108 [20000/60000]\n",
"loss: 2.223023 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 51.2%, Avg loss: 1.900150 \n",
" Accuracy: 34.4%, Avg loss: 2.204923 \n",
"\n",
"Epoch 4\n",
"-------------------------------\n",
"loss: 1.837173 [ 0/60000]\n",
"loss: 1.822518 [ 6400/60000]\n",
"loss: 1.775139 [12800/60000]\n",
"loss: 1.925843 [19200/60000]\n",
"loss: 1.731390 [25600/60000]\n",
"loss: 1.778743 [32000/60000]\n",
"loss: 1.714922 [38400/60000]\n",
"loss: 1.670009 [44800/60000]\n",
"loss: 1.755909 [51200/60000]\n",
"loss: 1.763030 [57600/60000]\n",
"loss: 2.201698 [ 0/60000]\n",
"loss: 2.199297 [20000/60000]\n",
"loss: 2.192280 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 52.5%, Avg loss: 1.711389 \n",
" Accuracy: 34.7%, Avg loss: 2.168218 \n",
"\n",
"Epoch 5\n",
"-------------------------------\n",
"loss: 1.621799 [ 0/60000]\n",
"loss: 1.615258 [ 6400/60000]\n",
"loss: 1.567131 [12800/60000]\n",
"loss: 1.768921 [19200/60000]\n",
"loss: 1.539987 [25600/60000]\n",
"loss: 1.627408 [32000/60000]\n",
"loss: 1.533756 [38400/60000]\n",
"loss: 1.510703 [44800/60000]\n",
"loss: 1.603583 [51200/60000]\n",
"loss: 1.636089 [57600/60000]\n",
"loss: 2.163056 [ 0/60000]\n",
"loss: 2.165860 [20000/60000]\n",
"loss: 2.158360 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 53.2%, Avg loss: 1.568043 \n",
" Accuracy: 35.1%, Avg loss: 2.126790 \n",
"\n",
"Epoch 6\n",
"-------------------------------\n",
"loss: 2.118984 [ 0/60000]\n",
"loss: 2.128839 [20000/60000]\n",
"loss: 2.120741 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 35.5%, Avg loss: 2.080902 \n",
"\n",
"Epoch 7\n",
"-------------------------------\n",
"loss: 2.069693 [ 0/60000]\n",
"loss: 2.088246 [20000/60000]\n",
"loss: 2.080452 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 35.9%, Avg loss: 2.032795 \n",
"\n",
"Epoch 8\n",
"-------------------------------\n",
"loss: 2.017604 [ 0/60000]\n",
"loss: 2.045965 [20000/60000]\n",
"loss: 2.040246 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 36.2%, Avg loss: 1.986216 \n",
"\n",
"Epoch 9\n",
"-------------------------------\n",
"loss: 1.966350 [ 0/60000]\n",
"loss: 2.004781 [20000/60000]\n",
"loss: 2.002753 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 36.6%, Avg loss: 1.944349 \n",
"\n",
"Epoch 10\n",
"-------------------------------\n",
"loss: 1.919495 [ 0/60000]\n",
"loss: 1.967277 [20000/60000]\n",
"loss: 1.969415 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 37.1%, Avg loss: 1.908297 \n",
"\n",
"Epoch 11\n",
"-------------------------------\n",
"loss: 1.878579 [ 0/60000]\n",
"loss: 1.934277 [20000/60000]\n",
"loss: 1.940233 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 38.0%, Avg loss: 1.877542 \n",
"\n",
"Epoch 12\n",
"-------------------------------\n",
"loss: 1.843474 [ 0/60000]\n",
"loss: 1.905277 [20000/60000]\n",
"loss: 1.914412 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 38.9%, Avg loss: 1.850760 \n",
"\n",
"Epoch 13\n",
"-------------------------------\n",
"loss: 1.812682 [ 0/60000]\n",
"loss: 1.879232 [20000/60000]\n",
"loss: 1.890868 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 39.6%, Avg loss: 1.826784 \n",
"\n",
"Epoch 14\n",
"-------------------------------\n",
"loss: 1.785155 [ 0/60000]\n",
"loss: 1.855434 [20000/60000]\n",
"loss: 1.868896 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 40.1%, Avg loss: 1.804814 \n",
"\n",
"Epoch 15\n",
"-------------------------------\n",
"loss: 1.760130 [ 0/60000]\n",
"loss: 1.833383 [20000/60000]\n",
"loss: 1.848235 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 40.5%, Avg loss: 1.784439 \n",
"\n",
"Epoch 16\n",
"-------------------------------\n",
"loss: 1.737113 [ 0/60000]\n",
"loss: 1.812643 [20000/60000]\n",
"loss: 1.828667 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 40.9%, Avg loss: 1.765419 \n",
"\n",
"Epoch 17\n",
"-------------------------------\n",
"loss: 1.715825 [ 0/60000]\n",
"loss: 1.792752 [20000/60000]\n",
"loss: 1.809863 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 41.2%, Avg loss: 1.747591 \n",
"\n",
"Epoch 18\n",
"-------------------------------\n",
"loss: 1.695961 [ 0/60000]\n",
"loss: 1.773857 [20000/60000]\n",
"loss: 1.791943 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 41.5%, Avg loss: 1.723384 \n",
"\n",
"Epoch 19\n",
"-------------------------------\n",
"loss: 1.669920 [ 0/60000]\n",
"loss: 1.726300 [20000/60000]\n",
"loss: 1.758728 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 42.2%, Avg loss: 1.677625 \n",
"\n",
"Epoch 20\n",
"-------------------------------\n",
"loss: 1.626003 [ 0/60000]\n",
"loss: 1.674431 [20000/60000]\n",
"loss: 1.733464 [40000/60000]\n",
"Test Error: \n",
" Accuracy: 43.5%, Avg loss: 1.645380 \n",
"\n",
"Done!\n"
]
}
],
"source": [
"epochs = 5\n",
"epochs = 20\n",
"for t in range(epochs):\n",
" print(f\"Epoch {t+1}\\n-------------------------------\")\n",
" train(train_dataloader, model, loss_fn, optimizer)\n",
@ -504,7 +589,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 9,
"metadata": {
"collapsed": false,
"jupyter": {
@ -539,7 +624,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 10,
"metadata": {
"collapsed": false,
"jupyter": {
@ -553,7 +638,7 @@
"<All keys matched successfully>"
]
},
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@ -573,7 +658,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 22,
"metadata": {
"collapsed": false,
"jupyter": {
@ -585,7 +670,56 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n"
"Predicted: \"Sandal\", Actual: \"Ankle boot\"\n",
"Predicted: \"Shirt\", Actual: \"Pullover\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Shirt\", Actual: \"Shirt\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Shirt\", Actual: \"Coat\"\n",
"Predicted: \"Shirt\", Actual: \"Shirt\"\n",
"Predicted: \"Sneaker\", Actual: \"Sandal\"\n",
"Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
"Predicted: \"T-shirt/top\", Actual: \"Coat\"\n",
"Predicted: \"Sandal\", Actual: \"Sandal\"\n",
"Predicted: \"Sandal\", Actual: \"Sneaker\"\n",
"Predicted: \"Dress\", Actual: \"Dress\"\n",
"Predicted: \"Shirt\", Actual: \"Coat\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Shirt\", Actual: \"Pullover\"\n",
"Predicted: \"Shirt\", Actual: \"Coat\"\n",
"Predicted: \"Bag\", Actual: \"Bag\"\n",
"Predicted: \"T-shirt/top\", Actual: \"T-shirt/top\"\n",
"Predicted: \"Shirt\", Actual: \"Pullover\"\n",
"Predicted: \"Sneaker\", Actual: \"Sandal\"\n",
"Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
"Predicted: \"Sneaker\", Actual: \"Ankle boot\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Shirt\", Actual: \"Coat\"\n",
"Predicted: \"Shirt\", Actual: \"Shirt\"\n",
"Predicted: \"Dress\", Actual: \"T-shirt/top\"\n",
"Predicted: \"Sandal\", Actual: \"Ankle boot\"\n",
"Predicted: \"Dress\", Actual: \"Dress\"\n",
"Predicted: \"Bag\", Actual: \"Bag\"\n",
"Predicted: \"Bag\", Actual: \"Bag\"\n",
"Predicted: \"Dress\", Actual: \"Dress\"\n",
"Predicted: \"Dress\", Actual: \"Dress\"\n",
"Predicted: \"Bag\", Actual: \"Bag\"\n",
"Predicted: \"T-shirt/top\", Actual: \"T-shirt/top\"\n",
"Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
"Predicted: \"Sandal\", Actual: \"Sandal\"\n",
"Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
"Predicted: \"T-shirt/top\", Actual: \"Ankle boot\"\n",
"Predicted: \"T-shirt/top\", Actual: \"Shirt\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Dress\", Actual: \"Dress\"\n",
"Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
"Predicted: \"Shirt\", Actual: \"Shirt\"\n",
"Predicted: \"Sandal\", Actual: \"Sneaker\"\n",
"Predicted: \"Shirt\", Actual: \"Pullover\"\n",
"Predicted: \"Trouser\", Actual: \"Trouser\"\n",
"Predicted: \"Trouser\", Actual: \"Pullover\"\n",
"Predicted: \"Shirt\", Actual: \"Pullover\"\n"
]
}
],
@ -604,21 +738,20 @@
"]\n",
"\n",
"model.eval()\n",
"x, y = test_data[0][0], test_data[0][1]\n",
"with torch.no_grad():\n",
" pred = model(x)\n",
" predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
" print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
"for i in range(0, 50):\n",
" x, y = test_data[i][0], test_data[i][1]\n",
" with torch.no_grad():\n",
" pred = model(x)\n",
" predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
" print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": [
"Read more about `Saving & Loading your model <saveloadrun_tutorial.html>`_.\n",
"\n",
"\n"
]
"outputs": [],
"source": []
}
],
"metadata": {

35
requirements.txt Normal file
View file

@ -0,0 +1,35 @@
backcall==0.2.0
cycler==0.10.0
debugpy==1.4.1
decorator==5.0.9
entrypoints==0.3
ipykernel==6.2.0
ipython==7.26.0
ipython-genutils==0.2.0
jedi==0.18.0
jupyter-client==7.0.1
jupyter-core==4.7.1
kiwisolver==1.3.1
matplotlib==3.4.3
matplotlib-inline==0.1.2
nest-asyncio==1.5.1
numpy==1.21.2
pandas==1.3.2
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.3.1
prompt-toolkit==3.0.20
ptyprocess==0.7.0
Pygments==2.10.0
pyparsing==2.4.7
python-dateutil==2.8.2
pytz==2021.1
pyzmq==22.2.1
six==1.16.0
torch==1.9.0
torchvision==0.10.0
tornado==6.1
traitlets==5.0.5
typing-extensions==3.10.0.0
wcwidth==0.2.5