Reformat code with autopep8
This commit is contained in:
parent
8dab77d61b
commit
f08e5bfc5f
6
bot.py
6
bot.py
|
@ -6,8 +6,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('-t', '--token', help='Mastodon application access token')
|
||||
parser.add_argument('-i', '--input', default='i am',
|
||||
help='initial input text for prediction')
|
||||
parser.add_argument('-i', '--input', help='initial input text for prediction')
|
||||
parser.add_argument('-m', '--model', default='model',
|
||||
help='path to load saved model')
|
||||
args = parser.parse_args()
|
||||
|
@ -19,7 +18,8 @@ model = AutoModelForCausalLM.from_pretrained(args.model)
|
|||
|
||||
# Run the input through the model
|
||||
inputs = tokenizer.encode(args.input, return_tensors="pt")
|
||||
output = tokenizer.decode(model.generate(inputs, do_sample=True, max_length=25, top_p=0.9, temperature=0.8)[0])
|
||||
output = tokenizer.decode(model.generate(
|
||||
inputs, do_sample=True, max_length=25, top_p=0.9, temperature=0.8)[0])
|
||||
print(output)
|
||||
|
||||
|
||||
|
|
6
train.py
6
train.py
|
@ -16,7 +16,8 @@ args = parser.parse_args()
|
|||
# Load and tokenize dataset
|
||||
raw_dataset = load_dataset('text', data_files={'train': args.input}, keep_linebreaks=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained('distilgpt2', use_fast=True)
|
||||
tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns='text')
|
||||
tokenized_dataset = raw_dataset.map(lambda examples: tokenizer(examples['text']),
|
||||
batched=True, remove_columns='text')
|
||||
|
||||
|
||||
# Generate chunks of block_size
|
||||
|
@ -44,6 +45,7 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True)
|
|||
|
||||
# Create and train the model
|
||||
model = AutoModelForCausalLM.from_pretrained('distilgpt2')
|
||||
trainer = Trainer(model, TrainingArguments(output_dir=args.output), default_data_collator, lm_dataset['train'])
|
||||
trainer = Trainer(model, TrainingArguments(output_dir=args.output),
|
||||
default_data_collator, lm_dataset['train'])
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
|
|
Loading…
Reference in a new issue