format: with ruff and black
This commit is contained in:
parent
f47b690f5b
commit
5a855e3236
31
src/main.py
31
src/main.py
|
@ -21,13 +21,12 @@ import sys
|
|||
import gi
|
||||
import sys
|
||||
import threading
|
||||
import socket
|
||||
import json
|
||||
|
||||
gi.require_version("Gtk", "4.0")
|
||||
gi.require_version("Adw", "1")
|
||||
gi.require_version("Gdk", "4.0")
|
||||
gi.require_version('Gst', '1.0')
|
||||
gi.require_version("Gst", "1.0")
|
||||
|
||||
from gi.repository import Gtk, Gio, Adw, Gdk, GLib, Gst
|
||||
from .window import BavarderWindow
|
||||
|
@ -35,12 +34,12 @@ from .preferences import Preferences
|
|||
|
||||
from .constants import app_id, version
|
||||
|
||||
from hgchat import HGChat
|
||||
from gtts import gTTS
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from .provider import PROVIDERS
|
||||
|
||||
|
||||
class BavarderApplication(Adw.Application):
|
||||
"""The main application singleton class."""
|
||||
|
||||
|
@ -65,15 +64,17 @@ class BavarderApplication(Adw.Application):
|
|||
self.enabled_providers = set(self.settings.get_strv("enabled-providers"))
|
||||
self.latest_provider = self.settings.get_string("latest-provider")
|
||||
self.latest_provider = "huggingchat"
|
||||
|
||||
|
||||
# GStreamer playbin object and related setup
|
||||
Gst.init(None)
|
||||
self.player = Gst.ElementFactory.make('playbin', 'player')
|
||||
self.player = Gst.ElementFactory.make("playbin", "player")
|
||||
self.pipeline = Gst.Pipeline()
|
||||
# bus = self.player.get_bus()
|
||||
# bus.add_signal_watch()
|
||||
# bus.connect('message', self.on_gst_message)
|
||||
self.player_event = threading.Event() # An event for letting us know when Gst is done playing
|
||||
self.player_event = (
|
||||
threading.Event()
|
||||
) # An event for letting us know when Gst is done playing
|
||||
|
||||
def on_quit(self, action, param):
|
||||
"""Called when the user activates the Quit action."""
|
||||
|
@ -114,17 +115,21 @@ class BavarderApplication(Adw.Application):
|
|||
print(self.providers_data)
|
||||
print(self.enabled_providers)
|
||||
|
||||
for provider, i in zip(self.enabled_providers, range(len(self.enabled_providers))):
|
||||
for provider, i in zip(
|
||||
self.enabled_providers, range(len(self.enabled_providers))
|
||||
):
|
||||
try:
|
||||
self.provider_selector_model.append(PROVIDERS[provider].name)
|
||||
|
||||
self.providers[i] = PROVIDERS[provider](self.win, self, self.providers_data[i])
|
||||
self.providers[i] = PROVIDERS[provider](
|
||||
self.win, self, self.providers_data[i]
|
||||
)
|
||||
except KeyError:
|
||||
self.providers[i] = PROVIDERS[provider](self.win, self, None)
|
||||
|
||||
self.win.provider_selector.set_model(self.provider_selector_model)
|
||||
self.win.provider_selector.connect('notify', self.on_provider_selector_notify)
|
||||
|
||||
self.win.provider_selector.connect("notify", self.on_provider_selector_notify)
|
||||
|
||||
for k, p in self.providers.items():
|
||||
if p.slug == self.latest_provider:
|
||||
self.win.provider_selector.set_selected(k)
|
||||
|
@ -252,13 +257,11 @@ class BavarderApplication(Adw.Application):
|
|||
print(exc)
|
||||
|
||||
def _play_audio(self, path):
|
||||
uri = 'file://' + path
|
||||
self.player.set_property('uri', uri)
|
||||
uri = "file://" + path
|
||||
self.player.set_property("uri", uri)
|
||||
self.pipeline.add(self.player)
|
||||
self.pipeline.set_state(Gst.State.PLAYING)
|
||||
self.player.set_state(Gst.State.PLAYING)
|
||||
|
||||
|
||||
|
||||
def on_listen_action(self, widget, _):
|
||||
"""Callback for the app.listen action."""
|
||||
|
|
|
@ -35,5 +35,5 @@ class Preferences(Adw.PreferencesWindow):
|
|||
for provider in self.app.providers.values():
|
||||
try:
|
||||
self.provider_group.add(provider.preferences())
|
||||
except TypeError: # no prefs
|
||||
pass
|
||||
except TypeError: # no prefs
|
||||
pass
|
||||
|
|
|
@ -12,16 +12,16 @@ from .hfgpt2 import HuggingFaceGPT2Provider
|
|||
from .hfdialogpt import HuggingFaceDialoGPTLargeProvider
|
||||
|
||||
PROVIDERS = {
|
||||
'alpacalora': AlpacaLoRAProvider,
|
||||
'baichat': BAIChatProvider,
|
||||
'catgpt': CatGPTProvider,
|
||||
'hfdialogpt': HuggingFaceDialoGPTLargeProvider,
|
||||
'hfgoogleflant5xxl': HuggingFaceGoogleFlanT5XXLProvider,
|
||||
'hfgoogleflanu12': HuggingFaceGoogleFlanU12Provider,
|
||||
'hfgpt2': HuggingFaceGPT2Provider,
|
||||
'hfopenassistantsft1pythia12b': HuggingFaceOpenAssistantSFT1PythiaProvider,
|
||||
'huggingchat': HuggingChatProvider,
|
||||
'openaigpt35turbo': OpenAIGPT35TurboProvider,
|
||||
'openaigpt4': OpenAIGPT4Provider,
|
||||
'openaitextdavinci003': OpenAITextDavinci003,
|
||||
}
|
||||
"alpacalora": AlpacaLoRAProvider,
|
||||
"baichat": BAIChatProvider,
|
||||
"catgpt": CatGPTProvider,
|
||||
"hfdialogpt": HuggingFaceDialoGPTLargeProvider,
|
||||
"hfgoogleflant5xxl": HuggingFaceGoogleFlanT5XXLProvider,
|
||||
"hfgoogleflanu12": HuggingFaceGoogleFlanU12Provider,
|
||||
"hfgpt2": HuggingFaceGPT2Provider,
|
||||
"hfopenassistantsft1pythia12b": HuggingFaceOpenAssistantSFT1PythiaProvider,
|
||||
"huggingchat": HuggingChatProvider,
|
||||
"openaigpt35turbo": OpenAIGPT35TurboProvider,
|
||||
"openaigpt4": OpenAIGPT4Provider,
|
||||
"openaitextdavinci003": OpenAITextDavinci003,
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ import socket
|
|||
import requests
|
||||
|
||||
from gi.repository import Gtk, Adw
|
||||
|
||||
|
||||
class AlpacaLoRAProvider(BavarderProvider):
|
||||
name = "Alpaca-LoRA"
|
||||
slug = "alpacalora"
|
||||
|
@ -11,20 +13,22 @@ class AlpacaLoRAProvider(BavarderProvider):
|
|||
def __init__(self, win, app, *args, **kwargs):
|
||||
super().__init__(win, app, *args, **kwargs)
|
||||
|
||||
|
||||
def ask(self, prompt):
|
||||
try:
|
||||
response = requests.post("https://tloen-alpaca-lora.hf.space/run/predict", json={
|
||||
"data": [
|
||||
prompt,
|
||||
prompt,
|
||||
0.1,
|
||||
0.75,
|
||||
40,
|
||||
4,
|
||||
128,
|
||||
]
|
||||
}).json()
|
||||
response = requests.post(
|
||||
"https://tloen-alpaca-lora.hf.space/run/predict",
|
||||
json={
|
||||
"data": [
|
||||
prompt,
|
||||
prompt,
|
||||
0.1,
|
||||
0.75,
|
||||
40,
|
||||
4,
|
||||
128,
|
||||
]
|
||||
},
|
||||
).json()
|
||||
except socket.gaierror:
|
||||
self.no_connection()
|
||||
return ""
|
||||
|
@ -40,7 +44,7 @@ class AlpacaLoRAProvider(BavarderProvider):
|
|||
|
||||
def preferences(self):
|
||||
self.no_preferences()
|
||||
|
||||
|
||||
def about(self):
|
||||
about = Adw.AboutWindow(
|
||||
transient_for=self.props.active_window,
|
||||
|
@ -51,9 +55,9 @@ class AlpacaLoRAProvider(BavarderProvider):
|
|||
version=version,
|
||||
copyright="© 2023 0xMRTT",
|
||||
)
|
||||
|
||||
|
||||
def save(self):
|
||||
return {}
|
||||
|
||||
def load(self, data):
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -3,9 +3,11 @@ from .base import BavarderProvider
|
|||
from baichat_py import BAIChat
|
||||
import socket
|
||||
|
||||
|
||||
class BAIChatProvider(BavarderProvider):
|
||||
name = "BAI Chat"
|
||||
slug = "baichat"
|
||||
|
||||
def __init__(self, win, app, *args, **kwargs):
|
||||
super().__init__(win, app, *args, **kwargs)
|
||||
self.chat = BAIChat(sync=True)
|
||||
|
@ -30,7 +32,7 @@ class BAIChatProvider(BavarderProvider):
|
|||
|
||||
def preferences(self):
|
||||
self.no_preferences()
|
||||
|
||||
|
||||
def about(self):
|
||||
about = Adw.AboutWindow(
|
||||
transient_for=self.props.active_window,
|
||||
|
@ -46,4 +48,4 @@ class BAIChatProvider(BavarderProvider):
|
|||
return {}
|
||||
|
||||
def load(self, data):
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
class BavarderProvider:
|
||||
name = None
|
||||
slug = None
|
||||
|
||||
def __init__(self, win, app, data, *args, **kwargs):
|
||||
self.win = win
|
||||
self.banner = win.banner
|
||||
|
@ -30,7 +30,9 @@ class BavarderProvider:
|
|||
if title:
|
||||
self.win.banner.props.title = title
|
||||
else:
|
||||
self.win.banner.props.title = "No API key provided, you can provide one in settings"
|
||||
self.win.banner.props.title = (
|
||||
"No API key provided, you can provide one in settings"
|
||||
)
|
||||
self.win.banner.props.button_label = "Open settings"
|
||||
self.win.banner.connect("button-clicked", self.app.on_preferences_action)
|
||||
self.win.banner.set_revealed(True)
|
||||
|
@ -53,4 +55,4 @@ class BavarderProvider:
|
|||
raise NotImplementedError()
|
||||
|
||||
def load(self, data):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from .base import BavarderProvider
|
||||
|
||||
from random import choice, randint
|
||||
import string
|
||||
|
||||
|
||||
class CatGPTProvider(BavarderProvider):
|
||||
name = "CatGPT"
|
||||
|
@ -12,35 +12,34 @@ class CatGPTProvider(BavarderProvider):
|
|||
self.chat = None
|
||||
|
||||
def ask(self, prompt):
|
||||
return ' '.join([
|
||||
self.pick_generator()()
|
||||
for i in range(randint(1, 12))
|
||||
])
|
||||
return " ".join([self.pick_generator()() for i in range(randint(1, 12))])
|
||||
|
||||
|
||||
def pick_generator(self):
|
||||
if randint(1, 15) == 1:
|
||||
return choice([
|
||||
lambda: "ня" * randint(1, 4),
|
||||
lambda: "ニャン" * randint(1, 4),
|
||||
lambda: "喵" * randint(1, 4),
|
||||
lambda: "ña" * randint(1, 4),
|
||||
lambda: "ڽا" * randint(1, 4),
|
||||
lambda: "ম্যাও" * randint(1, 4)
|
||||
])
|
||||
|
||||
return choice([
|
||||
lambda: 'meow' * randint(1, 3),
|
||||
lambda: 'mew' * randint(1, 3),
|
||||
lambda: 'miau' * randint(1, 3),
|
||||
lambda: 'miaou' * randint(1, 3),
|
||||
lambda: 'miao' * randint(1, 3),
|
||||
lambda: 'nya' * randint(1, 3),
|
||||
lambda: 'm' + 'r' * randint(1, 6) + 'p',
|
||||
lambda: 'pur' + 'r' * randint(1, 6),
|
||||
lambda: 'nya' * randint(1, 3) + 'ny' + 'a' * randint(1, 10),
|
||||
])
|
||||
return choice(
|
||||
[
|
||||
lambda: "ня" * randint(1, 4),
|
||||
lambda: "ニャン" * randint(1, 4),
|
||||
lambda: "喵" * randint(1, 4),
|
||||
lambda: "ña" * randint(1, 4),
|
||||
lambda: "ڽا" * randint(1, 4),
|
||||
lambda: "ম্যাও" * randint(1, 4),
|
||||
]
|
||||
)
|
||||
|
||||
return choice(
|
||||
[
|
||||
lambda: "meow" * randint(1, 3),
|
||||
lambda: "mew" * randint(1, 3),
|
||||
lambda: "miau" * randint(1, 3),
|
||||
lambda: "miaou" * randint(1, 3),
|
||||
lambda: "miao" * randint(1, 3),
|
||||
lambda: "nya" * randint(1, 3),
|
||||
lambda: "m" + "r" * randint(1, 6) + "p",
|
||||
lambda: "pur" + "r" * randint(1, 6),
|
||||
lambda: "nya" * randint(1, 3) + "ny" + "a" * randint(1, 10),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def require_api_key(self):
|
||||
|
@ -48,7 +47,7 @@ class CatGPTProvider(BavarderProvider):
|
|||
|
||||
def preferences(self):
|
||||
self.no_preferences()
|
||||
|
||||
|
||||
def about(self):
|
||||
about = Adw.AboutWindow(
|
||||
transient_for=self.props.active_window,
|
||||
|
@ -64,4 +63,4 @@ class CatGPTProvider(BavarderProvider):
|
|||
return {}
|
||||
|
||||
def load(self, data):
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -3,6 +3,7 @@ import json
|
|||
import socket
|
||||
import requests
|
||||
|
||||
|
||||
class HuggingFaceDialoGPTLargeProvider(BaseHFProvider):
|
||||
name = "DialoGPT"
|
||||
slug = "dialogpt"
|
||||
|
@ -11,16 +12,16 @@ class HuggingFaceDialoGPTLargeProvider(BaseHFProvider):
|
|||
|
||||
def ask(self, prompt):
|
||||
try:
|
||||
payload = json.dumps({
|
||||
"inputs": {
|
||||
#"past_user_inputs": ["Which movie is the best ?"],
|
||||
#"generated_responses": ["It's Die Hard for sure."],
|
||||
"text": prompt
|
||||
},
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
payload = json.dumps(
|
||||
{
|
||||
"inputs": {
|
||||
# "past_user_inputs": ["Which movie is the best ?"],
|
||||
# "generated_responses": ["It's Die Hard for sure."],
|
||||
"text": prompt
|
||||
},
|
||||
}
|
||||
)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.authorization:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
url = f"https://api-inference.huggingface.co/models/{self.model}"
|
||||
|
@ -41,4 +42,4 @@ class HuggingFaceDialoGPTLargeProvider(BaseHFProvider):
|
|||
self.hide_banner()
|
||||
print(response)
|
||||
self.update_response(response)
|
||||
return response
|
||||
return response
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from .huggingface import BaseHFProvider
|
||||
|
||||
|
||||
class HuggingFaceGoogleFlanT5XXLProvider(BaseHFProvider):
|
||||
name = "Google Flan T5 XXL"
|
||||
slug = "hfgoogleflant5xxl"
|
||||
model = "google/flan-t5-xxl"
|
||||
authorization = False
|
||||
authorization = False
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from .huggingface import BaseHFProvider
|
||||
|
||||
|
||||
class HuggingFaceGoogleFlanU12Provider(BaseHFProvider):
|
||||
name = "Google Flan U12"
|
||||
slug = "hfgoogleflanu12"
|
||||
model = "google/flan-ul2"
|
||||
authorization = False
|
||||
authorization = False
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from .huggingface import BaseHFProvider
|
||||
|
||||
|
||||
class HuggingFaceGPT2Provider(BaseHFProvider):
|
||||
name = "GPT 2"
|
||||
slug = "gpt2"
|
||||
model = "gpt2"
|
||||
authorization = False
|
||||
authorization = False
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from .huggingface import BaseHFProvider
|
||||
|
||||
|
||||
class HuggingFaceOpenAssistantSFT1PythiaProvider(BaseHFProvider):
|
||||
name = "Open-Assistant SFT-1 12B Model "
|
||||
slug = "hfopenassistantsft1pythia12b"
|
||||
model = "OpenAssistant/oasst-sft-1-pythia-12b"
|
||||
authorization = False
|
||||
authorization = False
|
||||
|
|
|
@ -5,6 +5,8 @@ import socket
|
|||
|
||||
|
||||
from gi.repository import Gtk, Adw
|
||||
|
||||
|
||||
class HuggingChatProvider(BavarderProvider):
|
||||
name = "Hugging Chat"
|
||||
slug = "huggingchat"
|
||||
|
@ -37,7 +39,7 @@ class HuggingChatProvider(BavarderProvider):
|
|||
|
||||
def preferences(self):
|
||||
self.no_preferences()
|
||||
|
||||
|
||||
def about(self):
|
||||
about = Adw.AboutWindow(
|
||||
transient_for=self.props.active_window,
|
||||
|
@ -48,9 +50,9 @@ class HuggingChatProvider(BavarderProvider):
|
|||
version=version,
|
||||
copyright="© 2023 0xMRTT",
|
||||
)
|
||||
|
||||
|
||||
def save(self):
|
||||
return {}
|
||||
|
||||
def load(self, data):
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
import requests
|
||||
import json
|
||||
|
||||
url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
|
||||
|
||||
|
||||
from .base import BavarderProvider
|
||||
|
||||
import socket
|
||||
|
||||
from gi.repository import Gtk, Adw
|
||||
|
||||
|
||||
class BaseHFProvider(BavarderProvider):
|
||||
name = None
|
||||
slug = None
|
||||
|
@ -22,12 +19,8 @@ class BaseHFProvider(BavarderProvider):
|
|||
|
||||
def ask(self, prompt):
|
||||
try:
|
||||
payload = json.dumps({
|
||||
"inputs": prompt
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
payload = json.dumps({"inputs": prompt})
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.authorization:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
url = f"https://api-inference.huggingface.co/models/{self.model}"
|
||||
|
@ -75,7 +68,6 @@ class BaseHFProvider(BavarderProvider):
|
|||
self.api_key = self.api_row.get_text()
|
||||
print(self.api_key)
|
||||
|
||||
|
||||
def about(self):
|
||||
about = Adw.AboutWindow(
|
||||
transient_for=self.props.active_window,
|
||||
|
@ -86,14 +78,12 @@ class BaseHFProvider(BavarderProvider):
|
|||
version=version,
|
||||
copyright="© 2023 0xMRTT",
|
||||
)
|
||||
|
||||
|
||||
def save(self):
|
||||
if self.authorization:
|
||||
return {
|
||||
"api_key": self.api_key
|
||||
}
|
||||
return {"api_key": self.api_key}
|
||||
return {}
|
||||
|
||||
def load(self, data):
|
||||
if self.authorization:
|
||||
self.api_key = data["api_key"]
|
||||
self.api_key = data["api_key"]
|
||||
|
|
|
@ -5,6 +5,7 @@ import socket
|
|||
|
||||
from gi.repository import Gtk, Adw
|
||||
|
||||
|
||||
class BaseOpenAIProvider(BavarderProvider):
|
||||
name = None
|
||||
slug = None
|
||||
|
@ -16,7 +17,9 @@ class BaseOpenAIProvider(BavarderProvider):
|
|||
|
||||
def ask(self, prompt):
|
||||
try:
|
||||
response = self.chat.create(model=self.model, messages=[{"role": "user", "content": prompt}])
|
||||
response = self.chat.create(
|
||||
model=self.model, messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
response = response.choices[0].message.content
|
||||
except openai.error.AuthenticationError:
|
||||
self.no_api_key()
|
||||
|
@ -56,7 +59,6 @@ class BaseOpenAIProvider(BavarderProvider):
|
|||
print(api_key)
|
||||
openai.api_key = api_key
|
||||
|
||||
|
||||
def about(self):
|
||||
about = Adw.AboutWindow(
|
||||
transient_for=self.props.active_window,
|
||||
|
@ -67,11 +69,9 @@ class BaseOpenAIProvider(BavarderProvider):
|
|||
version=version,
|
||||
copyright="© 2023 0xMRTT",
|
||||
)
|
||||
|
||||
|
||||
def save(self):
|
||||
return {
|
||||
"api_key": openai.api_key
|
||||
}
|
||||
return {"api_key": openai.api_key}
|
||||
|
||||
def load(self, data):
|
||||
openai.api_key = data["api_key"]
|
||||
openai.api_key = data["api_key"]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .openai import BaseOpenAIProvider
|
||||
|
||||
|
||||
class OpenAIGPT35TurboProvider(BaseOpenAIProvider):
|
||||
name = "OpenAI GPT 3.5 Turbo"
|
||||
slug = "openaigpt35turbo"
|
||||
model = "gpt-3.5-turbo"
|
||||
model = "gpt-3.5-turbo"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .openai import BaseOpenAIProvider
|
||||
|
||||
|
||||
class OpenAIGPT4Provider(BaseOpenAIProvider):
|
||||
name = "OpenAI GPT 4"
|
||||
slug = "openaigpt4"
|
||||
model = "gpt-4"
|
||||
model = "gpt-4"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .openai import BaseOpenAIProvider
|
||||
|
||||
|
||||
class OpenAITextDavinci003(BaseOpenAIProvider):
|
||||
name = "OpenAI Text Davinci 003"
|
||||
slug = "openaitextdavinci003"
|
||||
model = "text-davinci-003"
|
||||
model = "text-davinci-003"
|
||||
|
|
Loading…
Reference in a new issue