Dynavera/gpu_server.py

317 lines
11 KiB
Python
Raw Normal View History

import asyncio
import logging
import os
import json
2026-03-22 15:34:06 +00:00
import time
from contextlib import asynccontextmanager
from typing import Dict, Any
import numpy as np
2026-03-11 14:37:50 +00:00
from torch import cuda, no_grad, Tensor
import torch.nn.functional as F
import secrets
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from llama_cpp import Llama
from sentence_transformers import SentenceTransformer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger("gpu-node")
EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf")
2026-03-22 15:34:06 +00:00
LLM_IDLE_TIMEOUT = int(os.getenv("LLM_IDLE_TIMEOUT", "1800"))
2026-03-11 21:33:17 +00:00
TARGET_DIMENSIONS = 768
state: Dict[str, Any] = {}
gpu_semaphore = asyncio.Semaphore(1)
2026-03-22 15:34:06 +00:00
def _load_llm() -> Llama:
logger.info(f"Loading LLM: {LLM_MODEL_PATH}")
return Llama(model_path=LLM_MODEL_PATH, n_gpu_layers=-1, n_ctx=8192, n_batch=512, verbose=False)
def _unload_llm():
llm = state.pop("llm", None)
del llm
if cuda.is_available():
cuda.empty_cache()
logger.info("LLM unloaded due to inactivity")
async def _inactivity_watcher():
while True:
await asyncio.sleep(60)
llm = state.get("llm")
last_used = state.get("llm_last_used")
if llm is not None and last_used is not None:
if time.monotonic() - last_used > LLM_IDLE_TIMEOUT:
async with gpu_semaphore:
_unload_llm()
def _touch_llm():
state["llm_last_used"] = time.monotonic()
async def _ensure_llm() -> Llama:
llm = state.get("llm")
if llm is None:
if not os.path.exists(LLM_MODEL_PATH):
raise HTTPException(status_code=503, detail="LLM model file not found.")
loop = asyncio.get_event_loop()
state["llm"] = await loop.run_in_executor(None, _load_llm)
_touch_llm()
return state["llm"]
@asynccontextmanager
async def lifespan(app: FastAPI):
2026-03-11 14:37:50 +00:00
device = "cuda" if cuda.is_available() else "cpu"
logger.info(f"--- Initializing GPU Node on {device} ---")
if device == "cpu":
logger.warning("CUDA NOT DETECTED. Performance will be severely degraded.")
try:
logger.info(f"Loading Embedding Model: {EMBED_MODEL_NAME}")
state["embed_model"] = SentenceTransformer(
EMBED_MODEL_NAME,
trust_remote_code=True,
device=device
)
if not os.path.exists(LLM_MODEL_PATH):
2026-03-22 15:34:06 +00:00
logger.warning(f"LLM file not found at {LLM_MODEL_PATH} — will load on first request")
else:
2026-03-22 15:34:06 +00:00
state["llm"] = _load_llm()
_touch_llm()
2026-03-22 15:34:06 +00:00
logger.info(f"--- GPU Node Ready (LLM idle timeout: {LLM_IDLE_TIMEOUT}s) ---")
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise e
2026-03-22 15:34:06 +00:00
watcher = asyncio.create_task(_inactivity_watcher())
yield
2026-03-22 15:34:06 +00:00
watcher.cancel()
state.clear()
2026-03-11 14:37:50 +00:00
if cuda.is_available():
cuda.empty_cache()
app = FastAPI(title="Agentic GPU Node", lifespan=lifespan, docs_url=None, redoc_url=None, openapi_url=None)
_security = HTTPBasic()
_API_USER = os.getenv("INFERENCE_USERNAME", "admin")
_API_PASS = os.getenv("INFERENCE_PASSWORD", "changeme")
def require_auth(credentials: HTTPBasicCredentials = Depends(_security)):
valid_user = secrets.compare_digest(credentials.username.encode(), _API_USER.encode())
valid_pass = secrets.compare_digest(credentials.password.encode(), _API_PASS.encode())
if not (valid_user and valid_pass):
raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"})
@app.get("/health", dependencies=[Depends(require_auth)])
async def health():
return {
"status": "ok",
"embedding_ready": state.get("embed_model") is not None,
"llm_ready": state.get("llm") is not None,
}
2026-03-11 14:37:50 +00:00
def pad_and_normalize(embeddings: Tensor, target_dimensions: int) -> Tensor:
curr_dim = embeddings.shape[1]
if curr_dim < target_dimensions:
embeddings = F.pad(embeddings, (0, target_dimensions - curr_dim), "constant", 0)
elif curr_dim > target_dimensions:
embeddings = embeddings[:, :target_dimensions]
return F.normalize(embeddings, p=2, dim=1)
@app.post("/v1/embeddings", dependencies=[Depends(require_auth)])
async def embeddings(request: Request):
data = await request.json()
input_data = data.get("input", "")
2026-03-11 16:12:05 +00:00
input_kind = type(input_data).__name__
input_count = len(input_data) if isinstance(input_data, list) else (1 if isinstance(input_data, str) else 0)
logger.info("/v1/embeddings request received: input_kind=%s input_count=%s", input_kind, input_count)
2026-03-11 21:33:17 +00:00
logger.info("/v1/embeddings using target_dimensions=%s", TARGET_DIMENSIONS)
if isinstance(input_data, str):
inputs = [input_data]
elif isinstance(input_data, list):
inputs = [str(item) for item in input_data if str(item).strip()]
else:
2026-03-11 16:12:05 +00:00
logger.warning("/v1/embeddings bad input type: %s", input_kind)
raise HTTPException(status_code=400, detail="'input' must be a string or list of strings")
if not inputs:
return {
"object": "list",
"data": [],
"model": EMBED_MODEL_NAME,
"usage": {"prompt_tokens": 0, "total_tokens": 0},
}
model = state.get("embed_model")
if model is None:
raise HTTPException(status_code=503, detail="Embedding model not initialized")
prefixed_inputs = [
text if text.startswith("search_") else f"search_query: {text}"
for text in inputs
]
loop = asyncio.get_event_loop()
def _encode():
with no_grad():
vectors = model.encode(prefixed_inputs, convert_to_tensor=True)
return pad_and_normalize(vectors, target_dimensions=TARGET_DIMENSIONS)
async with gpu_semaphore:
vectors = await loop.run_in_executor(None, _encode)
vector_list = vectors.cpu().tolist()
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": idx,
"embedding": embedding,
}
for idx, embedding in enumerate(vector_list)
],
"model": EMBED_MODEL_NAME,
"usage": {
"prompt_tokens": sum(len(text.split()) for text in inputs),
"total_tokens": sum(len(text.split()) for text in inputs),
},
}
@app.post("/v1/semantic-chunk", dependencies=[Depends(require_auth)])
async def semantic_chunk(request: Request):
data = await request.json()
raw_text = data.get("text", "")
threshold_percentile = data.get("threshold", 95)
2026-03-11 16:12:05 +00:00
raw_text_len = len(raw_text) if isinstance(raw_text, str) else -1
logger.info("/v1/semantic-chunk request received: text_len=%s threshold=%s", raw_text_len, threshold_percentile,)
2026-03-11 21:33:17 +00:00
logger.info("/v1/semantic-chunk using target_dimensions=%s", TARGET_DIMENSIONS)
if not raw_text:
2026-03-11 16:12:05 +00:00
logger.info("/v1/semantic-chunk empty text payload")
return {"chunks": [], "embeddings": []}
if len(raw_text) > 50000:
2026-03-11 16:12:05 +00:00
logger.warning("/v1/semantic-chunk payload too large: text_len=%s", len(raw_text))
raise HTTPException(status_code=413, detail="Text block too large. Please batch on the client.")
model = state.get("embed_model")
if model is None:
2026-03-11 16:12:05 +00:00
logger.error("/v1/semantic-chunk embedding model not initialized")
raise HTTPException(status_code=503, detail="Embedding model not initialized")
loop = asyncio.get_event_loop()
sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()]
def _chunk_and_embed():
if len(sentences) < 2:
single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True)
single = pad_and_normalize(single, target_dimensions=TARGET_DIMENSIONS)
return {"chunks": [raw_text], "embeddings": single.cpu().tolist()}
s_embeddings = model.encode(sentences, convert_to_tensor=True)
distances = [
1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item()
for i in range(len(s_embeddings) - 1)
]
breakpoint_threshold = np.percentile(distances, threshold_percentile)
indices = [i for i, d in enumerate(distances) if d > breakpoint_threshold]
chunks = []
start = 0
for idx in indices:
chunks.append(". ".join(sentences[start : idx + 1]) + ".")
start = idx + 1
chunks.append(". ".join(sentences[start:]) + ".")
with no_grad():
final_embeddings = model.encode(
[f"search_document: {c}" for c in chunks],
convert_to_tensor=True
)
final_embeddings = pad_and_normalize(final_embeddings, target_dimensions=TARGET_DIMENSIONS)
return {"chunks": chunks, "embeddings": final_embeddings.cpu().tolist()}
async with gpu_semaphore:
result = await loop.run_in_executor(None, _chunk_and_embed)
return result
@app.post("/v1/chat/completions", dependencies=[Depends(require_auth)])
async def chat_completions(request: Request):
try:
data = await request.json()
except Exception as e:
raw_body = await request.body()
preview = raw_body[:500].decode("utf-8", errors="replace")
logger.error(f"Invalid JSON payload for chat completions: {e}; body_preview={preview}")
raise HTTPException(status_code=400, detail="Invalid JSON payload")
messages = data.get("messages", [])
stream = data.get("stream", False)
logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}")
2026-03-22 15:34:06 +00:00
llm = await _ensure_llm()
loop = asyncio.get_event_loop()
temperature = data.get("temperature", 0.7)
max_tokens = data.get("max_tokens", 1024)
def _infer():
return llm.create_chat_completion(
messages=messages,
stream=False,
temperature=temperature,
max_tokens=max_tokens,
stop=["<|eot_id|>", "<|end_of_text|>"],
)
try:
if stream:
def _infer_stream():
return llm.create_chat_completion(
messages=messages,
stream=True,
temperature=temperature,
max_tokens=max_tokens,
stop=["<|eot_id|>", "<|end_of_text|>"],
)
async def _stream_response():
async with gpu_semaphore:
chunks = await loop.run_in_executor(None, lambda: list(_infer_stream()))
for chunk in chunks:
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(_stream_response(), media_type="text/event-stream")
async with gpu_semaphore:
response = await loop.run_in_executor(None, _infer)
return response
except Exception as e:
logger.error(f"Inference error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run("gpu_server:app", host="0.0.0.0", port=8001, reload=True)