#!/usr/bin/python3 # https://github.com/ChunML/NLP/blob/master/text_generation/train_pt.py import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from collections import Counter import os from argparse import Namespace flags = Namespace( train_file='data', seq_size=32, batch_size=256, embedding_size=64, lstm_size=64, gradients_norm=5, initial_words=['i', 'am'], predict_top_k=3, checkpoint_path='checkpoint', ) def get_data_from_file(train_file, batch_size, seq_size): with open(train_file, 'r', encoding='utf-8') as f: text = f.read() text = text.split() word_counts = Counter(text) sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True) int_to_vocab = {k: w for k, w in enumerate(sorted_vocab)} vocab_to_int = {w: k for k, w in int_to_vocab.items()} n_vocab = len(int_to_vocab) print('Vocabulary size', n_vocab) int_text = [vocab_to_int[w] for w in text] num_batches = int(len(int_text) / (seq_size * batch_size)) in_text = int_text[:num_batches * batch_size * seq_size] out_text = np.zeros_like(in_text) out_text[:-1] = in_text[1:] out_text[-1] = in_text[0] in_text = np.reshape(in_text, (batch_size, -1)) out_text = np.reshape(out_text, (batch_size, -1)) return int_to_vocab, vocab_to_int, n_vocab, in_text, out_text def get_batches(in_text, out_text, batch_size, seq_size): num_batches = np.prod(in_text.shape) // (seq_size * batch_size) for i in range(0, num_batches * seq_size, seq_size): yield in_text[:, i:i+seq_size], out_text[:, i:i+seq_size] class RNNModule(nn.Module): def __init__(self, n_vocab, seq_size, embedding_size, lstm_size): super(RNNModule, self).__init__() self.seq_size = seq_size self.lstm_size = lstm_size self.embedding = nn.Embedding(n_vocab, embedding_size) self.lstm = nn.LSTM(embedding_size, lstm_size, batch_first=True) self.dense = nn.Linear(lstm_size, n_vocab) def forward(self, x, prev_state): embed = self.embedding(x) output, state = self.lstm(embed, prev_state) logits = self.dense(output) return logits, state def zero_state(self, batch_size): return (torch.zeros(1, batch_size, self.lstm_size), torch.zeros(1, batch_size, self.lstm_size)) def get_loss_and_train_op(net, lr=0.001): criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(net.parameters(), lr=lr) return criterion, optimizer def predict(device, net, words, n_vocab, vocab_to_int, int_to_vocab, top_k=5): net.eval() words = ['i', 'am'] state_h, state_c = net.zero_state(1) state_h = state_h.to(device) state_c = state_c.to(device) for w in words: ix = torch.tensor([[vocab_to_int[w]]]).to(device) output, (state_h, state_c) = net(ix, (state_h, state_c)) _, top_ix = torch.topk(output[0], k=top_k) choices = top_ix.tolist() choice = np.random.choice(choices[0]) words.append(int_to_vocab[choice]) for _ in range(100): ix = torch.tensor([[choice]]).to(device) output, (state_h, state_c) = net(ix, (state_h, state_c)) _, top_ix = torch.topk(output[0], k=top_k) choices = top_ix.tolist() choice = np.random.choice(choices[0]) words.append(int_to_vocab[choice]) print(' '.join(words).encode('utf-8')) def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') int_to_vocab, vocab_to_int, n_vocab, in_text, out_text = get_data_from_file( flags.train_file, flags.batch_size, flags.seq_size) net = RNNModule(n_vocab, flags.seq_size, flags.embedding_size, flags.lstm_size) net = net.to(device) criterion, optimizer = get_loss_and_train_op(net, 0.01) iteration = 0 for e in range(200): batches = get_batches(in_text, out_text, flags.batch_size, flags.seq_size) state_h, state_c = net.zero_state(flags.batch_size) state_h = state_h.to(device) state_c = state_c.to(device) for x, y in batches: iteration += 1 net.train() optimizer.zero_grad() x = torch.tensor(x).to(device) y = torch.tensor(y).to(device) logits, (state_h, state_c) = net(x, (state_h, state_c)) loss = criterion(logits.transpose(1, 2), y) loss_value = loss.item() loss.backward() state_h = state_h.detach() state_c = state_c.detach() _ = torch.nn.utils.clip_grad_norm_( net.parameters(), flags.gradients_norm) optimizer.step() if iteration % 1 == 0: print('Epoch: {}/{}'.format(e, 20), 'Iteration: {}'.format(iteration), 'Loss: {}'.format(loss_value)) if iteration % 1000 == 0: predict(device, net, flags.initial_words, n_vocab, vocab_to_int, int_to_vocab, top_k=3) torch.save(net.state_dict(), 'checkpoint/model-{}.pth'.format(iteration)) if __name__ == '__main__': main()