22 lines
770 B
Python
22 lines
770 B
Python
from argparse import ArgumentParser
|
|
|
|
from datasets import load_dataset
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
|
|
|
|
|
|
parser = ArgumentParser()
|
|
parser.add_argument('-i', '--input', default='data',
|
|
help='training data input file')
|
|
args = parser.parse_args()
|
|
|
|
|
|
raw_dataset = load_dataset('text', data_files={'train': args.input}, keep_linebreaks=True)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
|
|
tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns=raw_dataset["train"].column_names)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained('distilgpt2')
|
|
|
|
trainer = Trainer(model=model, train_dataset=tokenized_dataset['train'], tokenizer=tokenizer)
|
|
trainer.train()
|