More cleanup
This commit is contained in:
parent
423c1d8304
commit
8b86ce3f65
|
@ -5,8 +5,6 @@ import torch
|
|||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, file, seq_size):
|
||||
self.seq_size = seq_size
|
||||
|
||||
with open(file, 'r') as f:
|
||||
self.words = f.read().split()
|
||||
|
||||
|
@ -20,6 +18,8 @@ class Dataset(torch.utils.data.Dataset):
|
|||
|
||||
self.words_indexes = [self.word_to_index[w] for w in self.words]
|
||||
|
||||
self.seq_size = seq_size
|
||||
|
||||
def __len__(self):
|
||||
return len(self.words_indexes) - self.seq_size
|
||||
|
||||
|
|
7
model.py
7
model.py
|
@ -1,8 +1,9 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, dataset, embedding_size, lstm_size, num_layers, dropout):
|
||||
def __init__(self, dataset, embedding_dim, lstm_size, num_layers, dropout):
|
||||
super(Model, self).__init__()
|
||||
|
||||
self.seq_size = dataset.seq_size
|
||||
|
@ -10,10 +11,10 @@ class Model(nn.Module):
|
|||
self.num_layers = num_layers
|
||||
|
||||
n_vocab = len(dataset.uniq_words)
|
||||
self.embedding = nn.Embedding(n_vocab, embedding_size)
|
||||
self.embedding = nn.Embedding(n_vocab, embedding_dim)
|
||||
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=embedding_size,
|
||||
input_size=embedding_dim,
|
||||
hidden_size=lstm_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
|
|
2
test.py
2
test.py
|
@ -12,7 +12,7 @@ from argparse import Namespace
|
|||
|
||||
|
||||
flags = Namespace(
|
||||
train_file='text-small',
|
||||
train_file='data',
|
||||
seq_size=32,
|
||||
batch_size=256,
|
||||
embedding_size=64,
|
||||
|
|
10
train.py
10
train.py
|
@ -33,24 +33,24 @@ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|||
|
||||
|
||||
for t in range(args.epochs):
|
||||
model.train()
|
||||
|
||||
state_h, state_c = net.zero_state(flags.batch_size)
|
||||
state_h, state_c = model.zero_state(args.batch_size)
|
||||
state_h = state_h.to(args.device)
|
||||
state_c = state_c.to(args.device)
|
||||
|
||||
iteration = 0
|
||||
|
||||
for batch, (X, y) in enumerate(dataloader):
|
||||
model.train()
|
||||
|
||||
iteration += 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
x = torch.tensor(x).to(args.device)
|
||||
X = torch.tensor(X).to(args.device)
|
||||
y = torch.tensor(y).to(args.device)
|
||||
|
||||
# Compute prediction error
|
||||
logits, (state_h, state_c) = net(x, (state_h, state_c))
|
||||
logits, (state_h, state_c) = model(X, (state_h, state_c))
|
||||
loss = loss_fn(logits.transpose(1, 2), y)
|
||||
|
||||
loss_value = loss.item()
|
||||
|
|
Loading…
Reference in a new issue