diff --git a/train.py b/train.py index b90f02b..bbdc54a 100644 --- a/train.py +++ b/train.py @@ -1,97 +1,21 @@ from argparse import ArgumentParser -import torch -from torch import nn -from torch.utils.data import DataLoader - -from dataset import Dataset -from model import Model -from predict import predict +from datasets import load_dataset +from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer parser = ArgumentParser() -parser.add_argument('-d', '--device', default='cpu', - help='device to train with') parser.add_argument('-i', '--input', default='data', help='training data input file') -parser.add_argument('-o', '--output', default='model.pt', - help='trained model output file') -parser.add_argument('-e', '--epochs', default=10, type=int, - help='number of epochs to train for') -parser.add_argument('-s', '--seq-size', default=32, type=int, - help='sequence size') -parser.add_argument('-b', '--batch-size', default=256, type=int, - help='size of each training batch') -parser.add_argument('-m', '--embedding-dim', default=64, type=int, - help='size of the embedding') -parser.add_argument('-l', '--lstm-size', default=256, type=int, - help='size of the LSTM hidden state') -parser.add_argument('-a', '--layers', default=3, type=int, - help='number of LSTM layers') -parser.add_argument('-r', '--dropout', default=0.2, type=int, - help='how much dropout to apply') -parser.add_argument('-n', '--max-norm', default=5, type=int, - help='maximum norm for gradient clipping') args = parser.parse_args() -# Prepare dataloader -dataset = Dataset(args.input, args.seq_size) -dataloader = DataLoader(dataset, args.batch_size) -print(len(dataloader)) +raw_dataset = load_dataset('text', data_files={'train': args.input}, keep_linebreaks=True) +tokenizer = AutoTokenizer.from_pretrained('distilgpt2') +tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns=raw_dataset["train"].column_names) -# Prepare model -device = torch.device(args.device) -model = Model(dataset, args.embedding_dim, args.lstm_size, - args.layers, args.dropout).to(device) -print(model) +model = AutoModelForCausalLM.from_pretrained('distilgpt2') - -loss_fn = nn.CrossEntropyLoss() -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - - -for t in range(args.epochs): - state_h, state_c = model.zero_state(args.batch_size) - state_h = state_h.to(device) - state_c = state_c.to(device) - - iteration = 0 - for batch, (X, y) in enumerate(dataloader): - iteration += 1 - - model.train() - - optimizer.zero_grad() - - X = X.to(device) - y = y.to(device) - - # Compute prediction error - logits, (state_h, state_c) = model(X, (state_h, state_c)) - loss = loss_fn(logits.transpose(1, 2), y) - - loss_value = loss.item() - - # Backpropogation - loss.backward() - - state_h = state_h.detach() - state_c = state_c.detach() - - _ = torch.nn.utils.clip_grad_norm_( - model.parameters(), args.max_norm) - - optimizer.step() - - if iteration % 1 == 0: - print('Epoch: {}/{}'.format(t, args.epochs), - 'Iteration: {}'.format(iteration), - 'Loss: {}'.format(loss_value)) - - if iteration % 10 == 0: - print(' '.join(predict(args.device, dataset, model, 'i am'))) - - -torch.save(model, args.output) +trainer = Trainer(model=model, train_dataset=tokenized_dataset['train'], tokenizer=tokenizer) +trainer.train() diff --git a/train_lstm.py b/train_lstm.py new file mode 100644 index 0000000..b90f02b --- /dev/null +++ b/train_lstm.py @@ -0,0 +1,97 @@ +from argparse import ArgumentParser + +import torch +from torch import nn +from torch.utils.data import DataLoader + +from dataset import Dataset +from model import Model +from predict import predict + + +parser = ArgumentParser() +parser.add_argument('-d', '--device', default='cpu', + help='device to train with') +parser.add_argument('-i', '--input', default='data', + help='training data input file') +parser.add_argument('-o', '--output', default='model.pt', + help='trained model output file') +parser.add_argument('-e', '--epochs', default=10, type=int, + help='number of epochs to train for') +parser.add_argument('-s', '--seq-size', default=32, type=int, + help='sequence size') +parser.add_argument('-b', '--batch-size', default=256, type=int, + help='size of each training batch') +parser.add_argument('-m', '--embedding-dim', default=64, type=int, + help='size of the embedding') +parser.add_argument('-l', '--lstm-size', default=256, type=int, + help='size of the LSTM hidden state') +parser.add_argument('-a', '--layers', default=3, type=int, + help='number of LSTM layers') +parser.add_argument('-r', '--dropout', default=0.2, type=int, + help='how much dropout to apply') +parser.add_argument('-n', '--max-norm', default=5, type=int, + help='maximum norm for gradient clipping') +args = parser.parse_args() + + +# Prepare dataloader +dataset = Dataset(args.input, args.seq_size) +dataloader = DataLoader(dataset, args.batch_size) +print(len(dataloader)) + + +# Prepare model +device = torch.device(args.device) +model = Model(dataset, args.embedding_dim, args.lstm_size, + args.layers, args.dropout).to(device) +print(model) + + +loss_fn = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + +for t in range(args.epochs): + state_h, state_c = model.zero_state(args.batch_size) + state_h = state_h.to(device) + state_c = state_c.to(device) + + iteration = 0 + for batch, (X, y) in enumerate(dataloader): + iteration += 1 + + model.train() + + optimizer.zero_grad() + + X = X.to(device) + y = y.to(device) + + # Compute prediction error + logits, (state_h, state_c) = model(X, (state_h, state_c)) + loss = loss_fn(logits.transpose(1, 2), y) + + loss_value = loss.item() + + # Backpropogation + loss.backward() + + state_h = state_h.detach() + state_c = state_c.detach() + + _ = torch.nn.utils.clip_grad_norm_( + model.parameters(), args.max_norm) + + optimizer.step() + + if iteration % 1 == 0: + print('Epoch: {}/{}'.format(t, args.epochs), + 'Iteration: {}'.format(iteration), + 'Loss: {}'.format(loss_value)) + + if iteration % 10 == 0: + print(' '.join(predict(args.device, dataset, model, 'i am'))) + + +torch.save(model, args.output)