diff --git a/predict.py b/predict.py index 53447b9..9ef453d 100644 --- a/predict.py +++ b/predict.py @@ -2,7 +2,7 @@ import numpy as np import torch -def predict(device, dataset, model, text, next_words=100, top_k=3): +def predict(device, dataset, model, text, next_words=100, top_k=5): model.eval() words = text.split()