Archived
1
0
Fork 0
This repository has been archived on 2024-04-26. You can view files and clone it, but cannot push or open issues or pull requests.
akari-bot/modules/ask/__init__.py

101 lines
3.7 KiB
Python
Raw Normal View History

2023-04-22 03:28:11 +00:00
import io
import re
2023-08-08 08:08:04 +00:00
from decimal import Decimal
2023-04-19 07:41:39 +00:00
from langchain.callbacks import get_openai_callback
2023-08-08 08:08:04 +00:00
from PIL import Image as PILImage
2023-04-19 07:41:39 +00:00
2023-07-13 04:26:10 +00:00
from config import Config
2023-04-22 03:28:11 +00:00
from core.builtins import Bot, Plain, Image
2023-04-19 07:41:39 +00:00
from core.component import module
2023-07-14 07:59:01 +00:00
from core.dirty_check import check_bool, rickroll
from core.exceptions import NoReportException
2023-08-08 07:36:41 +00:00
from database import BotDBUtil
2023-04-22 03:28:11 +00:00
from .agent import agent_executor
from .formatting import generate_latex, generate_code_snippet
2023-04-14 13:15:04 +00:00
ONE_K = Decimal('1000')
# https://openai.com/pricing
2023-04-30 03:30:59 +00:00
BASE_COST_GPT_3_5 = Decimal('0.002') # gpt-3.5-turbo $0.002 / 1K tokens
# We are not tracking specific tool usage like searches b/c I'm too lazy, use a universal multiplier
THIRD_PARTY_MULTIPLIER = Decimal('1.5')
PROFIT_MULTIPLIER = Decimal('1.1') # At the time we are really just trying to break even
2023-04-14 13:15:04 +00:00
PRICE_PER_1K_TOKEN = BASE_COST_GPT_3_5 * THIRD_PARTY_MULTIPLIER * PROFIT_MULTIPLIER
2023-07-08 09:24:35 +00:00
# Assuming 1 USD = 7.3 CNY, 100 petal = 1 CNY
USD_TO_CNY = Decimal('7.3')
2023-04-14 13:15:04 +00:00
CNY_TO_PETAL = 100
2023-07-08 09:24:35 +00:00
a = module('ask', developers=['Dianliang233'], desc='{ask.help.desc}')
2023-04-05 11:08:13 +00:00
@a.command('<question> {{ask.help}}')
2023-07-21 07:58:27 +00:00
@a.regex(r'^(?:question||问|問)[\:]\s?(.+?)[?]$', flags=re.I, desc='{ask.help.regex}')
async def _(msg: Bot.MessageSession):
2023-04-30 03:30:59 +00:00
is_superuser = msg.checkSuperUser()
2023-07-13 04:26:10 +00:00
if not Config('openai_api_key'):
raise Exception(msg.locale.t('error.config.secret'))
2023-08-11 16:27:35 +00:00
if not is_superuser and msg.data.petal <= 0: # refuse
2023-07-21 07:58:27 +00:00
raise NoReportException(msg.locale.t('core.message.petal.no_petals'))
2023-08-08 06:30:09 +00:00
2023-08-08 08:08:04 +00:00
qc = BotDBUtil.CoolDown(msg, 'call_openai')
c = qc.check(60)
if c == 0 or msg.target.targetFrom == 'TEST|Console' or is_superuser:
2023-08-08 06:30:09 +00:00
if hasattr(msg, 'parsed_msg'):
question = msg.parsed_msg['<question>']
else:
question = msg.matched_msg[0]
if await check_bool(question):
rickroll(msg)
with get_openai_callback() as cb:
res = await agent_executor.arun(question)
tokens = cb.total_tokens
if not is_superuser:
price = tokens / ONE_K * PRICE_PER_1K_TOKEN
petal = price * USD_TO_CNY * CNY_TO_PETAL
msg.data.modify_petal(-petal)
2023-04-22 03:28:11 +00:00
2023-08-08 06:30:09 +00:00
blocks = parse_markdown(res)
2023-04-22 03:28:11 +00:00
2023-08-08 06:30:09 +00:00
chain = []
for block in blocks:
if block['type'] == 'text':
chain.append(Plain(block['content']))
elif block['type'] == 'latex':
chain.append(Image(PILImage.open(io.BytesIO(await generate_latex(block['content'])))))
elif block['type'] == 'code':
chain.append(Image(PILImage.open(io.BytesIO(await generate_code_snippet(block['content']['code'], block['content']['language'])))))
2023-04-22 03:28:11 +00:00
2023-08-08 06:30:09 +00:00
if await check_bool(res):
rickroll(msg)
2023-08-08 07:57:05 +00:00
await msg.sendMessage(chain)
2023-08-08 07:48:15 +00:00
2023-08-08 07:55:35 +00:00
if msg.target.targetFrom != 'TEST|Console' and not is_superuser:
2023-08-08 07:48:15 +00:00
qc.reset()
2023-08-08 06:30:09 +00:00
else:
await msg.finish(msg.locale.t('ask.message.cooldown', time=int(c)))
2023-04-22 03:28:11 +00:00
2023-04-30 03:30:59 +00:00
2023-04-22 03:28:11 +00:00
def parse_markdown(md: str):
2023-06-20 06:40:50 +00:00
regex = r'(```[\s\S]*?\n```|\$[\s\S]*?\$|[^\n]+)'
2023-04-22 03:28:11 +00:00
blocks = []
for match in re.finditer(regex, md):
content = match.group(1)
print(content)
if content.startswith('```'):
block = 'code'
try:
language, code = re.match(r'```(.*)\n([\s\S]*?)\n```', content).groups()
except AttributeError:
raise ValueError('Code block is missing language or code')
content = {'language': language, 'code': code}
2023-06-20 06:36:00 +00:00
elif content.startswith('$'):
2023-04-22 03:28:11 +00:00
block = 'latex'
2023-06-20 06:40:50 +00:00
content = content[1:-1].strip()
2023-04-22 03:28:11 +00:00
else:
block = 'text'
blocks.append({'type': block, 'content': content})
return blocks