28 lines
950 B
Python
28 lines
950 B
Python
from collections import Counter
|
|
|
|
import torch
|
|
|
|
|
|
class Dataset(torch.utils.data.Dataset):
|
|
def __init__(self, file, seq_size):
|
|
with open(file, 'r') as f:
|
|
self.words = f.read().split()
|
|
|
|
self.word_counts = Counter(self.words)
|
|
self.uniq_words = sorted(self.word_counts, key=self.word_counts.get)
|
|
|
|
self.index_to_word = {index: word for index,
|
|
word in enumerate(self.uniq_words)}
|
|
self.word_to_index = {word: index for index,
|
|
word in enumerate(self.uniq_words)}
|
|
|
|
self.words_indexes = [self.word_to_index[w] for w in self.words]
|
|
|
|
self.seq_size = seq_size
|
|
|
|
def __len__(self):
|
|
return len(self.words_indexes) - self.seq_size
|
|
|
|
def __getitem__(self, index):
|
|
return (torch.tensor(self.words_indexes[index:index+self.seq_size]),
|
|
torch.tensor(self.words_indexes[index+1:index+self.seq_size+1]))
|