format: with ruff and black

This commit is contained in:
0xMRTT 2023-05-01 01:45:31 +02:00
parent f47b690f5b
commit 5a855e3236
Signed by: 0xmrtt
GPG key ID: 19C1449A774028BD
18 changed files with 131 additions and 121 deletions

View file

@ -21,13 +21,12 @@ import sys
import gi
import sys
import threading
import socket
import json
gi.require_version("Gtk", "4.0")
gi.require_version("Adw", "1")
gi.require_version("Gdk", "4.0")
gi.require_version('Gst', '1.0')
gi.require_version("Gst", "1.0")
from gi.repository import Gtk, Gio, Adw, Gdk, GLib, Gst
from .window import BavarderWindow
@ -35,12 +34,12 @@ from .preferences import Preferences
from .constants import app_id, version
from hgchat import HGChat
from gtts import gTTS
from tempfile import NamedTemporaryFile
from .provider import PROVIDERS
class BavarderApplication(Adw.Application):
"""The main application singleton class."""
@ -65,15 +64,17 @@ class BavarderApplication(Adw.Application):
self.enabled_providers = set(self.settings.get_strv("enabled-providers"))
self.latest_provider = self.settings.get_string("latest-provider")
self.latest_provider = "huggingchat"
# GStreamer playbin object and related setup
Gst.init(None)
self.player = Gst.ElementFactory.make('playbin', 'player')
self.player = Gst.ElementFactory.make("playbin", "player")
self.pipeline = Gst.Pipeline()
# bus = self.player.get_bus()
# bus.add_signal_watch()
# bus.connect('message', self.on_gst_message)
self.player_event = threading.Event() # An event for letting us know when Gst is done playing
self.player_event = (
threading.Event()
) # An event for letting us know when Gst is done playing
def on_quit(self, action, param):
"""Called when the user activates the Quit action."""
@ -114,17 +115,21 @@ class BavarderApplication(Adw.Application):
print(self.providers_data)
print(self.enabled_providers)
for provider, i in zip(self.enabled_providers, range(len(self.enabled_providers))):
for provider, i in zip(
self.enabled_providers, range(len(self.enabled_providers))
):
try:
self.provider_selector_model.append(PROVIDERS[provider].name)
self.providers[i] = PROVIDERS[provider](self.win, self, self.providers_data[i])
self.providers[i] = PROVIDERS[provider](
self.win, self, self.providers_data[i]
)
except KeyError:
self.providers[i] = PROVIDERS[provider](self.win, self, None)
self.win.provider_selector.set_model(self.provider_selector_model)
self.win.provider_selector.connect('notify', self.on_provider_selector_notify)
self.win.provider_selector.connect("notify", self.on_provider_selector_notify)
for k, p in self.providers.items():
if p.slug == self.latest_provider:
self.win.provider_selector.set_selected(k)
@ -252,13 +257,11 @@ class BavarderApplication(Adw.Application):
print(exc)
def _play_audio(self, path):
uri = 'file://' + path
self.player.set_property('uri', uri)
uri = "file://" + path
self.player.set_property("uri", uri)
self.pipeline.add(self.player)
self.pipeline.set_state(Gst.State.PLAYING)
self.player.set_state(Gst.State.PLAYING)
def on_listen_action(self, widget, _):
"""Callback for the app.listen action."""

View file

@ -35,5 +35,5 @@ class Preferences(Adw.PreferencesWindow):
for provider in self.app.providers.values():
try:
self.provider_group.add(provider.preferences())
except TypeError: # no prefs
pass
except TypeError: # no prefs
pass

View file

