diff --git a/main.py b/main.py index 13b3ee4..edad4a8 100755 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ from socket import AF_UNIX from socketserver import UnixStreamServer from urllib.parse import unquote +from torch import float16 from transformers import AutoModelForCausalLM, AutoTokenizer @@ -20,9 +21,12 @@ class UnixHTTPServer(UnixStreamServer): class textgenHandler(BaseHTTPRequestHandler): def do_GET(self): prompt = unquote(self.path[1:]) + print('Prompt') + print(prompt) input = tokenizer.encode(prompt, return_tensors='pt').to('cuda') output = tokenizer.decode(model.generate( - input, do_sample=True, max_length=500, top_p=0.9)[0]) + input, do_sample=True, max_length=1000, top_p=0.9)[0]) + print(output) self.send_response(200) self.send_header('Content-Type', 'text/plain') @@ -32,14 +36,17 @@ class textgenHandler(BaseHTTPRequestHandler): # Load model +print('Loading model') tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-2.7B') model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-2.7B', torch_dtype=float16, low_cpu_mem_usage=True).to('cuda') # Create and start server +print('Starting server') path = '/srv/http/pages/textgen' Path(path).unlink(missing_ok=True) server = UnixHTTPServer(path, textgenHandler) chmod(path, 666) +print('Server ready') server.serve_forever()