Add provider type to chat providers
This commit is contained in:
parent
0123a80883
commit
02b02edbbb
5 changed files with 10 additions and 9 deletions
|
@ -1,7 +1,7 @@
|
|||
from .hfbasechat import BaseHFChatProvider
|
||||
from .hfbasechat import BaseHFChatProvider, ProviderType
|
||||
|
||||
class BlenderBotProvider(BaseHFChatProvider):
|
||||
name = "BlenderBot"
|
||||
description = "An open domain chatbot"
|
||||
provider = "facebook/blenderbot-400M-distill"
|
||||
|
||||
provider_type = ProviderType.TEXT
|
|
@ -1,6 +1,7 @@
|
|||
from .hfbasechat import BaseHFChatProvider
|
||||
from .hfbasechat import BaseHFChatProvider, ProviderType
|
||||
|
||||
class DialoGPTProvider(BaseHFChatProvider):
|
||||
name = "DialoGPT"
|
||||
description = "A State-of-the-Art Large-scale Pretrained Response generation model"
|
||||
provider = "microsoft/DialoGPT-large"
|
||||
provider_type = ProviderType.CHAT
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from .hfbasechat import BaseHFChatProvider
|
||||
from .hfbasechat import BaseHFChatProvider, ProviderType
|
||||
|
||||
class GoogleFlant5XXLProvider(BaseHFChatProvider):
|
||||
name = "Google Flan T5 XXL"
|
||||
description = "A better Text-To-Text Transfer Transformer (T5) model"
|
||||
provider = "google/flan-t5-xxl"
|
||||
chat_mode = False
|
||||
provider_type = ProviderType.TEXT
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from .hfbasechat import BaseHFChatProvider
|
||||
from .hfbasechat import BaseHFChatProvider, ProviderType
|
||||
|
||||
class GPT2Provider(BaseHFChatProvider):
|
||||
name = "GPT 2"
|
||||
description = "GPT-2 is a transformers model pretrained on a very large corpus of English data in a self-supervised fashion"
|
||||
provider = "gpt2"
|
||||
chat_mode = False
|
||||
provider_type = ProviderType.TEXT
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .base import BaseProvider
|
||||
from .base import BaseProvider, ProviderType
|
||||
|
||||
import requests
|
||||
|
||||
|
@ -23,7 +23,7 @@ class BaseHFChatProvider(BaseProvider):
|
|||
|
||||
return response.json()
|
||||
|
||||
if self.chat_mode:
|
||||
if self.provider_type == ProviderType.CHAT:
|
||||
output = query({
|
||||
"inputs": {
|
||||
"past_user_inputs": [i['content'] for i in chat if i['role'] == self.app.user_name],
|
||||
|
|
Loading…
Reference in a new issue