Add new transformers bot script and move old one to bot_lstm.py
This commit is contained in:
parent
edd4708123
commit
6bab795fe8
38
bot.py
38
bot.py
|
@ -1,37 +1,31 @@
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
import torch
|
|
||||||
from mastodon import Mastodon
|
from mastodon import Mastodon
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
from dataset import Dataset
|
|
||||||
from model import Model
|
|
||||||
from predict import predict
|
|
||||||
|
|
||||||
|
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument('-t', '--token', help='Mastodon application access token')
|
parser.add_argument('-t', '--token', help='Mastodon application access token')
|
||||||
parser.add_argument('-i', '--input', default='data',
|
parser.add_argument('-i', '--input', default='i am',
|
||||||
help='training data input file')
|
help='initial input text for prediction')
|
||||||
parser.add_argument('-e', '--text', default='i am',
|
parser.add_argument('-m', '--model', default='model',
|
||||||
help='initial text for prediction')
|
|
||||||
parser.add_argument('-d', '--device', default='cpu',
|
|
||||||
help='device to run the model with')
|
|
||||||
parser.add_argument('-m', '--model', default='model.pt',
|
|
||||||
help='path to load saved model')
|
help='path to load saved model')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.model)
|
||||||
|
|
||||||
|
|
||||||
|
# Run the input through the model
|
||||||
|
inputs = tokenizer.encode(args.input, return_tensors="pt")
|
||||||
|
output = tokenizer.decode(model.generate(inputs, do_sample=True, max_length=25, top_p=0.9, temperature=0.8)[0])
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
|
# Post it to Mastodon
|
||||||
mastodon = Mastodon(
|
mastodon = Mastodon(
|
||||||
access_token=args.token,
|
access_token=args.token,
|
||||||
api_base_url='https://social.exozy.me/'
|
api_base_url='https://social.exozy.me/'
|
||||||
)
|
)
|
||||||
|
mastodon.status_post(output)
|
||||||
|
|
||||||
dataset = Dataset(args.input, 32)
|
|
||||||
device = torch.device(args.device)
|
|
||||||
model = torch.load(args.model)
|
|
||||||
|
|
||||||
|
|
||||||
text = predict(device, model, args.text)
|
|
||||||
print(text)
|
|
||||||
# mastodon.status_post(text)
|
|
||||||
|
|
37
bot_lstm.py
Normal file
37
bot_lstm.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from mastodon import Mastodon
|
||||||
|
|
||||||
|
from dataset import Dataset
|
||||||
|
from model import Model
|
||||||
|
from predict import predict
|
||||||
|
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('-t', '--token', help='Mastodon application access token')
|
||||||
|
parser.add_argument('-i', '--input', default='data',
|
||||||
|
help='training data input file')
|
||||||
|
parser.add_argument('-e', '--text', default='i am',
|
||||||
|
help='initial text for prediction')
|
||||||
|
parser.add_argument('-d', '--device', default='cpu',
|
||||||
|
help='device to run the model with')
|
||||||
|
parser.add_argument('-m', '--model', default='model.pt',
|
||||||
|
help='path to load saved model')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
mastodon = Mastodon(
|
||||||
|
access_token=args.token,
|
||||||
|
api_base_url='https://social.exozy.me/'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
dataset = Dataset(args.input, 32)
|
||||||
|
device = torch.device(args.device)
|
||||||
|
model = torch.load(args.model)
|
||||||
|
|
||||||
|
|
||||||
|
text = predict(device, model, args.text)
|
||||||
|
print(text)
|
||||||
|
# mastodon.status_post(text)
|
Loading…
Reference in a new issue