ebooks/predict.py

29 lines
768 B
Python
Raw Normal View History

2022-02-21 21:20:09 +00:00
import numpy as np
import torch
2022-02-21 22:39:58 +00:00
def predict(device, dataset, model, text, next_words=100, top_k=5):
2022-02-21 21:20:09 +00:00
model.eval()
2022-02-21 21:49:39 +00:00
words = text.split()
2022-02-21 21:20:09 +00:00
state_h, state_c = model.zero_state(1)
state_h = state_h.to(device)
state_c = state_c.to(device)
2022-02-21 21:49:39 +00:00
for word in words:
2022-02-21 21:20:09 +00:00
ix = torch.tensor([[dataset.word_to_index[word]]]).to(device)
output, (state_h, state_c) = model(ix, (state_h, state_c))
2022-02-21 21:33:17 +00:00
2022-02-21 21:20:09 +00:00
for i in range(next_words):
_, top_ix = torch.topk(output[0], k=top_k)
choices = top_ix.tolist()
choice = np.random.choice(choices[0])
words.append(dataset.index_to_word[choice])
ix = torch.tensor([[choice]]).to(device)
2022-02-21 21:49:39 +00:00
output, (state_h, state_c) = model(ix, (state_h, state_c))
2022-02-21 21:20:09 +00:00
return words