ebooks/train_lstm.py

98 lines
2.9 KiB
Python

from argparse import ArgumentParser
import torch
from torch import nn
from torch.utils.data import DataLoader
from dataset import Dataset
from model import Model
from predict import predict
parser = ArgumentParser()
parser.add_argument('-d', '--device', default='cpu',
help='device to train with')
parser.add_argument('-i', '--input', default='data',
help='training data input file')
parser.add_argument('-o', '--output', default='model.pt',
help='trained model output file')
parser.add_argument('-e', '--epochs', default=10, type=int,
help='number of epochs to train for')
parser.add_argument('-s', '--seq-size', default=32, type=int,
help='sequence size')
parser.add_argument('-b', '--batch-size', default=256, type=int,
help='size of each training batch')
parser.add_argument('-m', '--embedding-dim', default=64, type=int,
help='size of the embedding')
parser.add_argument('-l', '--lstm-size', default=256, type=int,
help='size of the LSTM hidden state')
parser.add_argument('-a', '--layers', default=3, type=int,
help='number of LSTM layers')
parser.add_argument('-r', '--dropout', default=0.2, type=int,
help='how much dropout to apply')
parser.add_argument('-n', '--max-norm', default=5, type=int,
help='maximum norm for gradient clipping')
args = parser.parse_args()
# Prepare dataloader
dataset = Dataset(args.input, args.seq_size)
dataloader = DataLoader(dataset, args.batch_size)
print(len(dataloader))
# Prepare model
device = torch.device(args.device)
model = Model(dataset, args.embedding_dim, args.lstm_size,
args.layers, args.dropout).to(device)
print(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(args.epochs):
state_h, state_c = model.zero_state(args.batch_size)
state_h = state_h.to(device)
state_c = state_c.to(device)
iteration = 0
for batch, (X, y) in enumerate(dataloader):
iteration += 1
model.train()
optimizer.zero_grad()
X = X.to(device)
y = y.to(device)
# Compute prediction error
logits, (state_h, state_c) = model(X, (state_h, state_c))
loss = loss_fn(logits.transpose(1, 2), y)
loss_value = loss.item()
# Backpropogation
loss.backward()
state_h = state_h.detach()
state_c = state_c.detach()
_ = torch.nn.utils.clip_grad_norm_(
model.parameters(), args.max_norm)
optimizer.step()
if iteration % 1 == 0:
print('Epoch: {}/{}'.format(t, args.epochs),
'Iteration: {}'.format(iteration),
'Loss: {}'.format(loss_value))
if iteration % 10 == 0:
print(' '.join(predict(args.device, dataset, model, 'i am')))
torch.save(model, args.output)