Archived
1
0
Fork 0
This repository has been archived on 2024-04-26. You can view files and clone it, but cannot push or open issues or pull requests.
akari-bot/modules/ask/tools/utils.py
2023-06-20 15:29:56 +08:00

54 lines
1.5 KiB
Python

from typing import Any, Callable
import ujson as json
from langchain.tools import StructuredTool
from core.types.message import MessageSession, MsgInfo, Session
from core.utils.i18n import Locale
def to_json_func(func: Callable):
async def wrapper(*args, **kwargs):
return json.dumps(await func(*args, **kwargs))
return wrapper
def to_async_func(func: Callable):
async def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
def with_args(func: Callable, *args, **kwargs):
async def wrapper(*a, **k):
# if a is tuple with empty string
if len(a) == 1 and a[0] == '':
return await func(*args, **kwargs, **k)
return await func(*args, *a, **kwargs, **k)
return wrapper
def parse_input(input: str):
vals = input.split(',')
parsed = []
for v in vals:
parsed.append(v.strip().strip('"'.strip("'")))
return parsed
class AkariTool(StructuredTool):
def __init__(
self,
name: str,
func: Callable,
args_schema: Any = None,
description: str = None,
return_direct: bool = False):
super().__init__(name=name, args_schema=args_schema, description=description, return_direct=return_direct, func=func)
self.coroutine = func
fake_msg = MessageSession(MsgInfo('Ask|0', 'Ask|0', 'AkariBot', 'Ask', 'Ask', 'Ask', 0),
Session('~lol lol', 'Ask|0', 'Ask|0'))
locale_en = Locale('en_us')
fake_msg.locale = locale_en