Fix args formatting for train.py

This commit is contained in:
Anthony Wang 2022-02-21 16:40:09 -06:00
parent b240317d76
commit 6c7935489c
Signed by: a
GPG key ID: BC96B00AEC5F2D76

View file

@ -10,9 +10,9 @@ from predict import predict
parser = ArgumentParser()
parser.add_argument('-d', '--device', default='cpu',
parser.add_argument('-d', '--device', default='cpu',
help='device to train with')
parser.add_argument('-i', '--input', default='data',
parser.add_argument('-i', '--input', default='data',
help='training data input file')
parser.add_argument('-e', '--epochs', default=100, type=int,
help='number of epochs to train for')
@ -90,7 +90,7 @@ for t in range(args.epochs):
'Iteration: {}'.format(iteration),
'Loss: {}'.format(loss_value))
if iteration % 1 == 0:
if iteration % 20 == 0:
print(' '.join(predict(args.device, dataset, model, 'i am')))
#torch.save(model.state_dict(),
# torch.save(model.state_dict(),
# 'checkpoint/model-{}.pth'.format(iteration))