ebooks/train.py
2022-02-21 15:20:00 -06:00

79 lines
2.2 KiB
Python

from argparse import ArgumentParser
import torch
from torch import nn
from torch.utils.data import DataLoader
from dataset import Dataset
from model import Model
parser = ArgumentParser()
parser.add_argument('-d', '--device', help='device to train with', default='cpu')
parser.add_argument('-i', '--input', help='training data file', default='data')
parser.add_argument('-e', '--epochs', help='number of epochs to train for', default=100)
parser.add_argument('-s', '--seq-size', help='sequence size', default=32)
parser.add_argument('-b', '--batch-size', help='size of each training batch', default=256)
args = parser.parse_args()
# Prepare dataloader
dataset = Dataset(args.input, args.seq_size)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
print(len(dataloader))
# Prepare model
model = Model(dataset, 512, 512, 3, 0.2).to(args.device)
print(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(args.epochs):
state_h, state_c = model.zero_state(args.batch_size)
state_h = state_h.to(args.device)
state_c = state_c.to(args.device)
iteration = 0
for batch, (X, y) in enumerate(dataloader):
model.train()
iteration += 1
optimizer.zero_grad()
X = torch.tensor(X).to(args.device)
y = torch.tensor(y).to(args.device)
# Compute prediction error
logits, (state_h, state_c) = model(X, (state_h, state_c))
loss = loss_fn(logits.transpose(1, 2), y)
loss_value = loss.item()
# Backpropogation
loss.backward()
state_h = state_h.detach()
state_c = state_c.detach()
_ = torch.nn.utils.clip_grad_norm_(
model.parameters(), flags.gradients_norm)
optimizer.step()
if iteration % 10 == 0:
print('Epoch: {}/{}'.format(t, args.epochs),
'Iteration: {}'.format(iteration),
'Loss: {}'.format(loss_value))
if iteration % 1000 == 0:
predict(args.device, net, flags.initial_words, n_vocab,
vocab_to_int, int_to_vocab, top_k=3)
torch.save(net.state_dict(),
'checkpoint/model-{}.pth'.format(iteration))