Optimize LSTM training

This commit is contained in:
Anthony Wang 2022-02-22 12:36:52 -06:00
parent 772f3c3692
commit fb9b81284b
Signed by: a
GPG key ID: BC96B00AEC5F2D76

View file

@ -16,13 +16,13 @@ parser.add_argument('-i', '--input', default='data',
help='training data input file')
parser.add_argument('-o', '--output', default='model.pt',
help='trained model output file')
parser.add_argument('-e', '--epochs', default=100, type=int,
parser.add_argument('-e', '--epochs', default=10, type=int,
help='number of epochs to train for')
parser.add_argument('-s', '--seq-size', default=32, type=int,
help='sequence size')
parser.add_argument('-b', '--batch-size', default=256, type=int,
help='size of each training batch')
parser.add_argument('-m', '--embedding-dim', default=256, type=int,
parser.add_argument('-m', '--embedding-dim', default=64, type=int,
help='size of the embedding')
parser.add_argument('-l', '--lstm-size', default=256, type=int,
help='size of the LSTM hidden state')
@ -58,8 +58,6 @@ for t in range(args.epochs):
state_c = state_c.to(device)
iteration = 0
print(len(dataloader))
for batch, (X, y) in enumerate(dataloader):
iteration += 1
@ -67,8 +65,8 @@ for t in range(args.epochs):
optimizer.zero_grad()
X = torch.tensor(X).to(device)
y = torch.tensor(y).to(device)
X = X.to(device)
y = y.to(device)
# Compute prediction error
logits, (state_h, state_c) = model(X, (state_h, state_c))
@ -92,7 +90,7 @@ for t in range(args.epochs):
'Iteration: {}'.format(iteration),
'Loss: {}'.format(loss_value))
if iteration % 20 == 0:
if iteration % 10 == 0:
print(' '.join(predict(args.device, dataset, model, 'i am')))