Archived
1
0
Fork 0

add rollback when commit failed

This commit is contained in:
yzhh 2021-08-27 00:24:21 +08:00
parent e281c7a509
commit 43bb225656
3 changed files with 125 additions and 77 deletions

View file

@ -49,7 +49,7 @@ async def parser(msg: MessageSession):
elif module.desc is not None:
return await msg.sendMessage(module.desc)
if module.need_superuser:
if not senderInfo.query.isSuperUser:
if not msg.checkSuperUser():
return await msg.sendMessage('你没有使用该命令的权限。')
elif not module.is_base_function:
if command_first_word not in enabled_modules_list: # 若未开启

View file

@ -55,39 +55,51 @@ class BotDBUtil:
return True if module_name in self.enable_modules_list else False
def enable(self, module_name) -> bool:
if isinstance(module_name, str):
if module_name not in self.enable_modules_list:
self.enable_modules_list.append(module_name)
elif isinstance(module_name, (list, tuple)):
for x in module_name:
if x not in self.enable_modules_list:
self.enable_modules_list.append(x)
value = convert_list_to_str(self.enable_modules_list)
if self.need_insert:
table = EnabledModules(targetId=self.targetId,
enabledModules=value)
session.add_all([table])
else:
self.query_EnabledModules.enabledModules = value
session.commit()
session.expire_all()
EnabledModulesCache.add_cache(self.targetId, self.enable_modules_list)
return True
def disable(self, module_name) -> bool:
if isinstance(module_name, str):
if module_name in self.enable_modules_list:
self.enable_modules_list.remove(module_name)
elif isinstance(module_name, (list, tuple)):
for x in module_name:
if x in self.enable_modules_list:
self.enable_modules_list.remove(x)
if not self.need_insert:
self.query_EnabledModules.enabledModules = convert_list_to_str(self.enable_modules_list)
try:
if isinstance(module_name, str):
if module_name not in self.enable_modules_list:
self.enable_modules_list.append(module_name)
elif isinstance(module_name, (list, tuple)):
for x in module_name:
if x not in self.enable_modules_list:
self.enable_modules_list.append(x)
value = convert_list_to_str(self.enable_modules_list)
if self.need_insert:
table = EnabledModules(targetId=self.targetId,
enabledModules=value)
session.add_all([table])
else:
self.query_EnabledModules.enabledModules = value
session.commit()
session.expire_all()
EnabledModulesCache.add_cache(self.targetId, self.enable_modules_list)
return True
return True
except Exception:
session.rollback()
raise
finally:
session.close()
def disable(self, module_name) -> bool:
try:
if isinstance(module_name, str):
if module_name in self.enable_modules_list:
self.enable_modules_list.remove(module_name)
elif isinstance(module_name, (list, tuple)):
for x in module_name:
if x in self.enable_modules_list:
self.enable_modules_list.remove(x)
if not self.need_insert:
self.query_EnabledModules.enabledModules = convert_list_to_str(self.enable_modules_list)
session.commit()
session.expire_all()
EnabledModulesCache.add_cache(self.targetId, self.enable_modules_list)
return True
except Exception:
session.rollback()
raise
finally:
session.close()
@staticmethod
def get_enabled_this(module_name):
@ -107,23 +119,35 @@ class BotDBUtil:
self.query = Dict2Object(query_cache)
else:
self.query = self.query_SenderInfo
if self.query is None:
session.add_all([SenderInfo(id=senderId)])
session.commit()
self.query = session.query(SenderInfo).filter_by(id=senderId).first()
SenderInfoCache.add_cache(self.senderId, self.query.__dict__)
try:
if self.query is None:
session.add_all([SenderInfo(id=senderId)])
session.commit()
self.query = session.query(SenderInfo).filter_by(id=senderId).first()
SenderInfoCache.add_cache(self.senderId, self.query.__dict__)
except Exception:
session.rollback()
raise
finally:
session.close()
@property
def query_SenderInfo(self):
return session.query(SenderInfo).filter_by(id=self.senderId).first()
def edit(self, column: str, value):
query = self.query_SenderInfo
setattr(query, column, value)
session.commit()
session.expire_all()
SenderInfoCache.add_cache(self.senderId, query.__dict__)
return True
try:
query = self.query_SenderInfo
setattr(query, column, value)
session.commit()
session.expire_all()
SenderInfoCache.add_cache(self.senderId, query.__dict__)
return True
except Exception:
session.rollback()
raise
finally:
session.close()
def check_TargetAdmin(self, targetId):
query = session.query(TargetAdmin).filter_by(senderId=self.senderId, targetId=targetId).first()
@ -132,16 +156,29 @@ class BotDBUtil:
return False
def add_TargetAdmin(self, targetId):
if not self.check_TargetAdmin(targetId):
session.add_all([TargetAdmin(senderId=self.senderId, targetId=targetId)])
session.commit()
return True
try:
if not self.check_TargetAdmin(targetId):
session.add_all([TargetAdmin(senderId=self.senderId, targetId=targetId)])
session.commit()
return True
except Exception:
session.rollback()
raise
finally:
session.close()
def remove_TargetAdmin(self, targetId):
query = self.check_TargetAdmin(targetId)
if query:
session.delete(query)
session.commit()
try:
query = self.check_TargetAdmin(targetId)
if query:
session.delete(query)
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
class CoolDown:
def __init__(self, msg: MessageSession, name):
@ -160,8 +197,14 @@ class BotDBUtil:
return 0
def reset(self):
if not self.need_insert:
session.delete(self.query)
try:
if not self.need_insert:
session.delete(self.query)
session.commit()
session.add_all([CommandTriggerTime(targetId=self.msg.target.targetId, commandName=self.name)])
session.commit()
session.add_all([CommandTriggerTime(targetId=self.msg.target.targetId, commandName=self.name)])
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()

