src: add provider notify

This commit is contained in:
0xMRTT 2023-04-30 21:03:34 +02:00
parent ab7010a8db
commit 67e83a4fb5
Signed by: 0xmrtt
GPG key ID: 19C1449A774028BD
13 changed files with 214 additions and 10 deletions

View file

@ -17,7 +17,7 @@
<default>false</default>
</key>
<key name="enabled-providers" type="as">
<default>["huggingchat", "baichat"]</default>
<default>["baichat", "huggingchat", "openaigpt35turbo", "openaigpt4", "catgpt"]</default>
</key>
<key name="latest-provider" type="s">
<default>'huggingchat'</default>

View file

@ -112,6 +112,7 @@ class BavarderApplication(Adw.Application):
self.providers_data = self.settings.get_value("providers-data")
print(self.providers_data)
print(self.enabled_providers)
for provider, i in zip(self.enabled_providers, range(len(self.enabled_providers))):
try:
@ -122,12 +123,16 @@ class BavarderApplication(Adw.Application):
self.providers[i] = PROVIDERS[provider](self.win, self, None)
self.win.provider_selector.set_model(self.provider_selector_model)
self.win.provider_selector.connect('notify', self.on_provider_selector_notify)
for k, p in self.providers.items():
if p.slug == self.latest_provider:
self.win.provider_selector.set_selected(k)
break
def on_provider_selector_notify(self, _unused, pspec):
self.win.banner.set_revealed(False)
def on_about_action(self, widget, _):
"""Callback for the app.about action."""
about = Adw.AboutWindow(
@ -152,7 +157,7 @@ class BavarderApplication(Adw.Application):
)
about.present()
def on_preferences_action(self, widget, _):
def on_preferences_action(self, widget, *args, **kwargs):
"""Callback for the app.preferences action."""
print("app.preferences action activated")

View file

@ -6,6 +6,7 @@ class Preferences(Adw.PreferencesWindow):
__gtype_name__ = "Preferences"
clear_after_send_switch = Gtk.Template.Child()
provider_group = Gtk.Template.Child()
def __init__(self, application, **kwargs):
super().__init__(**kwargs)
@ -19,6 +20,8 @@ class Preferences(Adw.PreferencesWindow):
"state-set", self.on_clear_after_send_switch_toggled
)
self.setup_providers()
def on_clear_after_send_switch_toggled(self, *args):
"""Callback for the clear_after_send_switch toggled event."""
state = self.clear_after_send_switch.props.state
@ -27,3 +30,10 @@ class Preferences(Adw.PreferencesWindow):
self.settings.set_boolean("clear-after-send", True)
else:
self.settings.set_boolean("clear-after-send", False)
def setup_providers(self):
for provider in self.app.providers.values():
try:
self.provider_group.add(provider.preferences())
except TypeError: # no prefs
pass

View file

@ -1,7 +1,15 @@
from .huggingchat import HuggingChatProvider
from .baichat import BAIChatProvider
from .openaigpt35turbo import OpenAIGPT35TurboProvider
from .openaigpt4 import OpenAIGPT4Provider
from .catgpt import CatGPTProvider
from .openaitextdavinci003 import OpenAITextDavinci003
PROVIDERS = {
'huggingchat': HuggingChatProvider,
'baichat': BAIChatProvider,
'catgpt': CatGPTProvider,
'huggingchat': HuggingChatProvider,
'openaigpt35turbo': OpenAIGPT35TurboProvider,
'openaigpt4': OpenAIGPT4Provider,
'openaitextdavinci003': OpenAITextDavinci003,
}

View file

@ -17,7 +17,7 @@ class BAIChatProvider(BavarderProvider):
self.win.banner.set_revealed(False)
return ""
except socket.gaierror:
self.win.banner.set_revealed(True)
self.no_connection()
return ""
else:
self.win.banner.set_revealed(False)
@ -43,7 +43,7 @@ class BAIChatProvider(BavarderProvider):
)
def save(self):
return []
return {}
def load(self, data):
pass

View file

@ -26,6 +26,23 @@ class BavarderProvider:
def preferences(self):
raise NotImplementedError()
def no_api_key(self, title=None):
if title:
self.win.banner.props.title = title
else:
self.win.banner.props.title = "No API key provided, you can provide one in settings"
self.win.banner.props.button_label = "Open settings"
self.win.banner.connect("button-clicked", self.app.on_preferences_action)
self.win.banner.set_revealed(True)
def no_connection(self):
self.win.banner.props.title = "No network connection"
self.win.banner.props.button_label = ""
self.win.banner.set_revealed(True)
def hide_banner(self):
self.win.banner.set_revealed(False)
def about(self):
raise NotImplementedError()

67
src/provider/catgpt.py Normal file
View file

