Set output directory correctly

This commit is contained in:
Anthony Wang 2022-02-22 17:51:52 -06:00
parent 1c43115cd6
commit 8dab77d61b
Signed by: a
GPG key ID: BC96B00AEC5F2D76

View file

@ -2,7 +2,7 @@ from argparse import ArgumentParser
from itertools import chain
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, default_data_collator
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, default_data_collator
parser = ArgumentParser()
@ -44,6 +44,6 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True)
# Create and train the model
model = AutoModelForCausalLM.from_pretrained('distilgpt2')
trainer = Trainer(model=model, train_dataset=lm_dataset['train'], tokenizer=tokenizer, data_collator=default_data_collator)
trainer = Trainer(model, TrainingArguments(output_dir=args.output), default_data_collator, lm_dataset['train'])
trainer.train()
trainer.save_model(args.output)
trainer.save_model()