from collections import Counter import torch class Dataset(torch.utils.data.Dataset): def __init__(self, file, seq_size): with open(file, 'r') as f: self.words = f.read().split() self.word_counts = Counter(self.words) self.uniq_words = sorted(self.word_counts, key=self.word_counts.get) self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)} self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)} self.words_indexes = [self.word_to_index[w] for w in self.words] self.seq_size = seq_size def __len__(self): return len(self.words_indexes) - self.seq_size def __getitem__(self, index): return (torch.tensor(self.words_indexes[index:index+self.seq_size]), torch.tensor(self.words_indexes[index+1:index+self.seq_size+1]))