ebooks/dataset.py
2022-02-21 15:20:00 -06:00

29 lines
950 B
Python

from collections import Counter
import torch
class Dataset(torch.utils.data.Dataset):
def __init__(self, file, seq_size):
with open(file, 'r') as f:
self.words = f.read().split()
self.word_counts = Counter(self.words)
self.uniq_words = sorted(self.word_counts, key=self.word_counts.get)
self.index_to_word = {index: word for index,
word in enumerate(self.uniq_words)}
self.word_to_index = {word: index for index,
word in enumerate(self.uniq_words)}
self.words_indexes = [self.word_to_index[w] for w in self.words]
self.seq_size = seq_size
def __len__(self):
return len(self.words_indexes) - self.seq_size
def __getitem__(self, index):
return (torch.tensor(self.words_indexes[index:index+self.seq_size]),
torch.tensor(self.words_indexes[index+1:index+self.seq_size+1]))