from argparse import ArgumentParser from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer parser = ArgumentParser() parser.add_argument('-i', '--input', default='data', help='training data input file') args = parser.parse_args() raw_dataset = load_dataset('text', data_files={'train': args.input}, keep_linebreaks=True) tokenizer = AutoTokenizer.from_pretrained('distilgpt2') tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns=raw_dataset["train"].column_names) model = AutoModelForCausalLM.from_pretrained('distilgpt2') trainer = Trainer(model=model, train_dataset=tokenized_dataset['train'], tokenizer=tokenizer) trainer.train()