diff --git a/bot.py b/bot.py index c2a2530..8840a28 100644 --- a/bot.py +++ b/bot.py @@ -1,37 +1,31 @@ from argparse import ArgumentParser -import torch from mastodon import Mastodon - -from dataset import Dataset -from model import Model -from predict import predict +from transformers import AutoTokenizer, AutoModelForCausalLM parser = ArgumentParser() parser.add_argument('-t', '--token', help='Mastodon application access token') -parser.add_argument('-i', '--input', default='data', - help='training data input file') -parser.add_argument('-e', '--text', default='i am', - help='initial text for prediction') -parser.add_argument('-d', '--device', default='cpu', - help='device to run the model with') -parser.add_argument('-m', '--model', default='model.pt', +parser.add_argument('-i', '--input', default='i am', + help='initial input text for prediction') +parser.add_argument('-m', '--model', default='model', help='path to load saved model') args = parser.parse_args() +tokenizer = AutoTokenizer.from_pretrained('distilgpt2') +model = AutoModelForCausalLM.from_pretrained(args.model) + + +# Run the input through the model +inputs = tokenizer.encode(args.input, return_tensors="pt") +output = tokenizer.decode(model.generate(inputs, do_sample=True, max_length=25, top_p=0.9, temperature=0.8)[0]) +print(output) + + +# Post it to Mastodon mastodon = Mastodon( access_token=args.token, api_base_url='https://social.exozy.me/' ) - - -dataset = Dataset(args.input, 32) -device = torch.device(args.device) -model = torch.load(args.model) - - -text = predict(device, model, args.text) -print(text) -# mastodon.status_post(text) +mastodon.status_post(output) diff --git a/bot_lstm.py b/bot_lstm.py new file mode 100644 index 0000000..c2a2530 --- /dev/null +++ b/bot_lstm.py @@ -0,0 +1,37 @@ +from argparse import ArgumentParser + +import torch +from mastodon import Mastodon + +from dataset import Dataset +from model import Model +from predict import predict + + +parser = ArgumentParser() +parser.add_argument('-t', '--token', help='Mastodon application access token') +parser.add_argument('-i', '--input', default='data', + help='training data input file') +parser.add_argument('-e', '--text', default='i am', + help='initial text for prediction') +parser.add_argument('-d', '--device', default='cpu', + help='device to run the model with') +parser.add_argument('-m', '--model', default='model.pt', + help='path to load saved model') +args = parser.parse_args() + + +mastodon = Mastodon( + access_token=args.token, + api_base_url='https://social.exozy.me/' +) + + +dataset = Dataset(args.input, 32) +device = torch.device(args.device) +model = torch.load(args.model) + + +text = predict(device, model, args.text) +print(text) +# mastodon.status_post(text)