Add new transformers bot script and move old one to bot_lstm.py
This commit is contained in:
parent
edd4708123
commit
6bab795fe8
2 changed files with 53 additions and 22 deletions
38
bot.py
38
bot.py
|
@ -1,37 +1,31 @@
|
|||
from argparse import ArgumentParser
|
||||
|
||||
import torch
|
||||
from mastodon import Mastodon
|
||||
|
||||
from dataset import Dataset
|
||||
from model import Model
|
||||
from predict import predict
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('-t', '--token', help='Mastodon application access token')
|
||||
parser.add_argument('-i', '--input', default='data',
|
||||
help='training data input file')
|
||||
parser.add_argument('-e', '--text', default='i am',
|
||||
help='initial text for prediction')
|
||||
parser.add_argument('-d', '--device', default='cpu',
|
||||
help='device to run the model with')
|
||||
parser.add_argument('-m', '--model', default='model.pt',
|
||||
parser.add_argument('-i', '--input', default='i am',
|
||||
help='initial input text for prediction')
|
||||
parser.add_argument('-m', '--model', default='model',
|
||||
help='path to load saved model')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model)
|
||||
|
||||
|
||||
# Run the input through the model
|
||||
inputs = tokenizer.encode(args.input, return_tensors="pt")
|
||||
output = tokenizer.decode(model.generate(inputs, do_sample=True, max_length=25, top_p=0.9, temperature=0.8)[0])
|
||||
print(output)
|
||||
|
||||
|
||||
# Post it to Mastodon
|
||||
mastodon = Mastodon(
|
||||
access_token=args.token,
|
||||
api_base_url='https://social.exozy.me/'
|
||||
)
|
||||
|
||||
|
||||
dataset = Dataset(args.input, 32)
|
||||
device = torch.device(args.device)
|
||||
model = torch.load(args.model)
|
||||
|
||||
|
||||
text = predict(device, model, args.text)
|
||||
print(text)
|
||||
# mastodon.status_post(text)
|
||||
mastodon.status_post(output)
|
||||
|
|
37
bot_lstm.py
Normal file
37
bot_lstm.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from argparse import ArgumentParser
|
||||
|
||||
import torch
|
||||
from mastodon import Mastodon
|
||||
|
||||
from dataset import Dataset
|
||||
from model import Model
|
||||
from predict import predict
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('-t', '--token', help='Mastodon application access token')
|
||||
parser.add_argument('-i', '--input', default='data',
|
||||
help='training data input file')
|
||||
parser.add_argument('-e', '--text', default='i am',
|
||||
help='initial text for prediction')
|
||||
parser.add_argument('-d', '--device', default='cpu',
|
||||
help='device to run the model with')
|
||||
parser.add_argument('-m', '--model', default='model.pt',
|
||||
help='path to load saved model')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
mastodon = Mastodon(
|
||||
access_token=args.token,
|
||||
api_base_url='https://social.exozy.me/'
|
||||
)
|
||||
|
||||
|
||||
dataset = Dataset(args.input, 32)
|
||||
device = torch.device(args.device)
|
||||
model = torch.load(args.model)
|
||||
|
||||
|
||||
text = predict(device, model, args.text)
|
||||
print(text)
|
||||
# mastodon.status_post(text)
|
Loading…
Reference in a new issue