Move train.py to train_lstm.py and add new transformers training code
This commit is contained in:
parent
fb9b81284b
commit
d191b6204f
92
train.py
92
train.py
|
@ -1,97 +1,21 @@
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
import torch
|
from datasets import load_dataset
|
||||||
from torch import nn
|
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from dataset import Dataset
|
|
||||||
from model import Model
|
|
||||||
from predict import predict
|
|
||||||
|
|
||||||
|
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument('-d', '--device', default='cpu',
|
|
||||||
help='device to train with')
|
|
||||||
parser.add_argument('-i', '--input', default='data',
|
parser.add_argument('-i', '--input', default='data',
|
||||||
help='training data input file')
|
help='training data input file')
|
||||||
parser.add_argument('-o', '--output', default='model.pt',
|
|
||||||
help='trained model output file')
|
|
||||||
parser.add_argument('-e', '--epochs', default=10, type=int,
|
|
||||||
help='number of epochs to train for')
|
|
||||||
parser.add_argument('-s', '--seq-size', default=32, type=int,
|
|
||||||
help='sequence size')
|
|
||||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
|
||||||
help='size of each training batch')
|
|
||||||
parser.add_argument('-m', '--embedding-dim', default=64, type=int,
|
|
||||||
help='size of the embedding')
|
|
||||||
parser.add_argument('-l', '--lstm-size', default=256, type=int,
|
|
||||||
help='size of the LSTM hidden state')
|
|
||||||
parser.add_argument('-a', '--layers', default=3, type=int,
|
|
||||||
help='number of LSTM layers')
|
|
||||||
parser.add_argument('-r', '--dropout', default=0.2, type=int,
|
|
||||||
help='how much dropout to apply')
|
|
||||||
parser.add_argument('-n', '--max-norm', default=5, type=int,
|
|
||||||
help='maximum norm for gradient clipping')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
# Prepare dataloader
|
raw_dataset = load_dataset('text', data_files={'train': args.input}, keep_linebreaks=True)
|
||||||
dataset = Dataset(args.input, args.seq_size)
|
|
||||||
dataloader = DataLoader(dataset, args.batch_size)
|
|
||||||
print(len(dataloader))
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
|
||||||
|
tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns=raw_dataset["train"].column_names)
|
||||||
|
|
||||||
# Prepare model
|
model = AutoModelForCausalLM.from_pretrained('distilgpt2')
|
||||||
device = torch.device(args.device)
|
|
||||||
model = Model(dataset, args.embedding_dim, args.lstm_size,
|
|
||||||
args.layers, args.dropout).to(device)
|
|
||||||
print(model)
|
|
||||||
|
|
||||||
|
trainer = Trainer(model=model, train_dataset=tokenized_dataset['train'], tokenizer=tokenizer)
|
||||||
loss_fn = nn.CrossEntropyLoss()
|
trainer.train()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
||||||
|
|
||||||
|
|
||||||
for t in range(args.epochs):
|
|
||||||
state_h, state_c = model.zero_state(args.batch_size)
|
|
||||||
state_h = state_h.to(device)
|
|
||||||
state_c = state_c.to(device)
|
|
||||||
|
|
||||||
iteration = 0
|
|
||||||
for batch, (X, y) in enumerate(dataloader):
|
|
||||||
iteration += 1
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
X = X.to(device)
|
|
||||||
y = y.to(device)
|
|
||||||
|
|
||||||
# Compute prediction error
|
|
||||||
logits, (state_h, state_c) = model(X, (state_h, state_c))
|
|
||||||
loss = loss_fn(logits.transpose(1, 2), y)
|
|
||||||
|
|
||||||
loss_value = loss.item()
|
|
||||||
|
|
||||||
# Backpropogation
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
state_h = state_h.detach()
|
|
||||||
state_c = state_c.detach()
|
|
||||||
|
|
||||||
_ = torch.nn.utils.clip_grad_norm_(
|
|
||||||
model.parameters(), args.max_norm)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
if iteration % 1 == 0:
|
|
||||||
print('Epoch: {}/{}'.format(t, args.epochs),
|
|
||||||
'Iteration: {}'.format(iteration),
|
|
||||||
'Loss: {}'.format(loss_value))
|
|
||||||
|
|
||||||
if iteration % 10 == 0:
|
|
||||||
print(' '.join(predict(args.device, dataset, model, 'i am')))
|
|
||||||
|
|
||||||
|
|
||||||
torch.save(model, args.output)
|
|
||||||
|
|
97
train_lstm.py
Normal file
97
train_lstm.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from dataset import Dataset
|
||||||
|
from model import Model
|
||||||
|
from predict import predict
|
||||||
|
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('-d', '--device', default='cpu',
|
||||||
|
help='device to train with')
|
||||||
|
parser.add_argument('-i', '--input', default='data',
|
||||||
|
help='training data input file')
|
||||||
|
parser.add_argument('-o', '--output', default='model.pt',
|
||||||
|
help='trained model output file')
|
||||||
|
parser.add_argument('-e', '--epochs', default=10, type=int,
|
||||||
|
help='number of epochs to train for')
|
||||||
|
parser.add_argument('-s', '--seq-size', default=32, type=int,
|
||||||
|
help='sequence size')
|
||||||
|
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||||
|
help='size of each training batch')
|
||||||
|
parser.add_argument('-m', '--embedding-dim', default=64, type=int,
|
||||||
|
help='size of the embedding')
|
||||||
|
parser.add_argument('-l', '--lstm-size', default=256, type=int,
|
||||||
|
help='size of the LSTM hidden state')
|
||||||
|
parser.add_argument('-a', '--layers', default=3, type=int,
|
||||||
|
help='number of LSTM layers')
|
||||||
|
parser.add_argument('-r', '--dropout', default=0.2, type=int,
|
||||||
|
help='how much dropout to apply')
|
||||||
|
parser.add_argument('-n', '--max-norm', default=5, type=int,
|
||||||
|
help='maximum norm for gradient clipping')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
# Prepare dataloader
|
||||||
|
dataset = Dataset(args.input, args.seq_size)
|
||||||
|
dataloader = DataLoader(dataset, args.batch_size)
|
||||||
|
print(len(dataloader))
|
||||||
|
|
||||||
|
|
||||||
|
# Prepare model
|
||||||
|
device = torch.device(args.device)
|
||||||
|
model = Model(dataset, args.embedding_dim, args.lstm_size,
|
||||||
|
args.layers, args.dropout).to(device)
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
|
||||||
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
for t in range(args.epochs):
|
||||||
|
state_h, state_c = model.zero_state(args.batch_size)
|
||||||
|
state_h = state_h.to(device)
|
||||||
|
state_c = state_c.to(device)
|
||||||
|
|
||||||
|
iteration = 0
|
||||||
|
for batch, (X, y) in enumerate(dataloader):
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
X = X.to(device)
|
||||||
|
y = y.to(device)
|
||||||
|
|
||||||
|
# Compute prediction error
|
||||||
|
logits, (state_h, state_c) = model(X, (state_h, state_c))
|
||||||
|
loss = loss_fn(logits.transpose(1, 2), y)
|
||||||
|
|
||||||
|
loss_value = loss.item()
|
||||||
|
|
||||||
|
# Backpropogation
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
state_h = state_h.detach()
|
||||||
|
state_c = state_c.detach()
|
||||||
|
|
||||||
|
_ = torch.nn.utils.clip_grad_norm_(
|
||||||
|
model.parameters(), args.max_norm)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if iteration % 1 == 0:
|
||||||
|
print('Epoch: {}/{}'.format(t, args.epochs),
|
||||||
|
'Iteration: {}'.format(iteration),
|
||||||
|
'Loss: {}'.format(loss_value))
|
||||||
|
|
||||||
|
if iteration % 10 == 0:
|
||||||
|
print(' '.join(predict(args.device, dataset, model, 'i am')))
|
||||||
|
|
||||||
|
|
||||||
|
torch.save(model, args.output)
|
Loading…
Reference in a new issue