More cleanup

This commit is contained in:
Anthony Wang 2022-02-21 15:20:00 -06:00
parent 423c1d8304
commit 8b86ce3f65
Signed by: a
GPG Key ID: BC96B00AEC5F2D76
4 changed files with 14 additions and 13 deletions

View File

@ -5,8 +5,6 @@ import torch
class Dataset(torch.utils.data.Dataset):
def __init__(self, file, seq_size):
self.seq_size = seq_size
with open(file, 'r') as f:
self.words = f.read().split()
@ -20,6 +18,8 @@ class Dataset(torch.utils.data.Dataset):
self.words_indexes = [self.word_to_index[w] for w in self.words]
self.seq_size = seq_size
def __len__(self):
return len(self.words_indexes) - self.seq_size

View File

@ -1,8 +1,9 @@
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, dataset, embedding_size, lstm_size, num_layers, dropout):
def __init__(self, dataset, embedding_dim, lstm_size, num_layers, dropout):
super(Model, self).__init__()
self.seq_size = dataset.seq_size
@ -10,10 +11,10 @@ class Model(nn.Module):
self.num_layers = num_layers
n_vocab = len(dataset.uniq_words)
self.embedding = nn.Embedding(n_vocab, embedding_size)
self.embedding = nn.Embedding(n_vocab, embedding_dim)
self.lstm = nn.LSTM(
input_size=embedding_size,
input_size=embedding_dim,
hidden_size=lstm_size,
num_layers=num_layers,
batch_first=True,

View File

@ -12,7 +12,7 @@ from argparse import Namespace
flags = Namespace(
train_file='text-small',
train_file='data',
seq_size=32,
batch_size=256,
embedding_size=64,

View File

@ -33,24 +33,24 @@ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(args.epochs):
model.train()
state_h, state_c = net.zero_state(flags.batch_size)
state_h, state_c = model.zero_state(args.batch_size)
state_h = state_h.to(args.device)
state_c = state_c.to(args.device)
iteration = 0
for batch, (X, y) in enumerate(dataloader):
model.train()
iteration += 1
optimizer.zero_grad()
x = torch.tensor(x).to(args.device)
X = torch.tensor(X).to(args.device)
y = torch.tensor(y).to(args.device)
# Compute prediction error
logits, (state_h, state_c) = net(x, (state_h, state_c))
logits, (state_h, state_c) = model(X, (state_h, state_c))
loss = loss_fn(logits.transpose(1, 2), y)
loss_value = loss.item()
@ -68,8 +68,8 @@ for t in range(args.epochs):
if iteration % 10 == 0:
print('Epoch: {}/{}'.format(t, args.epochs),
'Iteration: {}'.format(iteration),
'Loss: {}'.format(loss_value))
'Iteration: {}'.format(iteration),
'Loss: {}'.format(loss_value))
if iteration % 1000 == 0:
predict(args.device, net, flags.initial_words, n_vocab,