ChatGLM6B / main.py
sea9696's picture
Duplicate from josStorer/ChatGLM-6B-Int4-API-OpenAI-Compatible
d1d6c39
import json
from typing import List
import torch
from fastapi import FastAPI, Request, status, HTTPException
from pydantic import BaseModel
from torch.cuda import get_device_properties
from transformers import AutoModel, AutoTokenizer
from sse_starlette.sse import EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import os
os.environ['TRANSFORMERS_CACHE'] = ".cache"
bits = 4
kernel_path = "models/models--silver--chatglm-6b-int4-slim/quantization_kernels.so"
model_path = "./models/models--silver--chatglm-6b-int4-slim/snapshots/02e096b3805c579caf5741a6d8eddd5ba7a74e0d"
cache_dir = './models'
model_name = 'chatglm-6b-int4'
min_memory = 5.5
tokenizer = None
model = None
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event('startup')
def init():
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, cache_dir=cache_dir)
if torch.cuda.is_available() and get_device_properties(0).total_memory / 1024 ** 3 > min_memory:
model = model.half().quantize(bits=bits).cuda()
print("Using GPU")
else:
model = model.float().quantize(bits=bits)
if torch.cuda.is_available():
print("Total Memory: ", get_device_properties(0).total_memory / 1024 ** 3)
else:
print("No GPU available")
print("Using CPU")
model = model.eval()
if os.environ.get("ngrok_token") is not None:
ngrok_connect()
class Message(BaseModel):
role: str
content: str
class Body(BaseModel):
messages: List[Message]
model: str
stream: bool
max_tokens: int
@app.get("/")
def read_root():
return {"Hello": "World!"}
@app.post("/chat/completions")
async def completions(body: Body, request: Request):
if not body.stream or body.model != model_name:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Not Implemented")
question = body.messages[-1]
if question.role == 'user':
question = question.content
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")
user_question = ''
history = []
for message in body.messages:
if message.role == 'user':
user_question = message.content
elif message.role == 'system' or message.role == 'assistant':
assistant_answer = message.content
history.append((user_question, assistant_answer))
async def event_generator():
for response in model.stream_chat(tokenizer, question, history, max_length=max(2048, body.max_tokens)):
if await request.is_disconnected():
return
yield json.dumps({"response": response[0]})
yield "[DONE]"
return EventSourceResponse(event_generator())
def ngrok_connect():
from pyngrok import ngrok, conf
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
ngrok.set_auth_token(os.environ["ngrok_token"])
http_tunnel = ngrok.connect(8000)
print(http_tunnel.public_url)
if __name__ == "__main__":
uvicorn.run("main:app", reload=True, app_dir=".")