@ -0,0 +1,67 @@
from .base import BavarderProvider
from random import choice, randint
import string
class CatGPTProvider(BavarderProvider):
name = "CatGPT"
slug = "catgpt"
def __init__(self, win, app, *args, **kwargs):
super().__init__(win, app, *args, **kwargs)
self.chat = None
def ask(self, prompt):
return ' '.join([
self.pick_generator()()
for i in range(randint(1, 12))
])
def pick_generator(self):
if randint(1, 15) == 1:
return choice([
lambda: "ня" * randint(1, 4),
lambda: "ニャン" * randint(1, 4),
lambda: "" * randint(1, 4),
lambda: "ña" * randint(1, 4),
lambda: "ڽا" * randint(1, 4),
lambda: "ম্যাও" * randint(1, 4)
])
return choice([
lambda: 'meow' * randint(1, 3),
lambda: 'mew' * randint(1, 3),
lambda: 'miau' * randint(1, 3),
lambda: 'miaou' * randint(1, 3),
lambda: 'miao' * randint(1, 3),
lambda: 'nya' * randint(1, 3),
lambda: 'm' + 'r' * randint(1, 6) + 'p',
lambda: 'pur' + 'r' * randint(1, 6),
lambda: 'nya' * randint(1, 3) + 'ny' + 'a' * randint(1, 10),
])
@property
def require_api_key(self):
return False
def preferences(self):
self.no_preferences()
def about(self):
about = Adw.AboutWindow(
transient_for=self.props.active_window,
application_name="Cat GPT",
developer_name="0xMRTT",
developers=["0xMRTT https://github.com/0xMRTT"],
license_type=Gtk.License.GPL_3_0,
version=version,
copyright="© 2023 0xMRTT",
)
def save(self):
return {}
def load(self, data):
pass

View file

@ -16,11 +16,8 @@ class HuggingChatProvider(BavarderProvider):
def ask(self, prompt):
try:
response = self.chat.ask(prompt)
except KeyError:
self.win.banner.set_revealed(False)
return ""
except socket.gaierror:
self.win.banner.set_revealed(True)
self.no_connection()
return ""
else:
self.win.banner.set_revealed(False)
@ -53,7 +50,7 @@ class HuggingChatProvider(BavarderProvider):
)
def save(self):
return []
return {}
def load(self, data):
pass

View file

@ -5,7 +5,12 @@ providers_sources = [
'__init__.py',
'baichat.py',
'base.py',
'catgpt.py',
'huggingchat.py',
'openai.py',
'openaigpt4.py',
'openaigpt35turbo.py',
'openaitextdavinci003.py',
]
PY_INSTALLDIR.install_sources(providers_sources, subdir: providers_dir)

77
src/provider/openai.py Normal file
View file

@ -0,0 +1,77 @@
from .base import BavarderProvider
import openai
import socket
from gi.repository import Gtk, Adw
class BaseOpenAIProvider(BavarderProvider):
name = None
slug = None
model = None
def __init__(self, win, app, *args, **kwargs):
super().__init__(win, app, *args, **kwargs)
self.chat = openai.ChatCompletion
def ask(self, prompt):
try:
response = self.chat.create(model=self.model, messages=[{"role": "user", "content": prompt}])
response = response.choices[0].message.content
except openai.error.AuthenticationError:
self.no_api_key()
return ""
except openai.error.InvalidRequestError:
self.win.banner.props.title = "You don't have access to this model"
self.win.banner.props.button_label = ""
self.win.banner.set_revealed(True)
return ""
except socket.gaierror:
self.no_connection()
return ""
else:
self.hide_banner()
self.update_response(response)
return response
@property
def require_api_key(self):
return False
def preferences(self):
self.expander = Adw.ExpanderRow()
self.expander.props.title = self.name
self.api_row = Adw.PasswordEntryRow()
self.api_row.connect("apply", self.on_apply)
self.api_row.props.title = "API Key"
self.api_row.set_show_apply_button(True)
self.expander.add_row(self.api_row)
return self.expander
def on_apply(self, widget):
self.hide_banner()
api_key = self.api_row.get_text()
print(api_key)
openai.api_key = api_key
def about(self):
about = Adw.AboutWindow(
transient_for=self.props.active_window,
application_name=self.name,
developer_name="OpenAI",
developers=["0xMRTT https://github.com/0xMRTT"],
license_type=Gtk.License.GPL_3_0,
version=version,
copyright="© 2023 0xMRTT",
)
def save(self):
return {
"api_key": openai.api_key
}
def load(self, data):
openai.api_key = data["api_key"]

View file

@ -0,0 +1,6 @@
from .openai import BaseOpenAIProvider
class OpenAIGPT35TurboProvider(BaseOpenAIProvider):
name = "OpenAI GPT 3.5 Turbo"
slug = "openaigpt35turbo"
model = "gpt-3.5-turbo"

View file

@ -0,0 +1,6 @@
from .openai import BaseOpenAIProvider
class OpenAIGPT4Provider(BaseOpenAIProvider):
name = "OpenAI GPT 4"
slug = "openaigpt4"
model = "gpt-4"

View file

@ -0,0 +1,6 @@
from .openai import BaseOpenAIProvider
class OpenAITextDavinci003(BaseOpenAIProvider):
name = "OpenAI Text Davinci 003"
slug = "openaitextdavinci003"
model = "text-davinci-003"