@ -12,16 +12,16 @@ from .hfgpt2 import HuggingFaceGPT2Provider
from .hfdialogpt import HuggingFaceDialoGPTLargeProvider
PROVIDERS = {
'alpacalora': AlpacaLoRAProvider,
'baichat': BAIChatProvider,
'catgpt': CatGPTProvider,
'hfdialogpt': HuggingFaceDialoGPTLargeProvider,
'hfgoogleflant5xxl': HuggingFaceGoogleFlanT5XXLProvider,
'hfgoogleflanu12': HuggingFaceGoogleFlanU12Provider,
'hfgpt2': HuggingFaceGPT2Provider,
'hfopenassistantsft1pythia12b': HuggingFaceOpenAssistantSFT1PythiaProvider,
'huggingchat': HuggingChatProvider,
'openaigpt35turbo': OpenAIGPT35TurboProvider,
'openaigpt4': OpenAIGPT4Provider,
'openaitextdavinci003': OpenAITextDavinci003,
}
"alpacalora": AlpacaLoRAProvider,
"baichat": BAIChatProvider,
"catgpt": CatGPTProvider,
"hfdialogpt": HuggingFaceDialoGPTLargeProvider,
"hfgoogleflant5xxl": HuggingFaceGoogleFlanT5XXLProvider,
"hfgoogleflanu12": HuggingFaceGoogleFlanU12Provider,
"hfgpt2": HuggingFaceGPT2Provider,
"hfopenassistantsft1pythia12b": HuggingFaceOpenAssistantSFT1PythiaProvider,
"huggingchat": HuggingChatProvider,
"openaigpt35turbo": OpenAIGPT35TurboProvider,
"openaigpt4": OpenAIGPT4Provider,
"openaitextdavinci003": OpenAITextDavinci003,
}

View file

@ -4,6 +4,8 @@ import socket
import requests
from gi.repository import Gtk, Adw
class AlpacaLoRAProvider(BavarderProvider):
name = "Alpaca-LoRA"
slug = "alpacalora"
@ -11,20 +13,22 @@ class AlpacaLoRAProvider(BavarderProvider):
def __init__(self, win, app, *args, **kwargs):
super().__init__(win, app, *args, **kwargs)
def ask(self, prompt):
try:
response = requests.post("https://tloen-alpaca-lora.hf.space/run/predict", json={
"data": [
prompt,
prompt,
0.1,
0.75,
40,
4,
128,
]
}).json()
response = requests.post(
"https://tloen-alpaca-lora.hf.space/run/predict",
json={
"data": [
prompt,
prompt,
0.1,
0.75,
40,
4,
128,
]
},
).json()
except socket.gaierror:
self.no_connection()
return ""
@ -40,7 +44,7 @@ class AlpacaLoRAProvider(BavarderProvider):
def preferences(self):
self.no_preferences()
def about(self):
about = Adw.AboutWindow(
transient_for=self.props.active_window,
@ -51,9 +55,9 @@ class AlpacaLoRAProvider(BavarderProvider):
version=version,
copyright="© 2023 0xMRTT",
)
def save(self):
return {}
def load(self, data):
pass
pass

View file

