Load and run model
This commit is contained in:
parent
0b70a7303f
commit
ede5a0a5a9
21
main.py
21
main.py
|
@ -5,6 +5,9 @@ from os import chmod
|
|||
from pathlib import Path
|
||||
from socket import AF_UNIX
|
||||
from socketserver import UnixStreamServer
|
||||
from urllib.parse import unquote
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/21650370/setting-up-an-http-server-that-listens-over-a-file-socket
|
||||
|
@ -16,15 +19,25 @@ class UnixHTTPServer(UnixStreamServer):
|
|||
|
||||
class textgenHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
response = 'Hello' + self.path
|
||||
|
||||
prompt = unquote(self.path[1:])
|
||||
input = tokenizer.encode(prompt, return_tensors='pt').to('cuda')
|
||||
output = tokenizer.decode(model.generate(
|
||||
input, do_sample=True, max_length=500, top_p=0.9)[0])
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'text/plain')
|
||||
self.send_header('Content-Length', str(len(response)))
|
||||
self.send_header('Content-Length', str(len(output)))
|
||||
self.end_headers()
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
self.wfile.write(output.encode('utf-8'))
|
||||
|
||||
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-2.7B')
|
||||
model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-2.7B',
|
||||
torch_dtype=float16, low_cpu_mem_usage=True).to('cuda')
|
||||
|
||||
|
||||
# Create and start server
|
||||
path = '/srv/http/pages/textgen'
|
||||
Path(path).unlink(missing_ok=True)
|
||||
server = UnixHTTPServer(path, textgenHandler)
|
||||
|
|
Loading…
Reference in a new issue