Hide stderr output
This commit is contained in:
parent
404bf34696
commit
fd89bcb232
54
gpt.py
Normal file
54
gpt.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from os import chmod
|
||||
from pathlib import Path
|
||||
from socket import AF_UNIX
|
||||
from socketserver import UnixStreamServer
|
||||
from urllib.parse import unquote
|
||||
|
||||
from torch import float16
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/21650370/setting-up-an-http-server-that-listens-over-a-file-socket
|
||||
class UnixHTTPServer(UnixStreamServer):
|
||||
def get_request(self):
|
||||
request, client_address = super(UnixHTTPServer, self).get_request()
|
||||
return (request, ["local", 0])
|
||||
|
||||
|
||||
class textgenHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
prompt = unquote(self.path[1:])
|
||||
print('Prompt')
|
||||
print(prompt)
|
||||
|
||||
if prompt == 'favicon.ico':
|
||||
return
|
||||
|
||||
input = tokenizer.encode(prompt, return_tensors='pt').to('cuda')
|
||||
output = tokenizer.decode(model.generate(
|
||||
input, do_sample=True, max_length=500, top_p=0.9)[0])
|
||||
print(output)
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'text/plain')
|
||||
self.send_header('Content-Length', str(len(output)))
|
||||
self.end_headers()
|
||||
self.wfile.write(output.encode('utf-8'))
|
||||
|
||||
|
||||
# Load model
|
||||
print('Loading model')
|
||||
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-2.7B')
|
||||
model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-2.7B',
|
||||
torch_dtype=float16, low_cpu_mem_usage=True).to('cuda')
|
||||
|
||||
|
||||
# Create and start server
|
||||
print('Starting server')
|
||||
path = '/srv/http/pages/textgen'
|
||||
Path(path).unlink(missing_ok=True)
|
||||
server = UnixHTTPServer(path, textgenHandler)
|
||||
chmod(path, 666)
|
||||
print('Server ready')
|
||||
server.serve_forever()
|
50
llama.py
50
llama.py
|
@ -1,50 +0,0 @@
|
|||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import threading
|
||||
from flask import Flask, Response
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/<prompt>")
|
||||
def llama(prompt):
|
||||
if prompt == "favicon.ico":
|
||||
return Response(status=204)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"/opt/llama.cpp/main",
|
||||
"-ngl",
|
||||
"32",
|
||||
"-m",
|
||||
"/opt/llama.cpp/models/ggml-vicuna-7b-1.1-q4_0.bin",
|
||||
"-n",
|
||||
"1024",
|
||||
"-p",
|
||||
f"### Human: {prompt}\n### Assistant:",
|
||||
],
|
||||
stderr=subprocess.STDOUT,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
for c in iter(lambda: process.stdout.read(1), b""):
|
||||
yield c
|
||||
finally:
|
||||
process.kill()
|
||||
|
||||
return Response(generate(), mimetype="text/plain")
|
||||
|
||||
|
||||
path = "/srv/http/pages/textgen"
|
||||
|
||||
|
||||
def fixperms():
|
||||
time.sleep(0.1)
|
||||
os.chmod(path, 660)
|
||||
|
||||
|
||||
threading.Thread(target=fixperms).start()
|
||||
|
||||
app.run(host="unix://" + path)
|
85
main.py
85
main.py
|
@ -1,54 +1,49 @@
|
|||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from os import chmod
|
||||
from pathlib import Path
|
||||
from socket import AF_UNIX
|
||||
from socketserver import UnixStreamServer
|
||||
from urllib.parse import unquote
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import threading
|
||||
from flask import Flask, Response
|
||||
|
||||
from torch import float16
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/21650370/setting-up-an-http-server-that-listens-over-a-file-socket
|
||||
class UnixHTTPServer(UnixStreamServer):
|
||||
def get_request(self):
|
||||
request, client_address = super(UnixHTTPServer, self).get_request()
|
||||
return (request, ["local", 0])
|
||||
@app.route("/<prompt>")
|
||||
def llama(prompt):
|
||||
if prompt == "favicon.ico":
|
||||
return Response(status=204)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"/opt/llama.cpp/main",
|
||||
"-ngl",
|
||||
"32",
|
||||
"-m",
|
||||
"/opt/llama.cpp/models/ggml-vicuna-7b-1.1-q4_0.bin",
|
||||
"-n",
|
||||
"1024",
|
||||
"-p",
|
||||
f"### Human: {prompt}\n### Assistant:",
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
for c in iter(lambda: process.stdout.read(1), b""):
|
||||
yield c
|
||||
finally:
|
||||
process.kill()
|
||||
|
||||
return Response(generate(), mimetype="text/plain")
|
||||
|
||||
|
||||
class textgenHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
prompt = unquote(self.path[1:])
|
||||
print('Prompt')
|
||||
print(prompt)
|
||||
|
||||
if prompt == 'favicon.ico':
|
||||
return
|
||||
|
||||
input = tokenizer.encode(prompt, return_tensors='pt').to('cuda')
|
||||
output = tokenizer.decode(model.generate(
|
||||
input, do_sample=True, max_length=500, top_p=0.9)[0])
|
||||
print(output)
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'text/plain')
|
||||
self.send_header('Content-Length', str(len(output)))
|
||||
self.end_headers()
|
||||
self.wfile.write(output.encode('utf-8'))
|
||||
path = "/srv/http/pages/textgen"
|
||||
|
||||
|
||||
# Load model
|
||||
print('Loading model')
|
||||
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-2.7B')
|
||||
model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-2.7B',
|
||||
torch_dtype=float16, low_cpu_mem_usage=True).to('cuda')
|
||||
def fixperms():
|
||||
time.sleep(0.1)
|
||||
os.chmod(path, 660)
|
||||
|
||||
|
||||
# Create and start server
|
||||
print('Starting server')
|
||||
path = '/srv/http/pages/textgen'
|
||||
Path(path).unlink(missing_ok=True)
|
||||
server = UnixHTTPServer(path, textgenHandler)
|
||||
chmod(path, 666)
|
||||
print('Server ready')
|
||||
server.serve_forever()
|
||||
threading.Thread(target=fixperms).start()
|
||||
|
||||
app.run(host="unix://" + path)
|
||||
|
|
Loading…
Reference in a new issue