ebooks/model.py

35 lines
1 KiB
Python
Raw Normal View History

2022-02-21 21:20:00 +00:00
import torch
2022-02-21 19:02:37 +00:00
import torch.nn as nn
class Model(nn.Module):
2022-02-21 21:20:00 +00:00
def __init__(self, dataset, embedding_dim, lstm_size, num_layers, dropout):
2022-02-21 19:02:37 +00:00
super(Model, self).__init__()
2022-02-21 21:03:28 +00:00
self.seq_size = dataset.seq_size
2022-02-21 19:02:37 +00:00
self.lstm_size = lstm_size
self.num_layers = num_layers
n_vocab = len(dataset.uniq_words)
2022-02-21 21:20:00 +00:00
self.embedding = nn.Embedding(n_vocab, embedding_dim)
2022-02-21 21:03:28 +00:00
2022-02-21 19:02:37 +00:00
self.lstm = nn.LSTM(
2022-02-21 21:20:00 +00:00
input_size=embedding_dim,
2022-02-21 19:02:37 +00:00
hidden_size=lstm_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout
)
self.dense = nn.Linear(lstm_size, n_vocab)
def forward(self, x, prev_state):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.dense(output)
return logits, state
def zero_state(self, batch_size):
return (torch.zeros(self.num_layers, batch_size, self.lstm_size),
2022-02-21 21:03:28 +00:00
torch.zeros(self.num_layers, batch_size, self.lstm_size))