@ -3,9 +3,11 @@ from .base import BavarderProvider
from baichat_py import BAIChat
import socket
class BAIChatProvider(BavarderProvider):
name = "BAI Chat"
slug = "baichat"
def __init__(self, win, app, *args, **kwargs):
super().__init__(win, app, *args, **kwargs)
self.chat = BAIChat(sync=True)
@ -30,7 +32,7 @@ class BAIChatProvider(BavarderProvider):
def preferences(self):
self.no_preferences()
def about(self):
about = Adw.AboutWindow(
transient_for=self.props.active_window,
@ -46,4 +48,4 @@ class BAIChatProvider(BavarderProvider):
return {}
def load(self, data):
pass
pass

View file

@ -1,7 +1,7 @@
class BavarderProvider:
name = None
slug = None
def __init__(self, win, app, data, *args, **kwargs):
self.win = win
self.banner = win.banner
@ -30,7 +30,9 @@ class BavarderProvider:
if title:
self.win.banner.props.title = title
else:
self.win.banner.props.title = "No API key provided, you can provide one in settings"
self.win.banner.props.title = (
"No API key provided, you can provide one in settings"
)
self.win.banner.props.button_label = "Open settings"
self.win.banner.connect("button-clicked", self.app.on_preferences_action)
self.win.banner.set_revealed(True)
@ -53,4 +55,4 @@ class BavarderProvider:
raise NotImplementedError()
def load(self, data):
raise NotImplementedError()
raise NotImplementedError()

View file

@ -1,7 +1,7 @@
from .base import BavarderProvider
from random import choice, randint
import string
class CatGPTProvider(BavarderProvider):
name = "CatGPT"
@ -12,35 +12,34 @@ class CatGPTProvider(BavarderProvider):
self.chat = None
def ask(self, prompt):
return ' '.join([
self.pick_generator()()
for i in range(randint(1, 12))
])
return " ".join([self.pick_generator()() for i in range(randint(1, 12))])
def pick_generator(self):
if randint(1, 15) == 1:
return choice([
lambda: "ня" * randint(1, 4),
lambda: "ニャン" * randint(1, 4),
lambda: "" * randint(1, 4),
lambda: "ña" * randint(1, 4),
lambda: "ڽا" * randint(1, 4),
lambda: "ম্যাও" * randint(1, 4)
])
return choice([
lambda: 'meow' * randint(1, 3),
lambda: 'mew' * randint(1, 3),
lambda: 'miau' * randint(1, 3),
lambda: 'miaou' * randint(1, 3),
lambda: 'miao' * randint(1, 3),
lambda: 'nya' * randint(1, 3),
lambda: 'm' + 'r' * randint(1, 6) + 'p',
lambda: 'pur' + 'r' * randint(1, 6),
lambda: 'nya' * randint(1, 3) + 'ny' + 'a' * randint(1, 10),
])
return choice(
[
lambda: "ня" * randint(1, 4),
lambda: "ニャン" * randint(1, 4),
lambda: "" * randint(1, 4),
lambda: "ña" * randint(1, 4),
lambda: "ڽا" * randint(1, 4),
lambda: "ম্যাও" * randint(1, 4),
]
)
return choice(
[
lambda: "meow" * randint(1, 3),
lambda: "mew" * randint(1, 3),
lambda: "miau" * randint(1, 3),
lambda: "miaou" * randint(1, 3),
lambda: "miao" * randint(1, 3),
lambda: "nya" * randint(1, 3),
lambda: "m" + "r" * randint(1, 6) + "p",
lambda: "pur" + "r" * randint(1, 6),
lambda: "nya" * randint(1, 3) + "ny" + "a" * randint(1, 10),
]
)
@property
def require_api_key(self):
@ -48,7 +47,7 @@ class CatGPTProvider(BavarderProvider):
def preferences(self):
self.no_preferences()
def about(self):
about = Adw.AboutWindow(
transient_for=self.props.active_window,
@ -64,4 +63,4 @@ class CatGPTProvider(BavarderProvider):
return {}
def load(self, data):
pass
pass

View file

@ -3,6 +3,7 @@ import json
import socket
import requests
class HuggingFaceDialoGPTLargeProvider(BaseHFProvider):
name = "DialoGPT"
slug = "dialogpt"
@ -11,16 +12,16 @@ class HuggingFaceDialoGPTLargeProvider(BaseHFProvider):
def ask(self, prompt):
try:
payload = json.dumps({
"inputs": {
#"past_user_inputs": ["Which movie is the best ?"],
#"generated_responses": ["It's Die Hard for sure."],
"text": prompt
},
})
headers = {
'Content-Type': 'application/json'
}
payload = json.dumps(
{
"inputs": {
# "past_user_inputs": ["Which movie is the best ?"],
# "generated_responses": ["It's Die Hard for sure."],
"text": prompt
},
}
)
headers = {"Content-Type": "application/json"}
if self.authorization:
headers["Authorization"] = f"Bearer {self.api_key}"
url = f"https://api-inference.huggingface.co/models/{self.model}"
@ -41,4 +42,4 @@ class HuggingFaceDialoGPTLargeProvider(BaseHFProvider):
self.hide_banner()
print(response)
self.update_response(response)
return response
return response

View file

@ -1,7 +1,8 @@
from .huggingface import BaseHFProvider
class HuggingFaceGoogleFlanT5XXLProvider(BaseHFProvider):
name = "Google Flan T5 XXL"
slug = "hfgoogleflant5xxl"
model = "google/flan-t5-xxl"
authorization = False
authorization = False

View file

@ -1,7 +1,8 @@
from .huggingface import BaseHFProvider
class HuggingFaceGoogleFlanU12Provider(BaseHFProvider):
name = "Google Flan U12"
slug = "hfgoogleflanu12"
model = "google/flan-ul2"
authorization = False
authorization = False

View file

@ -1,7 +1,8 @@
from .huggingface import BaseHFProvider
class HuggingFaceGPT2Provider(BaseHFProvider):
name = "GPT 2"
slug = "gpt2"
model = "gpt2"
authorization = False
authorization = False

View file

@ -1,7 +1,8 @@
from .huggingface import BaseHFProvider
class HuggingFaceOpenAssistantSFT1PythiaProvider(BaseHFProvider):
name = "Open-Assistant SFT-1 12B Model "
slug = "hfopenassistantsft1pythia12b"
model = "OpenAssistant/oasst-sft-1-pythia-12b"
authorization = False
authorization = False

View file

@ -5,6 +5,8 @@ import socket
from gi.repository import Gtk, Adw
class HuggingChatProvider(BavarderProvider):
name = "Hugging Chat"
slug = "huggingchat"
@ -37,7 +39,7 @@ class HuggingChatProvider(BavarderProvider):
def preferences(self):
self.no_preferences()
def about(self):
about = Adw.AboutWindow(
transient_for=self.props.active_window,
@ -48,9 +50,9 @@ class HuggingChatProvider(BavarderProvider):
version=version,
copyright="© 2023 0xMRTT",
)
def save(self):
return {}
def load(self, data):
pass
pass

View file

@ -1,15 +1,12 @@
import requests
import json
url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
from .base import BavarderProvider
import socket
from gi.repository import Gtk, Adw
class BaseHFProvider(BavarderProvider):
name = None
slug = None
@ -22,12 +19,8 @@ class BaseHFProvider(BavarderProvider):
def ask(self, prompt):
try:
payload = json.dumps({
"inputs": prompt
})
headers = {
'Content-Type': 'application/json'
}
payload = json.dumps({"inputs": prompt})
headers = {"Content-Type": "application/json"}
if self.authorization:
headers["Authorization"] = f"Bearer {self.api_key}"
url = f"https://api-inference.huggingface.co/models/{self.model}"
@ -75,7 +68,6 @@ class BaseHFProvider(BavarderProvider):
self.api_key = self.api_row.get_text()
print(self.api_key)
def about(self):
about = Adw.AboutWindow(
transient_for=self.props.active_window,
@ -86,14 +78,12 @@ class BaseHFProvider(BavarderProvider):
version=version,
copyright="© 2023 0xMRTT",
)
def save(self):
if self.authorization:
return {
"api_key": self.api_key
}
return {"api_key": self.api_key}
return {}
def load(self, data):
if self.authorization:
self.api_key = data["api_key"]
self.api_key = data["api_key"]

View file

@ -5,6 +5,7 @@ import socket
from gi.repository import Gtk, Adw
class BaseOpenAIProvider(BavarderProvider):
name = None
slug = None
@ -16,7 +17,9 @@ class BaseOpenAIProvider(BavarderProvider):
def ask(self, prompt):
try:
response = self.chat.create(model=self.model, messages=[{"role": "user", "content": prompt}])
response = self.chat.create(
model=self.model, messages=[{"role": "user", "content": prompt}]
)
response = response.choices[0].message.content
except openai.error.AuthenticationError:
self.no_api_key()
@ -56,7 +59,6 @@ class BaseOpenAIProvider(BavarderProvider):
print(api_key)
openai.api_key = api_key
def about(self):
about = Adw.AboutWindow(
transient_for=self.props.active_window,
@ -67,11 +69,9 @@ class BaseOpenAIProvider(BavarderProvider):
version=version,
copyright="© 2023 0xMRTT",
)
def save(self):
return {
"api_key": openai.api_key
}
return {"api_key": openai.api_key}
def load(self, data):
openai.api_key = data["api_key"]
openai.api_key = data["api_key"]

View file

@ -1,6 +1,7 @@
from .openai import BaseOpenAIProvider
class OpenAIGPT35TurboProvider(BaseOpenAIProvider):
name = "OpenAI GPT 3.5 Turbo"
slug = "openaigpt35turbo"
model = "gpt-3.5-turbo"
model = "gpt-3.5-turbo"

View file

@ -1,6 +1,7 @@
from .openai import BaseOpenAIProvider
class OpenAIGPT4Provider(BaseOpenAIProvider):
name = "OpenAI GPT 4"
slug = "openaigpt4"
model = "gpt-4"
model = "gpt-4"

View file

@ -1,6 +1,7 @@
from .openai import BaseOpenAIProvider
class OpenAITextDavinci003(BaseOpenAIProvider):
name = "OpenAI Text Davinci 003"
slug = "openaitextdavinci003"
model = "text-davinci-003"
model = "text-davinci-003"