diff --git a/dataset.py b/dataset.py index 193b1fb..33ef96f 100644 --- a/dataset.py +++ b/dataset.py @@ -5,8 +5,6 @@ import torch class Dataset(torch.utils.data.Dataset): def __init__(self, file, seq_size): - self.seq_size = seq_size - with open(file, 'r') as f: self.words = f.read().split() @@ -20,6 +18,8 @@ class Dataset(torch.utils.data.Dataset): self.words_indexes = [self.word_to_index[w] for w in self.words] + self.seq_size = seq_size + def __len__(self): return len(self.words_indexes) - self.seq_size diff --git a/model.py b/model.py index e54bba9..3ace3a2 100644 --- a/model.py +++ b/model.py @@ -1,8 +1,9 @@ +import torch import torch.nn as nn class Model(nn.Module): - def __init__(self, dataset, embedding_size, lstm_size, num_layers, dropout): + def __init__(self, dataset, embedding_dim, lstm_size, num_layers, dropout): super(Model, self).__init__() self.seq_size = dataset.seq_size @@ -10,10 +11,10 @@ class Model(nn.Module): self.num_layers = num_layers n_vocab = len(dataset.uniq_words) - self.embedding = nn.Embedding(n_vocab, embedding_size) + self.embedding = nn.Embedding(n_vocab, embedding_dim) self.lstm = nn.LSTM( - input_size=embedding_size, + input_size=embedding_dim, hidden_size=lstm_size, num_layers=num_layers, batch_first=True, diff --git a/test.py b/test.py index 981b77d..83b3bfe 100644 --- a/test.py +++ b/test.py @@ -12,7 +12,7 @@ from argparse import Namespace flags = Namespace( - train_file='text-small', + train_file='data', seq_size=32, batch_size=256, embedding_size=64, diff --git a/train.py b/train.py index 7941f3d..33ec5a0 100644 --- a/train.py +++ b/train.py @@ -33,24 +33,24 @@ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for t in range(args.epochs): - model.train() - - state_h, state_c = net.zero_state(flags.batch_size) + state_h, state_c = model.zero_state(args.batch_size) state_h = state_h.to(args.device) state_c = state_c.to(args.device) iteration = 0 for batch, (X, y) in enumerate(dataloader): + model.train() + iteration += 1 optimizer.zero_grad() - x = torch.tensor(x).to(args.device) + X = torch.tensor(X).to(args.device) y = torch.tensor(y).to(args.device) # Compute prediction error - logits, (state_h, state_c) = net(x, (state_h, state_c)) + logits, (state_h, state_c) = model(X, (state_h, state_c)) loss = loss_fn(logits.transpose(1, 2), y) loss_value = loss.item() @@ -68,8 +68,8 @@ for t in range(args.epochs): if iteration % 10 == 0: print('Epoch: {}/{}'.format(t, args.epochs), - 'Iteration: {}'.format(iteration), - 'Loss: {}'.format(loss_value)) + 'Iteration: {}'.format(iteration), + 'Loss: {}'.format(loss_value)) if iteration % 1000 == 0: predict(args.device, net, flags.initial_words, n_vocab,