More cleanup
This commit is contained in:
parent
423c1d8304
commit
8b86ce3f65
|
@ -5,8 +5,6 @@ import torch
|
||||||
|
|
||||||
class Dataset(torch.utils.data.Dataset):
|
class Dataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, file, seq_size):
|
def __init__(self, file, seq_size):
|
||||||
self.seq_size = seq_size
|
|
||||||
|
|
||||||
with open(file, 'r') as f:
|
with open(file, 'r') as f:
|
||||||
self.words = f.read().split()
|
self.words = f.read().split()
|
||||||
|
|
||||||
|
@ -20,6 +18,8 @@ class Dataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
self.words_indexes = [self.word_to_index[w] for w in self.words]
|
self.words_indexes = [self.word_to_index[w] for w in self.words]
|
||||||
|
|
||||||
|
self.seq_size = seq_size
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.words_indexes) - self.seq_size
|
return len(self.words_indexes) - self.seq_size
|
||||||
|
|
||||||
|
|
7
model.py
7
model.py
|
@ -1,8 +1,9 @@
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, dataset, embedding_size, lstm_size, num_layers, dropout):
|
def __init__(self, dataset, embedding_dim, lstm_size, num_layers, dropout):
|
||||||
super(Model, self).__init__()
|
super(Model, self).__init__()
|
||||||
|
|
||||||
self.seq_size = dataset.seq_size
|
self.seq_size = dataset.seq_size
|
||||||
|
@ -10,10 +11,10 @@ class Model(nn.Module):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
n_vocab = len(dataset.uniq_words)
|
n_vocab = len(dataset.uniq_words)
|
||||||
self.embedding = nn.Embedding(n_vocab, embedding_size)
|
self.embedding = nn.Embedding(n_vocab, embedding_dim)
|
||||||
|
|
||||||
self.lstm = nn.LSTM(
|
self.lstm = nn.LSTM(
|
||||||
input_size=embedding_size,
|
input_size=embedding_dim,
|
||||||
hidden_size=lstm_size,
|
hidden_size=lstm_size,
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
|
|
2
test.py
2
test.py
|
@ -12,7 +12,7 @@ from argparse import Namespace
|
||||||
|
|
||||||
|
|
||||||
flags = Namespace(
|
flags = Namespace(
|
||||||
train_file='text-small',
|
train_file='data',
|
||||||
seq_size=32,
|
seq_size=32,
|
||||||
batch_size=256,
|
batch_size=256,
|
||||||
embedding_size=64,
|
embedding_size=64,
|
||||||
|
|
14
train.py
14
train.py
|
@ -33,24 +33,24 @@ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
|
||||||
for t in range(args.epochs):
|
for t in range(args.epochs):
|
||||||
model.train()
|
state_h, state_c = model.zero_state(args.batch_size)
|
||||||
|
|
||||||
state_h, state_c = net.zero_state(flags.batch_size)
|
|
||||||
state_h = state_h.to(args.device)
|
state_h = state_h.to(args.device)
|
||||||
state_c = state_c.to(args.device)
|
state_c = state_c.to(args.device)
|
||||||
|
|
||||||
iteration = 0
|
iteration = 0
|
||||||
|
|
||||||
for batch, (X, y) in enumerate(dataloader):
|
for batch, (X, y) in enumerate(dataloader):
|
||||||
|
model.train()
|
||||||
|
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
x = torch.tensor(x).to(args.device)
|
X = torch.tensor(X).to(args.device)
|
||||||
y = torch.tensor(y).to(args.device)
|
y = torch.tensor(y).to(args.device)
|
||||||
|
|
||||||
# Compute prediction error
|
# Compute prediction error
|
||||||
logits, (state_h, state_c) = net(x, (state_h, state_c))
|
logits, (state_h, state_c) = model(X, (state_h, state_c))
|
||||||
loss = loss_fn(logits.transpose(1, 2), y)
|
loss = loss_fn(logits.transpose(1, 2), y)
|
||||||
|
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
|
@ -68,8 +68,8 @@ for t in range(args.epochs):
|
||||||
|
|
||||||
if iteration % 10 == 0:
|
if iteration % 10 == 0:
|
||||||
print('Epoch: {}/{}'.format(t, args.epochs),
|
print('Epoch: {}/{}'.format(t, args.epochs),
|
||||||
'Iteration: {}'.format(iteration),
|
'Iteration: {}'.format(iteration),
|
||||||
'Loss: {}'.format(loss_value))
|
'Loss: {}'.format(loss_value))
|
||||||
|
|
||||||
if iteration % 1000 == 0:
|
if iteration % 1000 == 0:
|
||||||
predict(args.device, net, flags.initial_words, n_vocab,
|
predict(args.device, net, flags.initial_words, n_vocab,
|
||||||
|
|
Loading…
Reference in a new issue