src: add provider notify
This commit is contained in:
parent
ab7010a8db
commit
67e83a4fb5
|
@ -17,7 +17,7 @@
|
||||||
<default>false</default>
|
<default>false</default>
|
||||||
</key>
|
</key>
|
||||||
<key name="enabled-providers" type="as">
|
<key name="enabled-providers" type="as">
|
||||||
<default>["huggingchat", "baichat"]</default>
|
<default>["baichat", "huggingchat", "openaigpt35turbo", "openaigpt4", "catgpt"]</default>
|
||||||
</key>
|
</key>
|
||||||
<key name="latest-provider" type="s">
|
<key name="latest-provider" type="s">
|
||||||
<default>'huggingchat'</default>
|
<default>'huggingchat'</default>
|
||||||
|
|
|
@ -112,6 +112,7 @@ class BavarderApplication(Adw.Application):
|
||||||
|
|
||||||
self.providers_data = self.settings.get_value("providers-data")
|
self.providers_data = self.settings.get_value("providers-data")
|
||||||
print(self.providers_data)
|
print(self.providers_data)
|
||||||
|
print(self.enabled_providers)
|
||||||
|
|
||||||
for provider, i in zip(self.enabled_providers, range(len(self.enabled_providers))):
|
for provider, i in zip(self.enabled_providers, range(len(self.enabled_providers))):
|
||||||
try:
|
try:
|
||||||
|
@ -122,12 +123,16 @@ class BavarderApplication(Adw.Application):
|
||||||
self.providers[i] = PROVIDERS[provider](self.win, self, None)
|
self.providers[i] = PROVIDERS[provider](self.win, self, None)
|
||||||
|
|
||||||
self.win.provider_selector.set_model(self.provider_selector_model)
|
self.win.provider_selector.set_model(self.provider_selector_model)
|
||||||
|
self.win.provider_selector.connect('notify', self.on_provider_selector_notify)
|
||||||
|
|
||||||
for k, p in self.providers.items():
|
for k, p in self.providers.items():
|
||||||
if p.slug == self.latest_provider:
|
if p.slug == self.latest_provider:
|
||||||
self.win.provider_selector.set_selected(k)
|
self.win.provider_selector.set_selected(k)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def on_provider_selector_notify(self, _unused, pspec):
|
||||||
|
self.win.banner.set_revealed(False)
|
||||||
|
|
||||||
def on_about_action(self, widget, _):
|
def on_about_action(self, widget, _):
|
||||||
"""Callback for the app.about action."""
|
"""Callback for the app.about action."""
|
||||||
about = Adw.AboutWindow(
|
about = Adw.AboutWindow(
|
||||||
|
@ -152,7 +157,7 @@ class BavarderApplication(Adw.Application):
|
||||||
)
|
)
|
||||||
about.present()
|
about.present()
|
||||||
|
|
||||||
def on_preferences_action(self, widget, _):
|
def on_preferences_action(self, widget, *args, **kwargs):
|
||||||
"""Callback for the app.preferences action."""
|
"""Callback for the app.preferences action."""
|
||||||
print("app.preferences action activated")
|
print("app.preferences action activated")
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ class Preferences(Adw.PreferencesWindow):
|
||||||
__gtype_name__ = "Preferences"
|
__gtype_name__ = "Preferences"
|
||||||
|
|
||||||
clear_after_send_switch = Gtk.Template.Child()
|
clear_after_send_switch = Gtk.Template.Child()
|
||||||
|
provider_group = Gtk.Template.Child()
|
||||||
|
|
||||||
def __init__(self, application, **kwargs):
|
def __init__(self, application, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
@ -19,6 +20,8 @@ class Preferences(Adw.PreferencesWindow):
|
||||||
"state-set", self.on_clear_after_send_switch_toggled
|
"state-set", self.on_clear_after_send_switch_toggled
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.setup_providers()
|
||||||
|
|
||||||
def on_clear_after_send_switch_toggled(self, *args):
|
def on_clear_after_send_switch_toggled(self, *args):
|
||||||
"""Callback for the clear_after_send_switch toggled event."""
|
"""Callback for the clear_after_send_switch toggled event."""
|
||||||
state = self.clear_after_send_switch.props.state
|
state = self.clear_after_send_switch.props.state
|
||||||
|
@ -27,3 +30,10 @@ class Preferences(Adw.PreferencesWindow):
|
||||||
self.settings.set_boolean("clear-after-send", True)
|
self.settings.set_boolean("clear-after-send", True)
|
||||||
else:
|
else:
|
||||||
self.settings.set_boolean("clear-after-send", False)
|
self.settings.set_boolean("clear-after-send", False)
|
||||||
|
|
||||||
|
def setup_providers(self):
|
||||||
|
for provider in self.app.providers.values():
|
||||||
|
try:
|
||||||
|
self.provider_group.add(provider.preferences())
|
||||||
|
except TypeError: # no prefs
|
||||||
|
pass
|
|
@ -1,7 +1,15 @@
|
||||||
from .huggingchat import HuggingChatProvider
|
from .huggingchat import HuggingChatProvider
|
||||||
from .baichat import BAIChatProvider
|
from .baichat import BAIChatProvider
|
||||||
|
from .openaigpt35turbo import OpenAIGPT35TurboProvider
|
||||||
|
from .openaigpt4 import OpenAIGPT4Provider
|
||||||
|
from .catgpt import CatGPTProvider
|
||||||
|
from .openaitextdavinci003 import OpenAITextDavinci003
|
||||||
|
|
||||||
PROVIDERS = {
|
PROVIDERS = {
|
||||||
'huggingchat': HuggingChatProvider,
|
|
||||||
'baichat': BAIChatProvider,
|
'baichat': BAIChatProvider,
|
||||||
|
'catgpt': CatGPTProvider,
|
||||||
|
'huggingchat': HuggingChatProvider,
|
||||||
|
'openaigpt35turbo': OpenAIGPT35TurboProvider,
|
||||||
|
'openaigpt4': OpenAIGPT4Provider,
|
||||||
|
'openaitextdavinci003': OpenAITextDavinci003,
|
||||||
}
|
}
|
|
@ -17,7 +17,7 @@ class BAIChatProvider(BavarderProvider):
|
||||||
self.win.banner.set_revealed(False)
|
self.win.banner.set_revealed(False)
|
||||||
return ""
|
return ""
|
||||||
except socket.gaierror:
|
except socket.gaierror:
|
||||||
self.win.banner.set_revealed(True)
|
self.no_connection()
|
||||||
return ""
|
return ""
|
||||||
else:
|
else:
|
||||||
self.win.banner.set_revealed(False)
|
self.win.banner.set_revealed(False)
|
||||||
|
@ -43,7 +43,7 @@ class BAIChatProvider(BavarderProvider):
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
return []
|
return {}
|
||||||
|
|
||||||
def load(self, data):
|
def load(self, data):
|
||||||
pass
|
pass
|
|
@ -26,6 +26,23 @@ class BavarderProvider:
|
||||||
def preferences(self):
|
def preferences(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def no_api_key(self, title=None):
|
||||||
|
if title:
|
||||||
|
self.win.banner.props.title = title
|
||||||
|
else:
|
||||||
|
self.win.banner.props.title = "No API key provided, you can provide one in settings"
|
||||||
|
self.win.banner.props.button_label = "Open settings"
|
||||||
|
self.win.banner.connect("button-clicked", self.app.on_preferences_action)
|
||||||
|
self.win.banner.set_revealed(True)
|
||||||
|
|
||||||
|
def no_connection(self):
|
||||||
|
self.win.banner.props.title = "No network connection"
|
||||||
|
self.win.banner.props.button_label = ""
|
||||||
|
self.win.banner.set_revealed(True)
|
||||||
|
|
||||||
|
def hide_banner(self):
|
||||||
|
self.win.banner.set_revealed(False)
|
||||||
|
|
||||||
def about(self):
|
def about(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
67
src/provider/catgpt.py
Normal file
67
src/provider/catgpt.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
from .base import BavarderProvider
|
||||||
|
|
||||||
|
from random import choice, randint
|
||||||
|
import string
|
||||||
|
|
||||||
|
class CatGPTProvider(BavarderProvider):
|
||||||
|
name = "CatGPT"
|
||||||
|
slug = "catgpt"
|
||||||
|
|
||||||
|
def __init__(self, win, app, *args, **kwargs):
|
||||||
|
super().__init__(win, app, *args, **kwargs)
|
||||||
|
self.chat = None
|
||||||
|
|
||||||
|
def ask(self, prompt):
|
||||||
|
return ' '.join([
|
||||||
|
self.pick_generator()()
|
||||||
|
for i in range(randint(1, 12))
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def pick_generator(self):
|
||||||
|
if randint(1, 15) == 1:
|
||||||
|
return choice([
|
||||||
|
lambda: "ня" * randint(1, 4),
|
||||||
|
lambda: "ニャン" * randint(1, 4),
|
||||||
|
lambda: "喵" * randint(1, 4),
|
||||||
|
lambda: "ña" * randint(1, 4),
|
||||||
|
lambda: "ڽا" * randint(1, 4),
|
||||||
|
lambda: "ম্যাও" * randint(1, 4)
|
||||||
|
])
|
||||||
|
|
||||||
|
return choice([
|
||||||
|
lambda: 'meow' * randint(1, 3),
|
||||||
|
lambda: 'mew' * randint(1, 3),
|
||||||
|
lambda: 'miau' * randint(1, 3),
|
||||||
|
lambda: 'miaou' * randint(1, 3),
|
||||||
|
lambda: 'miao' * randint(1, 3),
|
||||||
|
lambda: 'nya' * randint(1, 3),
|
||||||
|
lambda: 'm' + 'r' * randint(1, 6) + 'p',
|
||||||
|
lambda: 'pur' + 'r' * randint(1, 6),
|
||||||
|
lambda: 'nya' * randint(1, 3) + 'ny' + 'a' * randint(1, 10),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def require_api_key(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def preferences(self):
|
||||||
|
self.no_preferences()
|
||||||
|
|
||||||
|
def about(self):
|
||||||
|
about = Adw.AboutWindow(
|
||||||
|
transient_for=self.props.active_window,
|
||||||
|
application_name="Cat GPT",
|
||||||
|
developer_name="0xMRTT",
|
||||||
|
developers=["0xMRTT https://github.com/0xMRTT"],
|
||||||
|
license_type=Gtk.License.GPL_3_0,
|
||||||
|
version=version,
|
||||||
|
copyright="© 2023 0xMRTT",
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load(self, data):
|
||||||
|
pass
|
|
@ -16,11 +16,8 @@ class HuggingChatProvider(BavarderProvider):
|
||||||
def ask(self, prompt):
|
def ask(self, prompt):
|
||||||
try:
|
try:
|
||||||
response = self.chat.ask(prompt)
|
response = self.chat.ask(prompt)
|
||||||
except KeyError:
|
|
||||||
self.win.banner.set_revealed(False)
|
|
||||||
return ""
|
|
||||||
except socket.gaierror:
|
except socket.gaierror:
|
||||||
self.win.banner.set_revealed(True)
|
self.no_connection()
|
||||||
return ""
|
return ""
|
||||||
else:
|
else:
|
||||||
self.win.banner.set_revealed(False)
|
self.win.banner.set_revealed(False)
|
||||||
|
@ -53,7 +50,7 @@ class HuggingChatProvider(BavarderProvider):
|
||||||
)
|
)
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
return []
|
return {}
|
||||||
|
|
||||||
def load(self, data):
|
def load(self, data):
|
||||||
pass
|
pass
|
|
@ -5,7 +5,12 @@ providers_sources = [
|
||||||
'__init__.py',
|
'__init__.py',
|
||||||
'baichat.py',
|
'baichat.py',
|
||||||
'base.py',
|
'base.py',
|
||||||
|
'catgpt.py',
|
||||||
'huggingchat.py',
|
'huggingchat.py',
|
||||||
|
'openai.py',
|
||||||
|
'openaigpt4.py',
|
||||||
|
'openaigpt35turbo.py',
|
||||||
|
'openaitextdavinci003.py',
|
||||||
]
|
]
|
||||||
|
|
||||||
PY_INSTALLDIR.install_sources(providers_sources, subdir: providers_dir)
|
PY_INSTALLDIR.install_sources(providers_sources, subdir: providers_dir)
|
||||||
|
|
77
src/provider/openai.py
Normal file
77
src/provider/openai.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
from .base import BavarderProvider
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from gi.repository import Gtk, Adw
|
||||||
|
|
||||||
|
class BaseOpenAIProvider(BavarderProvider):
|
||||||
|
name = None
|
||||||
|
slug = None
|
||||||
|
model = None
|
||||||
|
|
||||||
|
def __init__(self, win, app, *args, **kwargs):
|
||||||
|
super().__init__(win, app, *args, **kwargs)
|
||||||
|
self.chat = openai.ChatCompletion
|
||||||
|
|
||||||
|
def ask(self, prompt):
|
||||||
|
try:
|
||||||
|
response = self.chat.create(model=self.model, messages=[{"role": "user", "content": prompt}])
|
||||||
|
response = response.choices[0].message.content
|
||||||
|
except openai.error.AuthenticationError:
|
||||||
|
self.no_api_key()
|
||||||
|
return ""
|
||||||
|
except openai.error.InvalidRequestError:
|
||||||
|
self.win.banner.props.title = "You don't have access to this model"
|
||||||
|
self.win.banner.props.button_label = ""
|
||||||
|
self.win.banner.set_revealed(True)
|
||||||
|
return ""
|
||||||
|
except socket.gaierror:
|
||||||
|
self.no_connection()
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
self.hide_banner()
|
||||||
|
self.update_response(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@property
|
||||||
|
def require_api_key(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def preferences(self):
|
||||||
|
self.expander = Adw.ExpanderRow()
|
||||||
|
self.expander.props.title = self.name
|
||||||
|
|
||||||
|
self.api_row = Adw.PasswordEntryRow()
|
||||||
|
self.api_row.connect("apply", self.on_apply)
|
||||||
|
self.api_row.props.title = "API Key"
|
||||||
|
self.api_row.set_show_apply_button(True)
|
||||||
|
self.expander.add_row(self.api_row)
|
||||||
|
|
||||||
|
return self.expander
|
||||||
|
|
||||||
|
def on_apply(self, widget):
|
||||||
|
self.hide_banner()
|
||||||
|
api_key = self.api_row.get_text()
|
||||||
|
print(api_key)
|
||||||
|
openai.api_key = api_key
|
||||||
|
|
||||||
|
|
||||||
|
def about(self):
|
||||||
|
about = Adw.AboutWindow(
|
||||||
|
transient_for=self.props.active_window,
|
||||||
|
application_name=self.name,
|
||||||
|
developer_name="OpenAI",
|
||||||
|
developers=["0xMRTT https://github.com/0xMRTT"],
|
||||||
|
license_type=Gtk.License.GPL_3_0,
|
||||||
|
version=version,
|
||||||
|
copyright="© 2023 0xMRTT",
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
return {
|
||||||
|
"api_key": openai.api_key
|
||||||
|
}
|
||||||
|
|
||||||
|
def load(self, data):
|
||||||
|
openai.api_key = data["api_key"]
|
6
src/provider/openaigpt35turbo.py
Normal file
6
src/provider/openaigpt35turbo.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from .openai import BaseOpenAIProvider
|
||||||
|
|
||||||
|
class OpenAIGPT35TurboProvider(BaseOpenAIProvider):
|
||||||
|
name = "OpenAI GPT 3.5 Turbo"
|
||||||
|
slug = "openaigpt35turbo"
|
||||||
|
model = "gpt-3.5-turbo"
|
6
src/provider/openaigpt4.py
Normal file
6
src/provider/openaigpt4.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from .openai import BaseOpenAIProvider
|
||||||
|
|
||||||
|
class OpenAIGPT4Provider(BaseOpenAIProvider):
|
||||||
|
name = "OpenAI GPT 4"
|
||||||
|
slug = "openaigpt4"
|
||||||
|
model = "gpt-4"
|
6
src/provider/openaitextdavinci003.py
Normal file
6
src/provider/openaitextdavinci003.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from .openai import BaseOpenAIProvider
|
||||||
|
|
||||||
|
class OpenAITextDavinci003(BaseOpenAIProvider):
|
||||||
|
name = "OpenAI Text Davinci 003"
|
||||||
|
slug = "openaitextdavinci003"
|
||||||
|
model = "text-davinci-003"
|
Loading…
Reference in a new issue