from argparse import ArgumentParser from random import randint, choice from torch import float16 from transformers import AutoTokenizer, AutoModelForCausalLM parser = ArgumentParser() parser.add_argument('-b', '--backend', choices=['mastodon', 'misskey', 'matrix', 'none'], default='mastodon', help='fediverse server type') parser.add_argument('-i', '--instance', help='Mastodon instance hosting the bot') parser.add_argument('-t', '--token', help='Mastodon application access token') parser.add_argument('-n', '--input', help='initial input text') parser.add_argument('-d', '--data', default='data', help='data for automatic input generation') parser.add_argument('-m', '--model', default='model', help='path to load saved model') args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained('gpt2-large') model = AutoModelForCausalLM.from_pretrained(args.model).to('cuda') if args.input is None: # Create random input if randint(0, 1) == 0: args.input = choice([ 'I am', 'My life is', 'Computers are', 'This is', 'My', 'I\'ve', 'No one', 'I love', 'I will die of', 'I', 'The', 'Anime', 'I\'m going to die', 'Hello', '@ta180m@exozy.me', 'Life', 'My favorite', 'I\'m not', 'I hate', 'I think', 'In my opinion', 'Breaking news:', 'Have I ever told you that', 'I read on the news that', 'I never knew that', 'My dream is', 'It\'s terrible that' ]) else: with open(args.data, 'r') as f: # Get a line with at least two words lines = f.readlines() line = choice(lines).split() while len(line) < 2: line = choice(lines).split() # Remove mentions if line[0].count('@') > 1: line[0] = '@'.join(line[0].split('@')[0:2]) if line[1].count('@') > 1: line[1] = '@'.join(line[1].split('@')[0:2]) args.input = line[0] + ' ' + line[1] # Run the input through the model print(args.input) inputs = tokenizer.encode(args.input, return_tensors='pt').to('cuda') output = tokenizer.decode(model.generate( inputs, do_sample=True, max_length=150, top_p=0.9)[0]) print(output) # Prepare the post output = output.split('\n') post = output[0] if len(post) < 200 and len(output) > 1: post = output[0] + '\n' + output[1] post = post[:500] # Post it! if args.backend == 'mastodon': from mastodon import Mastodon mastodon = Mastodon( access_token=args.token, api_base_url=args.instance ) mastodon.status_post(post) elif args.backend == 'misskey': from Misskey import Misskey misskey = Misskey(args.instance, i=args.token) misskey.notes_create(post) elif args.backend == 'matrix': import simplematrixbotlib as botlib creds = botlib.Creds(args.instance, 'ebooks', args.token) bot = botlib.Bot(creds) @bot.listener.on_startup async def room_joined(room_id): await bot.api.send_text_message(room_id=room_id, message=post) bot.run()