Add new transformers bot script and move old one to bot_lstm.py

This commit is contained in:
Anthony Wang 2022-02-22 16:58:19 -06:00
parent edd4708123
commit 6bab795fe8
Signed by: a
GPG key ID: BC96B00AEC5F2D76
2 changed files with 53 additions and 22 deletions

38
bot.py
View file

@ -1,37 +1,31 @@
from argparse import ArgumentParser from argparse import ArgumentParser
import torch
from mastodon import Mastodon from mastodon import Mastodon
from transformers import AutoTokenizer, AutoModelForCausalLM
from dataset import Dataset
from model import Model
from predict import predict
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('-t', '--token', help='Mastodon application access token') parser.add_argument('-t', '--token', help='Mastodon application access token')
parser.add_argument('-i', '--input', default='data', parser.add_argument('-i', '--input', default='i am',
help='training data input file') help='initial input text for prediction')
parser.add_argument('-e', '--text', default='i am', parser.add_argument('-m', '--model', default='model',
help='initial text for prediction')
parser.add_argument('-d', '--device', default='cpu',
help='device to run the model with')
parser.add_argument('-m', '--model', default='model.pt',
help='path to load saved model') help='path to load saved model')
args = parser.parse_args() args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
model = AutoModelForCausalLM.from_pretrained(args.model)
# Run the input through the model
inputs = tokenizer.encode(args.input, return_tensors="pt")
output = tokenizer.decode(model.generate(inputs, do_sample=True, max_length=25, top_p=0.9, temperature=0.8)[0])
print(output)
# Post it to Mastodon
mastodon = Mastodon( mastodon = Mastodon(
access_token=args.token, access_token=args.token,
api_base_url='https://social.exozy.me/' api_base_url='https://social.exozy.me/'
) )
mastodon.status_post(output)
dataset = Dataset(args.input, 32)
device = torch.device(args.device)
model = torch.load(args.model)
text = predict(device, model, args.text)
print(text)
# mastodon.status_post(text)

37
bot_lstm.py Normal file
View file

@ -0,0 +1,37 @@
from argparse import ArgumentParser
import torch
from mastodon import Mastodon
from dataset import Dataset
from model import Model
from predict import predict
parser = ArgumentParser()
parser.add_argument('-t', '--token', help='Mastodon application access token')
parser.add_argument('-i', '--input', default='data',
help='training data input file')
parser.add_argument('-e', '--text', default='i am',
help='initial text for prediction')
parser.add_argument('-d', '--device', default='cpu',
help='device to run the model with')
parser.add_argument('-m', '--model', default='model.pt',
help='path to load saved model')
args = parser.parse_args()
mastodon = Mastodon(
access_token=args.token,
api_base_url='https://social.exozy.me/'
)
dataset = Dataset(args.input, 32)
device = torch.device(args.device)
model = torch.load(args.model)
text = predict(device, model, args.text)
print(text)
# mastodon.status_post(text)