Print debugging output
This commit is contained in:
parent
ede5a0a5a9
commit
d55f6163b6
9
main.py
9
main.py
|
@ -7,6 +7,7 @@ from socket import AF_UNIX
|
|||
from socketserver import UnixStreamServer
|
||||
from urllib.parse import unquote
|
||||
|
||||
from torch import float16
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
|
@ -20,9 +21,12 @@ class UnixHTTPServer(UnixStreamServer):
|
|||
class textgenHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
prompt = unquote(self.path[1:])
|
||||
print('Prompt')
|
||||
print(prompt)
|
||||
input = tokenizer.encode(prompt, return_tensors='pt').to('cuda')
|
||||
output = tokenizer.decode(model.generate(
|
||||
input, do_sample=True, max_length=500, top_p=0.9)[0])
|
||||
input, do_sample=True, max_length=1000, top_p=0.9)[0])
|
||||
print(output)
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'text/plain')
|
||||
|
@ -32,14 +36,17 @@ class textgenHandler(BaseHTTPRequestHandler):
|
|||
|
||||
|
||||
# Load model
|
||||
print('Loading model')
|
||||
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-2.7B')
|
||||
model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-2.7B',
|
||||
torch_dtype=float16, low_cpu_mem_usage=True).to('cuda')
|
||||
|
||||
|
||||
# Create and start server
|
||||
print('Starting server')
|
||||
path = '/srv/http/pages/textgen'
|
||||
Path(path).unlink(missing_ok=True)
|
||||
server = UnixHTTPServer(path, textgenHandler)
|
||||
chmod(path, 666)
|
||||
print('Server ready')
|
||||
server.serve_forever()
|
||||
|
|
Loading…
Reference in a new issue