src: add provider notify
This commit is contained in:
parent
ab7010a8db
commit
67e83a4fb5
|
@ -17,7 +17,7 @@
|
|||
<default>false</default>
|
||||
</key>
|
||||
<key name="enabled-providers" type="as">
|
||||
<default>["huggingchat", "baichat"]</default>
|
||||
<default>["baichat", "huggingchat", "openaigpt35turbo", "openaigpt4", "catgpt"]</default>
|
||||
</key>
|
||||
<key name="latest-provider" type="s">
|
||||
<default>'huggingchat'</default>
|
||||
|
|
|
@ -112,6 +112,7 @@ class BavarderApplication(Adw.Application):
|
|||
|
||||
self.providers_data = self.settings.get_value("providers-data")
|
||||
print(self.providers_data)
|
||||
print(self.enabled_providers)
|
||||
|
||||
for provider, i in zip(self.enabled_providers, range(len(self.enabled_providers))):
|
||||
try:
|
||||
|
@ -122,12 +123,16 @@ class BavarderApplication(Adw.Application):
|
|||
self.providers[i] = PROVIDERS[provider](self.win, self, None)
|
||||
|
||||
self.win.provider_selector.set_model(self.provider_selector_model)
|
||||
self.win.provider_selector.connect('notify', self.on_provider_selector_notify)
|
||||
|
||||
for k, p in self.providers.items():
|
||||
if p.slug == self.latest_provider:
|
||||
self.win.provider_selector.set_selected(k)
|
||||
break
|
||||
|
||||
def on_provider_selector_notify(self, _unused, pspec):
|
||||
self.win.banner.set_revealed(False)
|
||||
|
||||
def on_about_action(self, widget, _):
|
||||
"""Callback for the app.about action."""
|
||||
about = Adw.AboutWindow(
|
||||
|
@ -152,7 +157,7 @@ class BavarderApplication(Adw.Application):
|
|||
)
|
||||
about.present()
|
||||
|
||||
def on_preferences_action(self, widget, _):
|
||||
def on_preferences_action(self, widget, *args, **kwargs):
|
||||
"""Callback for the app.preferences action."""
|
||||
print("app.preferences action activated")
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ class Preferences(Adw.PreferencesWindow):
|
|||
__gtype_name__ = "Preferences"
|
||||
|
||||
clear_after_send_switch = Gtk.Template.Child()
|
||||
provider_group = Gtk.Template.Child()
|
||||
|
||||
def __init__(self, application, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
@ -19,6 +20,8 @@ class Preferences(Adw.PreferencesWindow):
|
|||
"state-set", self.on_clear_after_send_switch_toggled
|
||||
)
|
||||
|
||||
self.setup_providers()
|
||||
|
||||
def on_clear_after_send_switch_toggled(self, *args):
|
||||
"""Callback for the clear_after_send_switch toggled event."""
|
||||
state = self.clear_after_send_switch.props.state
|
||||
|
@ -27,3 +30,10 @@ class Preferences(Adw.PreferencesWindow):
|
|||
self.settings.set_boolean("clear-after-send", True)
|
||||
else:
|
||||
self.settings.set_boolean("clear-after-send", False)
|
||||
|
||||
def setup_providers(self):
|
||||
for provider in self.app.providers.values():
|
||||
try:
|
||||
self.provider_group.add(provider.preferences())
|
||||
except TypeError: # no prefs
|
||||
pass
|
|
@ -1,7 +1,15 @@
|
|||
from .huggingchat import HuggingChatProvider
|
||||
from .baichat import BAIChatProvider
|
||||
from .openaigpt35turbo import OpenAIGPT35TurboProvider
|
||||
from .openaigpt4 import OpenAIGPT4Provider
|
||||
from .catgpt import CatGPTProvider
|
||||
from .openaitextdavinci003 import OpenAITextDavinci003
|
||||
|
||||
PROVIDERS = {
|
||||
'huggingchat': HuggingChatProvider,
|
||||
'baichat': BAIChatProvider,
|
||||
'catgpt': CatGPTProvider,
|
||||
'huggingchat': HuggingChatProvider,
|
||||
'openaigpt35turbo': OpenAIGPT35TurboProvider,
|
||||
'openaigpt4': OpenAIGPT4Provider,
|
||||
'openaitextdavinci003': OpenAITextDavinci003,
|
||||
}
|
|
@ -17,7 +17,7 @@ class BAIChatProvider(BavarderProvider):
|
|||
self.win.banner.set_revealed(False)
|
||||
return ""
|
||||
except socket.gaierror:
|
||||
self.win.banner.set_revealed(True)
|
||||
self.no_connection()
|
||||
return ""
|
||||
else:
|
||||
self.win.banner.set_revealed(False)
|
||||
|
@ -43,7 +43,7 @@ class BAIChatProvider(BavarderProvider):
|
|||
)
|
||||
|
||||
def save(self):
|
||||
return []
|
||||
return {}
|
||||
|
||||
def load(self, data):
|
||||
pass
|
|
@ -26,6 +26,23 @@ class BavarderProvider:
|
|||
def preferences(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def no_api_key(self, title=None):
|
||||
if title:
|
||||
self.win.banner.props.title = title
|
||||
else:
|
||||
self.win.banner.props.title = "No API key provided, you can provide one in settings"
|
||||
self.win.banner.props.button_label = "Open settings"
|
||||
self.win.banner.connect("button-clicked", self.app.on_preferences_action)
|
||||
self.win.banner.set_revealed(True)
|
||||
|
||||
def no_connection(self):
|
||||
self.win.banner.props.title = "No network connection"
|
||||
self.win.banner.props.button_label = ""
|
||||
self.win.banner.set_revealed(True)
|
||||
|
||||
def hide_banner(self):
|
||||
self.win.banner.set_revealed(False)
|
||||
|
||||
def about(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
67
src/provider/catgpt.py
Normal file
67
src/provider/catgpt.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
from .base import BavarderProvider
|
||||
|
||||
from random import choice, randint
|
||||
import string
|
||||
|
||||
class CatGPTProvider(BavarderProvider):
|
||||
name = "CatGPT"
|
||||
slug = "catgpt"
|
||||
|
||||
def __init__(self, win, app, *args, **kwargs):
|
||||
super().__init__(win, app, *args, **kwargs)
|
||||
self.chat = None
|
||||
|
||||
def ask(self, prompt):
|
||||
return ' '.join([
|
||||
self.pick_generator()()
|
||||
for i in range(randint(1, 12))
|
||||
])
|
||||
|
||||
|
||||
def pick_generator(self):
|
||||
if randint(1, 15) == 1:
|
||||
return choice([
|
||||
lambda: "ня" * randint(1, 4),
|
||||
lambda: "ニャン" * randint(1, 4),
|
||||
lambda: "喵" * randint(1, 4),
|
||||
lambda: "ña" * randint(1, 4),
|
||||
lambda: "ڽا" * randint(1, 4),
|
||||
lambda: "ম্যাও" * randint(1, 4)
|
||||
])
|
||||
|
||||
return choice([
|
||||
lambda: 'meow' * randint(1, 3),
|
||||
lambda: 'mew' * randint(1, 3),
|
||||
lambda: 'miau' * randint(1, 3),
|
||||
lambda: 'miaou' * randint(1, 3),
|
||||
lambda: 'miao' * randint(1, 3),
|
||||
lambda: 'nya' * randint(1, 3),
|
||||
lambda: 'm' + 'r' * randint(1, 6) + 'p',
|
||||
lambda: 'pur' + 'r' * randint(1, 6),
|
||||
lambda: 'nya' * randint(1, 3) + 'ny' + 'a' * randint(1, 10),
|
||||
])
|
||||
|
||||
|
||||
@property
|
||||
def require_api_key(self):
|
||||
return False
|
||||
|
||||
def preferences(self):
|
||||
self.no_preferences()
|
||||
|
||||
def about(self):
|
||||
about = Adw.AboutWindow(
|
||||
transient_for=self.props.active_window,
|
||||
application_name="Cat GPT",
|
||||
developer_name="0xMRTT",
|
||||
developers=["0xMRTT https://github.com/0xMRTT"],
|
||||
license_type=Gtk.License.GPL_3_0,
|
||||
version=version,
|
||||
copyright="© 2023 0xMRTT",
|
||||
)
|
||||
|
||||
def save(self):
|
||||
return {}
|
||||
|
||||
def load(self, data):
|
||||
pass
|
|
@ -16,11 +16,8 @@ class HuggingChatProvider(BavarderProvider):
|
|||
def ask(self, prompt):
|
||||
try:
|
||||
response = self.chat.ask(prompt)
|
||||
except KeyError:
|
||||
self.win.banner.set_revealed(False)
|
||||
return ""
|
||||
except socket.gaierror:
|
||||
self.win.banner.set_revealed(True)
|
||||
self.no_connection()
|
||||
return ""
|
||||
else:
|
||||
self.win.banner.set_revealed(False)
|
||||
|
@ -53,7 +50,7 @@ class HuggingChatProvider(BavarderProvider):
|
|||
)
|
||||
|
||||
def save(self):
|
||||
return []
|
||||
return {}
|
||||
|
||||
def load(self, data):
|
||||
pass
|
|
@ -5,7 +5,12 @@ providers_sources = [
|
|||
'__init__.py',
|
||||
'baichat.py',
|
||||
'base.py',
|
||||
'catgpt.py',
|
||||
'huggingchat.py',
|
||||
'openai.py',
|
||||
'openaigpt4.py',
|
||||
'openaigpt35turbo.py',
|
||||
'openaitextdavinci003.py',
|
||||
]
|
||||
|
||||
PY_INSTALLDIR.install_sources(providers_sources, subdir: providers_dir)
|
||||
|
|
77
src/provider/openai.py
Normal file
77
src/provider/openai.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
from .base import BavarderProvider
|
||||
|
||||
import openai
|
||||
import socket
|
||||
|
||||
from gi.repository import Gtk, Adw
|
||||
|
||||
class BaseOpenAIProvider(BavarderProvider):
|
||||
name = None
|
||||
slug = None
|
||||
model = None
|
||||
|
||||
def __init__(self, win, app, *args, **kwargs):
|
||||
super().__init__(win, app, *args, **kwargs)
|
||||
self.chat = openai.ChatCompletion
|
||||
|
||||
def ask(self, prompt):
|
||||
try:
|
||||
response = self.chat.create(model=self.model, messages=[{"role": "user", "content": prompt}])
|
||||
response = response.choices[0].message.content
|
||||
except openai.error.AuthenticationError:
|
||||
self.no_api_key()
|
||||
return ""
|
||||
except openai.error.InvalidRequestError:
|
||||
self.win.banner.props.title = "You don't have access to this model"
|
||||
self.win.banner.props.button_label = ""
|
||||
self.win.banner.set_revealed(True)
|
||||
return ""
|
||||
except socket.gaierror:
|
||||
self.no_connection()
|
||||
return ""
|
||||
else:
|
||||
self.hide_banner()
|
||||
self.update_response(response)
|
||||
return response
|
||||
|
||||
@property
|
||||
def require_api_key(self):
|
||||
return False
|
||||
|
||||
def preferences(self):
|
||||
self.expander = Adw.ExpanderRow()
|
||||
self.expander.props.title = self.name
|
||||
|
||||
self.api_row = Adw.PasswordEntryRow()
|
||||
self.api_row.connect("apply", self.on_apply)
|
||||
self.api_row.props.title = "API Key"
|
||||
self.api_row.set_show_apply_button(True)
|
||||
self.expander.add_row(self.api_row)
|
||||
|
||||
return self.expander
|
||||
|
||||
def on_apply(self, widget):
|
||||
self.hide_banner()
|
||||
api_key = self.api_row.get_text()
|
||||
print(api_key)
|
||||
openai.api_key = api_key
|
||||
|
||||
|
||||
def about(self):
|
||||
about = Adw.AboutWindow(
|
||||
transient_for=self.props.active_window,
|
||||
application_name=self.name,
|
||||
developer_name="OpenAI",
|
||||
developers=["0xMRTT https://github.com/0xMRTT"],
|
||||
license_type=Gtk.License.GPL_3_0,
|
||||
version=version,
|
||||
copyright="© 2023 0xMRTT",
|
||||
)
|
||||
|
||||
def save(self):
|
||||
return {
|
||||
"api_key": openai.api_key
|
||||
}
|
||||
|
||||
def load(self, data):
|
||||
openai.api_key = data["api_key"]
|
6
src/provider/openaigpt35turbo.py
Normal file
6
src/provider/openaigpt35turbo.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from .openai import BaseOpenAIProvider
|
||||
|
||||
class OpenAIGPT35TurboProvider(BaseOpenAIProvider):
|
||||
name = "OpenAI GPT 3.5 Turbo"
|
||||
slug = "openaigpt35turbo"
|
||||
model = "gpt-3.5-turbo"
|
6
src/provider/openaigpt4.py
Normal file
6
src/provider/openaigpt4.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from .openai import BaseOpenAIProvider
|
||||
|
||||
class OpenAIGPT4Provider(BaseOpenAIProvider):
|
||||
name = "OpenAI GPT 4"
|
||||
slug = "openaigpt4"
|
||||
model = "gpt-4"
|
6
src/provider/openaitextdavinci003.py
Normal file
6
src/provider/openaitextdavinci003.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from .openai import BaseOpenAIProvider
|
||||
|
||||
class OpenAITextDavinci003(BaseOpenAIProvider):
|
||||
name = "OpenAI Text Davinci 003"
|
||||
slug = "openaitextdavinci003"
|
||||
model = "text-davinci-003"
|
Loading…
Reference in a new issue