Dynavera/gpu_server.py
2026-03-11 14:33:39 +00:00

267 lines
No EOL
8.7 KiB
Python

import logging
import os
import json
from contextlib import asynccontextmanager
from typing import Dict, Any
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
from llama_cpp import Llama
from sentence_transformers import SentenceTransformer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger("gpu-node")
EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf")
state: Dict[str, Any] = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Handles GPU model loading and cleanup."""
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"--- Initializing GPU Node on {device} ---")
if device == "cpu":
logger.warning("CUDA NOT DETECTED. Performance will be severely degraded.")
try:
# Load Embedding Model (Nomic)
logger.info(f"Loading Embedding Model: {EMBED_MODEL_NAME}")
state["embed_model"] = SentenceTransformer(
EMBED_MODEL_NAME,
trust_remote_code=True,
device=device
)
# Load Llama Model (GGUF)
if not os.path.exists(LLM_MODEL_PATH):
logger.error(f"LLM File not found at {LLM_MODEL_PATH}")
else:
logger.info(f"Loading LLM: {LLM_MODEL_PATH}")
state["llm"] = Llama(
model_path=LLM_MODEL_PATH,
n_gpu_layers=-1, # Offload all layers to GPU
n_ctx=8192,
n_batch=512,
verbose=False
)
logger.info("--- GPU Node Ready ---")
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise e
yield
# Cleanup
state.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
app = FastAPI(title="Agentic GPU Node", lifespan=lifespan)
@app.get("/health")
async def health():
return {
"status": "ok",
"embedding_ready": state.get("embed_model") is not None,
"llm_ready": state.get("llm") is not None,
}
def _resolve_target_dimensions(payload: Dict[str, Any]) -> int:
raw_target = payload.get("target_dimensions")
if raw_target in (None, ""):
raise HTTPException(status_code=400, detail="'target_dimensions' is required")
try:
target = int(raw_target)
except (TypeError, ValueError) as exc:
raise HTTPException(status_code=400, detail="'target_dimensions' must be an integer") from exc
if target <= 0:
raise HTTPException(status_code=400, detail="'target_dimensions' must be > 0")
return target
def pad_and_normalize(embeddings: torch.Tensor, target_dimensions: int) -> torch.Tensor:
"""Dimension standardization plus L2 normalization."""
curr_dim = embeddings.shape[1]
if curr_dim < target_dimensions:
embeddings = F.pad(embeddings, (0, target_dimensions - curr_dim), "constant", 0)
elif curr_dim > target_dimensions:
embeddings = embeddings[:, :target_dimensions]
return F.normalize(embeddings, p=2, dim=1)
@app.post("/v1/embeddings")
async def embeddings(request: Request):
"""Generates text embeddings compatible with OpenAI API format."""
data = await request.json()
input_data = data.get("input", "")
target_dimensions = _resolve_target_dimensions(data)
if isinstance(input_data, str):
inputs = [input_data]
elif isinstance(input_data, list):
inputs = [str(item) for item in input_data if str(item).strip()]
else:
raise HTTPException(status_code=400, detail="'input' must be a string or list of strings")
if not inputs:
return {
"object": "list",
"data": [],
"model": EMBED_MODEL_NAME,
"usage": {"prompt_tokens": 0, "total_tokens": 0},
}
model = state.get("embed_model")
if model is None:
raise HTTPException(status_code=503, detail="Embedding model not initialized")
prefixed_inputs = [
text if text.startswith("search_") else f"search_query: {text}"
for text in inputs
]
with torch.no_grad():
vectors = model.encode(prefixed_inputs, convert_to_tensor=True)
vectors = pad_and_normalize(vectors, target_dimensions=target_dimensions)
vector_list = vectors.cpu().tolist()
return {
"object": "list",
"data": [
{
"object": "embedding",
"index": idx,
"embedding": embedding,
}
for idx, embedding in enumerate(vector_list)
],
"model": EMBED_MODEL_NAME,
"usage": {
"prompt_tokens": sum(len(text.split()) for text in inputs),
"total_tokens": sum(len(text.split()) for text in inputs),
},
}
@app.post("/v1/semantic-chunk")
async def semantic_chunk(request: Request):
"""Processes raw text into semantically cohesive blocks."""
data = await request.json()
raw_text = data.get("text", "")
threshold_percentile = data.get("threshold", 95)
target_dimensions = _resolve_target_dimensions(data)
if not raw_text:
return {"chunks": [], "embeddings": []}
if len(raw_text) > 50000:
raise HTTPException(status_code=413, detail="Text block too large. Please batch on the client.")
model = state.get("embed_model")
if model is None:
raise HTTPException(status_code=503, detail="Embedding model not initialized")
# Split by sentences
sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()]
if len(sentences) < 2:
single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True)
single = pad_and_normalize(single, target_dimensions=target_dimensions)
return {
"chunks": [raw_text],
"embeddings": single.cpu().tolist(),
}
# Generate sentence embeddings to find breakpoints via cosine distance
s_embeddings = model.encode(sentences, convert_to_tensor=True)
distances = [
1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item()
for i in range(len(s_embeddings) - 1)
]
breakpoint_threshold = np.percentile(distances, threshold_percentile)
indices = [i for i, d in enumerate(distances) if d > breakpoint_threshold]
chunks = []
start = 0
for idx in indices:
chunks.append(". ".join(sentences[start : idx + 1]) + ".")
start = idx + 1
chunks.append(". ".join(sentences[start:]) + ".")
with torch.no_grad():
final_embeddings = model.encode(
[f"search_document: {c}" for c in chunks],
convert_to_tensor=True
)
final_embeddings = pad_and_normalize(final_embeddings, target_dimensions=target_dimensions)
return {
"chunks": chunks,
"embeddings": final_embeddings.cpu().tolist()
}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
"""Unified LLM completion endpoint compatible with OpenAI-style requests."""
try:
data = await request.json()
except Exception as e:
raw_body = await request.body()
preview = raw_body[:500].decode("utf-8", errors="replace")
logger.error(f"Invalid JSON payload for chat completions: {e}; body_preview={preview}")
raise HTTPException(status_code=400, detail="Invalid JSON payload")
messages = data.get("messages", [])
stream = data.get("stream", False)
# Log incoming request details
logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}")
llm = state.get("llm")
if not llm:
raise HTTPException(status_code=503, detail="LLM not initialized or model file missing.")
try:
response = llm.create_chat_completion(
messages=messages,
stream=stream,
temperature=data.get("temperature", 0.7),
max_tokens=data.get("max_tokens", 1024),
stop=["<|eot_id|>", "<|end_of_text|>"]
)
if stream:
return StreamingResponse(
llm_streamer(response),
media_type="text/event-stream"
)
return response
except Exception as e:
logger.error(f"Inference error: {e}")
raise HTTPException(status_code=500, detail=str(e))
async def llm_streamer(response_iterator):
"""Iterates through llama-cpp generator and yields SSE chunks."""
for chunk in response_iterator:
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
if __name__ == "__main__":
import uvicorn
uvicorn.run("gpu_server:app", host="0.0.0.0", port=8001, reload=True)