import numpy as np import torch def predict(device, dataset, model, text, next_words=100, top_k=5): model.eval() words = text.split() state_h, state_c = model.zero_state(1) state_h = state_h.to(device) state_c = state_c.to(device) for word in words: ix = torch.tensor([[dataset.word_to_index[word]]]).to(device) output, (state_h, state_c) = model(ix, (state_h, state_c)) for i in range(next_words): _, top_ix = torch.topk(output[0], k=top_k) choices = top_ix.tolist() choice = np.random.choice(choices[0]) words.append(dataset.index_to_word[choice]) ix = torch.tensor([[choice]]).to(device) output, (state_h, state_c) = model(ix, (state_h, state_c)) return words