Fix args formatting for train.py
This commit is contained in:
parent
b240317d76
commit
6c7935489c
8
train.py
8
train.py
|
@ -10,9 +10,9 @@ from predict import predict
|
|||
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('-d', '--device', default='cpu',
|
||||
parser.add_argument('-d', '--device', default='cpu',
|
||||
help='device to train with')
|
||||
parser.add_argument('-i', '--input', default='data',
|
||||
parser.add_argument('-i', '--input', default='data',
|
||||
help='training data input file')
|
||||
parser.add_argument('-e', '--epochs', default=100, type=int,
|
||||
help='number of epochs to train for')
|
||||
|
@ -90,7 +90,7 @@ for t in range(args.epochs):
|
|||
'Iteration: {}'.format(iteration),
|
||||
'Loss: {}'.format(loss_value))
|
||||
|
||||
if iteration % 1 == 0:
|
||||
if iteration % 20 == 0:
|
||||
print(' '.join(predict(args.device, dataset, model, 'i am')))
|
||||
#torch.save(model.state_dict(),
|
||||
# torch.save(model.state_dict(),
|
||||
# 'checkpoint/model-{}.pth'.format(iteration))
|
||||
|
|
Loading…
Reference in a new issue