98 lines
2.9 KiB

from argparse import ArgumentParser
import torch
from torch import nn
from torch.utils.data import DataLoader
from dataset import Dataset
from model import Model
from predict import predict
parser = ArgumentParser()
parser.add_argument('-d', '--device', default='cpu',
help='device to train with')
parser.add_argument('-i', '--input', default='data',
help='training data input file')
parser.add_argument('-o', '--output', default='model.pt',
help='trained model output file')
parser.add_argument('-e', '--epochs', default=10, type=int,
help='number of epochs to train for')
parser.add_argument('-s', '--seq-size', default=32, type=int,
help='sequence size')
parser.add_argument('-b', '--batch-size', default=256, type=int,
help='size of each training batch')
parser.add_argument('-m', '--embedding-dim', default=64, type=int,
help='size of the embedding')
parser.add_argument('-l', '--lstm-size', default=256, type=int,
help='size of the LSTM hidden state')
parser.add_argument('-a', '--layers', default=3, type=int,
help='number of LSTM layers')
parser.add_argument('-r', '--dropout', default=0.2, type=int,
help='how much dropout to apply')
parser.add_argument('-n', '--max-norm', default=5, type=int,
help='maximum norm for gradient clipping')
args = parser.parse_args()
# Prepare dataloader
dataset = Dataset(args.input, args.seq_size)
dataloader = DataLoader(dataset, args.batch_size)
# Prepare model
device = torch.device(args.device)
model = Model(dataset, args.embedding_dim, args.lstm_size,
args.layers, args.dropout).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(args.epochs):
state_h, state_c = model.zero_state(args.batch_size)
state_h = state_h.to(device)
state_c = state_c.to(device)
iteration = 0
for batch, (X, y) in enumerate(dataloader):
iteration += 1
X = X.to(device)
y = y.to(device)
# Compute prediction error
logits, (state_h, state_c) = model(X, (state_h, state_c))
loss = loss_fn(logits.transpose(1, 2), y)
loss_value = loss.item()
# Backpropogation
state_h = state_h.detach()
state_c = state_c.detach()
_ = torch.nn.utils.clip_grad_norm_(
model.parameters(), args.max_norm)
if iteration % 1 == 0:
print('Epoch: {}/{}'.format(t, args.epochs),
'Iteration: {}'.format(iteration),
'Loss: {}'.format(loss_value))
if iteration % 10 == 0:
print(' '.join(predict(args.device, dataset, model, 'i am')))
torch.save(model, args.output)