Dynavera/gpu_server.py
2026-03-22 15:34:06 +00:00

317 lines
No EOL
11 KiB
Python

import asyncio
import logging
import os
import json
import time
from contextlib import asynccontextmanager
from typing import Dict, Any
import numpy as np
from torch import cuda, no_grad, Tensor
import torch.nn.functional as F
import secrets
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from llama_cpp import Llama
from sentence_transformers import SentenceTransformer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger("gpu-node")
EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf")
LLM_IDLE_TIMEOUT = int(os.getenv("LLM_IDLE_TIMEOUT", "1800"))
TARGET_DIMENSIONS = 768
state: Dict[str, Any] = {}
gpu_semaphore = asyncio.Semaphore(1)
def _load_llm() -> Llama:
logger.info(f"Loading LLM: {LLM_MODEL_PATH}")
return Llama(model_path=LLM_MODEL_PATH, n_gpu_layers=-1, n_ctx=8192, n_batch=512, verbose=False)
def _unload_llm():
llm = state.pop("llm", None)
del llm
if cuda.is_available():
cuda.empty_cache()
logger.info("LLM unloaded due to inactivity")
async def _inactivity_watcher():
while True:
await asyncio.sleep(60)
llm = state.get("llm")
last_used = state.get("llm_last_used")
if llm is not None and last_used is not None:
if time.monotonic() - last_used > LLM_IDLE_TIMEOUT:
async with gpu_semaphore:
_unload_llm()
def _touch_llm():
state["llm_last_used"] = time.monotonic()
async def _ensure_llm() -> Llama:
llm = state.get("llm")
if llm is None:
if not os.path.exists(LLM_MODEL_PATH):
raise HTTPException(status_code=503, detail="LLM model file not found.")
loop = asyncio.get_event_loop()
state["llm"] = await loop.run_in_executor(None, _load_llm)
_touch_llm()
return state["llm"]
@asynccontextmanager
async def lifespan(app: FastAPI):
device = "cuda" if cuda.is_available() else "cpu"
logger.info(f"--- Initializing GPU Node on {device} ---")
if device == "cpu":
logger.warning("CUDA NOT DETECTED. Performance will be severely degraded.")
try:
logger.info(f"Loading Embedding Model: {EMBED_MODEL_NAME}")
state["embed_model"] = SentenceTransformer(
EMBED_MODEL_NAME,
trust_remote_code=True,
device=device
)
if not os.path.exists(LLM_MODEL_PATH):
logger.warning(f"LLM file not found at {LLM_MODEL_PATH} — will load on first request")
else:
state["llm"] = _load_llm()
_touch_llm()
logger.info(f"--- GPU Node Ready (LLM idle timeout: {LLM_IDLE_TIMEOUT}s) ---")
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise e
watcher = asyncio.create_task(_inactivity_watcher())
yield
watcher.cancel()
state.clear()
if cuda.is_available():
cuda.empty_cache()
app = FastAPI(title="Agentic GPU Node", lifespan=lifespan, docs_url=None, redoc_url=None, openapi_url=None)
_security = HTTPBasic()
_API_USER = os.getenv("INFERENCE_USERNAME", "admin")
_API_PASS = os.getenv("INFERENCE_PASSWORD", "changeme")
def require_auth(credentials: HTTPBasicCredentials = Depends(_security)):
valid_user = secrets.compare_digest(credentials.username.encode(), _API_USER.encode())
valid_pass = secrets.compare_digest(credentials.password.encode(), _API_PASS.encode())
if not (valid_user and valid_pass):
raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"})
@app.get("/health", dependencies=[Depends(require_auth)])
async def health():
return {
"status": "ok",
"embedding_ready": state.get("embed_model") is not None,
"llm_ready": state.get("llm") is not None,
}
def pad_and_normalize(embeddings: Tensor, target_dimensions: int) -> Tensor:
curr_dim = embeddings.shape[1]
if curr_dim < target_dimensions:
embeddings = F.pad(embeddings, (0, target_dimensions - curr_dim), "constant", 0)
elif curr_dim > target_dimensions:
embeddings = embeddings[:, :target_dimensions]
return F.normalize(embeddings, p=2, dim=1)
@app.post("/v1/embeddings", dependencies=[Depends(require_auth)])
async def embeddings(request: Request):
data = await request.json()
input_data = data.get("input", "")
input_kind = type(input_data).__name__
input_count = len(input_data) if isinstance(input_data, list) else (1 if isinstance(input_data, str) else 0)
logger.info("/v1/embeddings request received: input_kind=%s input_count=%s", input_kind, input_count)
logger.info("/v1/embeddings using target_dimensions=%s", TARGET_DIMENSIONS)
if isinstance(input_data, str):
inputs = [input_data]
elif isinstance(input_data, list):
inputs = [str(item) for item in input_data if str(item).strip()]
else:
logger.warning("/v1/embeddings bad input type: %s", input_kind)
raise HTTPException(status_code=400, detail="'input' must be a string or list of strings")
if not inputs:
return {
"object": "list",
"data": [],
"model": EMBED_MODEL_NAME,
"usage": {"prompt_tokens": 0, "total_tokens": 0},
}
model = state.get("embed_model")
if model is None:
raise HTTPException(status_code=503, detail="Embedding model not initialized")
prefixed_inputs = [
text if text.startswith("search_") else f"search_query: {text}"
for text in inputs
]
loop = asyncio.get_event_loop()
def _encode():
with no_grad():
vectors = model.encode(prefixed_inputs, convert_to_tensor=True)
return pad_and_normalize(vectors, target_dimensions=TARGET_DIMENSIONS)
async with gpu_semaphore:
vectors = await loop.run_in_executor(None, _encode)
vector_list = vectors.cpu().tolist()
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": idx,
"embedding": embedding,
}
for idx, embedding in enumerate(vector_list)
],
"model": EMBED_MODEL_NAME,
"usage": {
"prompt_tokens": sum(len(text.split()) for text in inputs),
"total_tokens": sum(len(text.split()) for text in inputs),
},
}
@app.post("/v1/semantic-chunk", dependencies=[Depends(require_auth)])
async def semantic_chunk(request: Request):
data = await request.json()
raw_text = data.get("text", "")
threshold_percentile = data.get("threshold", 95)
raw_text_len = len(raw_text) if isinstance(raw_text, str) else -1
logger.info("/v1/semantic-chunk request received: text_len=%s threshold=%s", raw_text_len, threshold_percentile,)
logger.info("/v1/semantic-chunk using target_dimensions=%s", TARGET_DIMENSIONS)
if not raw_text:
logger.info("/v1/semantic-chunk empty text payload")
return {"chunks": [], "embeddings": []}
if len(raw_text) > 50000:
logger.warning("/v1/semantic-chunk payload too large: text_len=%s", len(raw_text))
raise HTTPException(status_code=413, detail="Text block too large. Please batch on the client.")
model = state.get("embed_model")
if model is None:
logger.error("/v1/semantic-chunk embedding model not initialized")
raise HTTPException(status_code=503, detail="Embedding model not initialized")
loop = asyncio.get_event_loop()
sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()]
def _chunk_and_embed():
if len(sentences) < 2:
single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True)
single = pad_and_normalize(single, target_dimensions=TARGET_DIMENSIONS)
return {"chunks": [raw_text], "embeddings": single.cpu().tolist()}
s_embeddings = model.encode(sentences, convert_to_tensor=True)
distances = [
1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item()
for i in range(len(s_embeddings) - 1)
]
breakpoint_threshold = np.percentile(distances, threshold_percentile)
indices = [i for i, d in enumerate(distances) if d > breakpoint_threshold]
chunks = []
start = 0
for idx in indices:
chunks.append(". ".join(sentences[start : idx + 1]) + ".")
start = idx + 1
chunks.append(". ".join(sentences[start:]) + ".")
with no_grad():
final_embeddings = model.encode(
[f"search_document: {c}" for c in chunks],
convert_to_tensor=True
)
final_embeddings = pad_and_normalize(final_embeddings, target_dimensions=TARGET_DIMENSIONS)
return {"chunks": chunks, "embeddings": final_embeddings.cpu().tolist()}
async with gpu_semaphore:
result = await loop.run_in_executor(None, _chunk_and_embed)
return result
@app.post("/v1/chat/completions", dependencies=[Depends(require_auth)])
async def chat_completions(request: Request):
try:
data = await request.json()
except Exception as e:
raw_body = await request.body()
preview = raw_body[:500].decode("utf-8", errors="replace")
logger.error(f"Invalid JSON payload for chat completions: {e}; body_preview={preview}")
raise HTTPException(status_code=400, detail="Invalid JSON payload")
messages = data.get("messages", [])
stream = data.get("stream", False)
logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}")
llm = await _ensure_llm()
loop = asyncio.get_event_loop()
temperature = data.get("temperature", 0.7)
max_tokens = data.get("max_tokens", 1024)
def _infer():
return llm.create_chat_completion(
messages=messages,
stream=False,
temperature=temperature,
max_tokens=max_tokens,
stop=["<|eot_id|>", "<|end_of_text|>"],
)
try:
if stream:
def _infer_stream():
return llm.create_chat_completion(
messages=messages,
stream=True,
temperature=temperature,
max_tokens=max_tokens,
stop=["<|eot_id|>", "<|end_of_text|>"],
)
async def _stream_response():
async with gpu_semaphore:
chunks = await loop.run_in_executor(None, lambda: list(_infer_stream()))
for chunk in chunks:
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(_stream_response(), media_type="text/event-stream")
async with gpu_semaphore:
response = await loop.run_in_executor(None, _infer)
return response
except Exception as e:
logger.error(f"Inference error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run("gpu_server:app", host="0.0.0.0", port=8001, reload=True)