Set output directory correctly
This commit is contained in:
parent
1c43115cd6
commit
8dab77d61b
6
train.py
6
train.py
|
@ -2,7 +2,7 @@ from argparse import ArgumentParser
|
|||
from itertools import chain
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, default_data_collator
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, default_data_collator
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
@ -44,6 +44,6 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True)
|
|||
|
||||
# Create and train the model
|
||||
model = AutoModelForCausalLM.from_pretrained('distilgpt2')
|
||||
trainer = Trainer(model=model, train_dataset=lm_dataset['train'], tokenizer=tokenizer, data_collator=default_data_collator)
|
||||
trainer = Trainer(model, TrainingArguments(output_dir=args.output), default_data_collator, lm_dataset['train'])
|
||||
trainer.train()
|
||||
trainer.save_model(args.output)
|
||||
trainer.save_model()
|
||||
|
|
Loading…
Reference in a new issue