View file

@ -35,28 +35,33 @@ async def config_modules(msg: MessageSession):
query = BotDBUtil.Module(msg)
msglist = []
if msg.parsed_msg['enable']:
for module in wait_config_list:
if module == 'all':
for function in modules:
if not modules[function].need_superuser:
if query.enable(function):
msglist.append(f'成功:打开模块“{function}')
elif module not in modules:
msglist.append(f'失败:“{module}”模块不存在')
else:
if query.enable(wait_config_list):
msglist.append(f'成功:打开模块“{module}')
if wait_config_list == ['all']:
for function in modules:
if not modules[function].need_superuser:
if query.enable(function):
msglist.append(f'成功:打开模块“{function}')
else:
for module in wait_config_list:
if module not in modules:
msglist.append(f'失败:“{module}”模块不存在')
else:
if modules[module].need_superuser and not msg.checkSuperUser():
msglist.append(f'失败:你没有打开“{module}”的权限。')
else:
if query.enable(wait_config_list):
msglist.append(f'成功:打开模块“{module}')
elif msg.parsed_msg['disable']:
for module in wait_config_list:
if module == 'all':
for function in modules:
if query.disable(function):
msglist.append(f'成功:关闭模块“{function}')
elif module not in modules:
msglist.append(f'失败:“{module}”模块不存在')
else:
if query.disable(wait_config_list):
msglist.append(f'成功:关闭模块“{module}')
if wait_config_list == ['all']:
for function in modules:
if query.disable(function):
msglist.append(f'成功:关闭模块“{function}')
else:
for module in wait_config_list:
if module not in modules:
msglist.append(f'失败:“{module}”模块不存在')
else:
if query.disable(wait_config_list):
msglist.append(f'成功:关闭模块“{module}')
if msglist is not None:
await msg.sendMessage('\n'.join(msglist))
@ -95,7 +100,7 @@ async def bot_help(msg: MessageSession):
help_msg.append(' | '.join(module))
print(help_msg)
help_msg.append(
'使用~help <对应模块名>查看详细信息。\n使用~modules查看所有的可用模块。\n你也可以通过查阅文档获取帮助:\nhttps://bot.teahou.se/wiki/\n请向我们捐赠以维持机器人稳定服务:\nhttps://bot.teahou.se/wiki/%E6%8D%90%E8%B5%A0')
'使用~help <对应模块名>查看详细信息。\n使用~modules查看所有的可用模块。\n你也可以通过查阅文档获取帮助:\nhttps://bot.teahou.se/wiki/')
help_msg.append('[本消息将在一分钟后撤回]')
send = await msg.sendMessage('\n'.join(help_msg))
await asyncio.sleep(60)
@ -114,7 +119,7 @@ async def modules_help(msg: MessageSession):
module.append(module_list[x].bind_prefix)
help_msg.append(' | '.join(module))
help_msg.append(
'使用~help <模块名>查看详细信息。\n你也可以通过查阅文档获取帮助:\nhttps://bot.teahou.se/wiki/\n请向我们捐赠以维持机器人稳定服务:\nhttps://bot.teahou.se/wiki/%E6%8D%90%E8%B5%A0')
'使用~help <模块名>查看详细信息。\n你也可以通过查阅文档获取帮助:\nhttps://bot.teahou.se/wiki/')
help_msg.append('[本消息将在一分钟后撤回]')
send = await msg.sendMessage('\n'.join(help_msg))
await asyncio.sleep(60)