ebooks/train.py

22 lines
770 B
Python

from argparse import ArgumentParser
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
parser = ArgumentParser()
parser.add_argument('-i', '--input', default='data',
help='training data input file')
args = parser.parse_args()
raw_dataset = load_dataset('text', data_files={'train': args.input}, keep_linebreaks=True)
tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns=raw_dataset["train"].column_names)
model = AutoModelForCausalLM.from_pretrained('distilgpt2')
trainer = Trainer(model=model, train_dataset=tokenized_dataset['train'], tokenizer=tokenizer)
trainer.train()