diff --git a/mnist.py b/mnist.py index d6bed0f..4466052 100644 --- a/mnist.py +++ b/mnist.py @@ -8,14 +8,14 @@ import matplotlib.pyplot as plt training_data = datasets.MNIST( - root="data", + root=".data", train=True, download=True, transform=ToTensor(), ) test_data = datasets.MNIST( - root="data", + root=".data", train=False, download=True, transform=ToTensor(),