35 lines
1.0 KiB
Python
35 lines
1.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, dataset, embedding_dim, lstm_size, num_layers, dropout):
|
|
super(Model, self).__init__()
|
|
|
|
self.seq_size = dataset.seq_size
|
|
self.lstm_size = lstm_size
|
|
self.num_layers = num_layers
|
|
|
|
n_vocab = len(dataset.uniq_words)
|
|
self.embedding = nn.Embedding(n_vocab, embedding_dim)
|
|
|
|
self.lstm = nn.LSTM(
|
|
input_size=embedding_dim,
|
|
hidden_size=lstm_size,
|
|
num_layers=num_layers,
|
|
batch_first=True,
|
|
dropout=dropout
|
|
)
|
|
self.dense = nn.Linear(lstm_size, n_vocab)
|
|
|
|
def forward(self, x, prev_state):
|
|
embed = self.embedding(x)
|
|
output, state = self.lstm(embed, prev_state)
|
|
logits = self.dense(output)
|
|
|
|
return logits, state
|
|
|
|
def zero_state(self, batch_size):
|
|
return (torch.zeros(self.num_layers, batch_size, self.lstm_size),
|
|
torch.zeros(self.num_layers, batch_size, self.lstm_size))
|