2023-04-08 04:25:48 +00:00
|
|
|
import ujson as json
|
|
|
|
from typing import Callable
|
|
|
|
|
|
|
|
from langchain.agents import Tool
|
|
|
|
|
|
|
|
from core.utils.i18n import Locale
|
|
|
|
from core.types.message import MessageSession, MsgInfo, Session
|
|
|
|
|
|
|
|
def to_json_func(func: Callable):
|
|
|
|
async def wrapper(*args, **kwargs):
|
|
|
|
return json.dumps(await func(*args, **kwargs))
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
def to_async_func(func: Callable):
|
|
|
|
async def wrapper(*args, **kwargs):
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
def with_args(func: Callable, *args, **kwargs):
|
|
|
|
async def wrapper(*a, **k):
|
|
|
|
# if a is tuple with empty string
|
|
|
|
if len(a) == 1 and a[0] == '':
|
|
|
|
return await func(*args, **kwargs, **k)
|
|
|
|
return await func(*args, *a, **kwargs, **k)
|
|
|
|
return wrapper
|
|
|
|
|
2023-04-08 05:22:26 +00:00
|
|
|
def parse_input(input: str):
|
|
|
|
vals = input.split(',')
|
|
|
|
parsed = []
|
|
|
|
for v in vals:
|
|
|
|
parsed.append(v.strip().strip('"'.strip("'")))
|
|
|
|
return parsed
|
2023-04-08 04:25:48 +00:00
|
|
|
|
|
|
|
class AkariTool(Tool):
|
|
|
|
def __init__(self, name: str, func: Callable, description: str = None):
|
|
|
|
super().__init__(name, func, description)
|
|
|
|
self.coroutine = func
|
|
|
|
|
|
|
|
fake_msg = MessageSession(MsgInfo('Ask|0', 'Ask|0', 'AkariBot', 'Ask', 'Ask', 'Ask', 0),
|
|
|
|
Session('~lol lol', 'Ask|0', 'Ask|0'))
|
2023-04-15 08:18:51 +00:00
|
|
|
locale_en = Locale('en_us')
|
|
|
|
fake_msg.locale = locale_en
|