2023-04-22 03:28:11 +00:00
|
|
|
|
import io
|
2023-05-28 09:38:21 +00:00
|
|
|
|
import re
|
|
|
|
|
from decimal import Decimal
|
2023-03-31 09:38:53 +00:00
|
|
|
|
|
2023-05-28 09:38:21 +00:00
|
|
|
|
from PIL import Image as PILImage
|
2023-04-19 07:41:39 +00:00
|
|
|
|
from langchain.callbacks import get_openai_callback
|
|
|
|
|
|
2023-04-22 03:28:11 +00:00
|
|
|
|
from core.builtins import Bot, Plain, Image
|
2023-04-19 07:41:39 +00:00
|
|
|
|
from core.component import module
|
|
|
|
|
from core.dirty_check import check_bool
|
2023-03-31 09:38:53 +00:00
|
|
|
|
from core.exceptions import NoReportException
|
2023-04-22 03:28:11 +00:00
|
|
|
|
from .agent import agent_executor
|
|
|
|
|
from .formatting import generate_latex, generate_code_snippet
|
2023-03-31 09:38:53 +00:00
|
|
|
|
|
2023-04-14 13:15:04 +00:00
|
|
|
|
ONE_K = Decimal('1000')
|
|
|
|
|
# https://openai.com/pricing
|
2023-04-30 03:30:59 +00:00
|
|
|
|
BASE_COST_GPT_3_5 = Decimal('0.002') # gpt-3.5-turbo: $0.002 / 1K tokens
|
|
|
|
|
# We are not tracking specific tool usage like searches b/c I'm too lazy, use a universal multiplier
|
|
|
|
|
THIRD_PARTY_MULTIPLIER = Decimal('1.5')
|
|
|
|
|
PROFIT_MULTIPLIER = Decimal('1.1') # At the time we are really just trying to break even
|
2023-04-14 13:15:04 +00:00
|
|
|
|
PRICE_PER_1K_TOKEN = BASE_COST_GPT_3_5 * THIRD_PARTY_MULTIPLIER * PROFIT_MULTIPLIER
|
|
|
|
|
# Assuming 1 USD = 7 CNY, 100 petal = 1 CNY
|
|
|
|
|
USD_TO_CNY = 7
|
|
|
|
|
CNY_TO_PETAL = 100
|
|
|
|
|
|
2023-04-05 10:15:09 +00:00
|
|
|
|
a = module('ask', developers=['Dianliang233'], desc='{ask.help.desc}', required_superuser=True)
|
2023-03-31 09:38:53 +00:00
|
|
|
|
|
|
|
|
|
|
2023-04-05 11:08:13 +00:00
|
|
|
|
@a.command('<question> {{ask.help}}')
|
|
|
|
|
@a.regex(r'^(?:ask|问)[\::]? ?(.+?)[??]$')
|
2023-03-31 09:38:53 +00:00
|
|
|
|
async def _(msg: Bot.MessageSession):
|
2023-04-30 03:30:59 +00:00
|
|
|
|
is_superuser = msg.checkSuperUser()
|
|
|
|
|
if not is_superuser and msg.data.petal < 100: # refuse
|
2023-04-14 13:15:04 +00:00
|
|
|
|
raise NoReportException(msg.locale.t('petal_'))
|
2023-03-31 14:57:33 +00:00
|
|
|
|
if hasattr(msg, 'parsed_msg'):
|
|
|
|
|
question = msg.parsed_msg['<question>']
|
|
|
|
|
else:
|
|
|
|
|
question = msg.matched_msg[0]
|
2023-03-31 09:38:53 +00:00
|
|
|
|
if await check_bool(question):
|
|
|
|
|
raise NoReportException('https://wdf.ink/6OUp')
|
2023-04-14 13:15:04 +00:00
|
|
|
|
with get_openai_callback() as cb:
|
2023-04-07 12:41:06 +00:00
|
|
|
|
res = await agent_executor.arun(question)
|
2023-04-14 13:15:04 +00:00
|
|
|
|
tokens = cb.total_tokens
|
2023-04-14 13:30:10 +00:00
|
|
|
|
# TODO: REMEMBER TO UNCOMMENT THIS BEFORE LAUNCH!!!!
|
2023-04-14 13:15:04 +00:00
|
|
|
|
# if not is_superuser:
|
|
|
|
|
# price = tokens / ONE_K * PRICE_PER_1K_TOKEN
|
|
|
|
|
# petal = price * USD_TO_CNY * CNY_TO_PETAL
|
|
|
|
|
# await msg.data.modify_petal(-petal)
|
|
|
|
|
price = tokens / ONE_K * PRICE_PER_1K_TOKEN
|
|
|
|
|
petal = price * USD_TO_CNY * CNY_TO_PETAL
|
|
|
|
|
msg.data.modify_petal(-int(petal))
|
2023-04-22 03:28:11 +00:00
|
|
|
|
|
|
|
|
|
blocks = parse_markdown(res)
|
|
|
|
|
|
|
|
|
|
chain = []
|
|
|
|
|
for block in blocks:
|
|
|
|
|
if block['type'] == 'text':
|
|
|
|
|
chain.append(Plain(block['content']))
|
|
|
|
|
elif block['type'] == 'latex':
|
|
|
|
|
chain.append(Image(PILImage.open(io.BytesIO(await generate_latex(block['content'])))))
|
|
|
|
|
elif block['type'] == 'code':
|
|
|
|
|
chain.append(Image(PILImage.open(io.BytesIO(await generate_code_snippet(block['content']['code'], block['content']['language'])))))
|
|
|
|
|
|
2023-03-31 09:38:53 +00:00
|
|
|
|
if await check_bool(res):
|
|
|
|
|
raise NoReportException('https://wdf.ink/6OUp')
|
2023-04-22 03:28:11 +00:00
|
|
|
|
await msg.finish(chain)
|
|
|
|
|
|
2023-04-30 03:30:59 +00:00
|
|
|
|
|
2023-04-22 03:28:11 +00:00
|
|
|
|
def parse_markdown(md: str):
|
2023-06-20 06:40:50 +00:00
|
|
|
|
regex = r'(```[\s\S]*?\n```|\$[\s\S]*?\$|[^\n]+)'
|
2023-04-22 03:28:11 +00:00
|
|
|
|
|
|
|
|
|
blocks = []
|
|
|
|
|
for match in re.finditer(regex, md):
|
|
|
|
|
content = match.group(1)
|
|
|
|
|
print(content)
|
|
|
|
|
if content.startswith('```'):
|
|
|
|
|
block = 'code'
|
|
|
|
|
try:
|
|
|
|
|
language, code = re.match(r'```(.*)\n([\s\S]*?)\n```', content).groups()
|
|
|
|
|
except AttributeError:
|
|
|
|
|
raise ValueError('Code block is missing language or code')
|
|
|
|
|
content = {'language': language, 'code': code}
|
2023-06-20 06:36:00 +00:00
|
|
|
|
elif content.startswith('$'):
|
2023-04-22 03:28:11 +00:00
|
|
|
|
block = 'latex'
|
2023-06-20 06:40:50 +00:00
|
|
|
|
content = content[1:-1].strip()
|
2023-04-22 03:28:11 +00:00
|
|
|
|
else:
|
|
|
|
|
block = 'text'
|
|
|
|
|
blocks.append({'type': block, 'content': content})
|
|
|
|
|
|
|
|
|
|
return blocks
|