154 lines
4.3 KiB

from argparse import ArgumentParser
from asyncio import create_task, sleep, run
from random import randint, choice
from re import sub
from transformers import AutoTokenizer, AutoModelForCausalLM
parser = ArgumentParser()
parser.add_argument('-n', '--input', help='initial input text')
parser.add_argument('-b', '--backend', choices=['mastodon', 'misskey', 'matrix'],
action='append', help='fediverse server type')
parser.add_argument('-i', '--instance', action='append',
help='Mastodon instance hosting the bot')
parser.add_argument('-t', '--token', action='append',
help='Mastodon application access token')
parser.add_argument('-d', '--data', default='data',
help='data for automatic input generation')
parser.add_argument('-m', '--model', default='model',
help='path to load saved model')
parser.add_argument('-y', '--yes', action='store_true',
help='answer yes to all prompts')
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained('gpt2-large')
model = AutoModelForCausalLM.from_pretrained(args.model).to('cuda')
def generate_input():
# Create random input
if randint(0, 1) == 0:
return choice([
'I am',
'My life is',
'Computers are',
'This is',
'No one',
'I love',
'I will die of',
'I\'m going to die',
'My favorite',
'I\'m not',
'I hate',
'I think',
'In my opinion',
'Breaking news:',
'Have I ever told you that',
'I read on the news that',
'I never knew that',
'My dream is',
'It\'s terrible that',
'My new theory:',
'My conspiracy theory',
'The worst thing'
with open(args.data, 'r') as f:
# Get a line with at least two words
lines = f.readlines()
line = choice(lines).split()
while len(line) < 2:
line = choice(lines).split()
return line[0] + ' ' + line[1]
if args.input is None:
args.input = generate_input()
# Loop until we're satisfied
while True:
# Run the input through the model
inputs = tokenizer.encode(args.input, return_tensors='pt').to('cuda')
output = tokenizer.decode(model.generate(
inputs, max_length=150, do_sample=True, top_p=0.9)[0])
# Prepare the post
output = output.split('\n')
post = output[0]
if len(post) < 200 and len(output) > 1:
post = output[0] + '\n' + output[1]
post = post[:500]
# Remove mentions
post = sub('(@[^ ]*)@[^ ]*', '\\1', post)
if args.yes:
# Prompt the user
res = input('Post/Retry/New input/Custom input/Quit: ')
if res not in 'prnPRNcC':
if res in 'pP':
if res in 'nN':
args.input = generate_input()
if res in 'cC':
args.input = input('Enter custom input: ')
# Post it!
for backend, instance, token in zip(args.backend, args.instance, args.token):
if backend == 'mastodon':
from mastodon import Mastodon
mastodon = Mastodon(
elif backend == 'misskey':
from Misskey import Misskey
misskey = Misskey(instance, i=token)
elif backend == 'matrix':
import simplematrixbotlib as botlib
creds = botlib.Creds(instance, 'ebooks', token)
bot = botlib.Bot(creds)
async def room_joined(room_id):
await bot.api.send_text_message(room_id=room_id, message=post)
async def wait_quit():
await sleep(5)
async def run_bot():
run = create_task(bot.main())
wait = create_task(wait_quit())
await run
await wait