2026-02-26 01:32:04 +00:00
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import json
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from fastapi import FastAPI, Request, HTTPException
|
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
from llama_cpp import Llama
|
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
|
|
level=logging.INFO,
|
|
|
|
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
|
|
|
|
)
|
|
|
|
|
logger = logging.getLogger("gpu-node")
|
|
|
|
|
|
|
|
|
|
EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
|
|
|
|
|
LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf")
|
|
|
|
|
TARGET_DIMENSIONS = 1536
|
|
|
|
|
|
|
|
|
|
state: Dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
|
"""Handles GPU model loading and cleanup."""
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
logger.info(f"--- Initializing GPU Node on {device} ---")
|
|
|
|
|
|
|
|
|
|
if device == "cpu":
|
|
|
|
|
logger.warning("CUDA NOT DETECTED. Performance will be severely degraded.")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Load Embedding Model (Nomic)
|
|
|
|
|
logger.info(f"Loading Embedding Model: {EMBED_MODEL_NAME}")
|
|
|
|
|
state["embed_model"] = SentenceTransformer(
|
|
|
|
|
EMBED_MODEL_NAME,
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
device=device
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Load Llama Model (GGUF)
|
|
|
|
|
if not os.path.exists(LLM_MODEL_PATH):
|
|
|
|
|
logger.error(f"LLM File not found at {LLM_MODEL_PATH}")
|
|
|
|
|
else:
|
|
|
|
|
logger.info(f"Loading LLM: {LLM_MODEL_PATH}")
|
|
|
|
|
state["llm"] = Llama(
|
|
|
|
|
model_path=LLM_MODEL_PATH,
|
|
|
|
|
n_gpu_layers=-1, # Offload all layers to GPU
|
|
|
|
|
n_ctx=8192,
|
|
|
|
|
n_batch=512,
|
|
|
|
|
verbose=False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info("--- GPU Node Ready ---")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to load models: {e}")
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
# Cleanup
|
|
|
|
|
state.clear()
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Agentic GPU Node", lifespan=lifespan)
|
|
|
|
|
|
|
|
|
|
|
2026-02-27 00:45:34 +00:00
|
|
|
@app.get("/health")
|
|
|
|
|
async def health():
|
|
|
|
|
return {
|
|
|
|
|
"status": "ok",
|
|
|
|
|
"embedding_ready": state.get("embed_model") is not None,
|
|
|
|
|
"llm_ready": state.get("llm") is not None,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2026-02-26 01:32:04 +00:00
|
|
|
def pad_and_normalize(embeddings: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""Standardizes vector dimensions to 1536 for pgvector compatibility."""
|
|
|
|
|
curr_dim = embeddings.shape[1]
|
|
|
|
|
if curr_dim < TARGET_DIMENSIONS:
|
|
|
|
|
embeddings = F.pad(embeddings, (0, TARGET_DIMENSIONS - curr_dim), "constant", 0)
|
|
|
|
|
elif curr_dim > TARGET_DIMENSIONS:
|
|
|
|
|
embeddings = embeddings[:, :TARGET_DIMENSIONS]
|
|
|
|
|
return F.normalize(embeddings, p=2, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/v1/embeddings")
|
|
|
|
|
async def embeddings(request: Request):
|
|
|
|
|
"""Generates text embeddings compatible with OpenAI API format."""
|
|
|
|
|
data = await request.json()
|
|
|
|
|
input_data = data.get("input", "")
|
|
|
|
|
|
|
|
|
|
if isinstance(input_data, str):
|
|
|
|
|
inputs = [input_data]
|
|
|
|
|
elif isinstance(input_data, list):
|
|
|
|
|
inputs = [str(item) for item in input_data if str(item).strip()]
|
|
|
|
|
else:
|
|
|
|
|
raise HTTPException(status_code=400, detail="'input' must be a string or list of strings")
|
|
|
|
|
|
|
|
|
|
if not inputs:
|
|
|
|
|
return {
|
|
|
|
|
"object": "list",
|
|
|
|
|
"data": [],
|
|
|
|
|
"model": EMBED_MODEL_NAME,
|
|
|
|
|
"usage": {"prompt_tokens": 0, "total_tokens": 0},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
model = state.get("embed_model")
|
|
|
|
|
if model is None:
|
|
|
|
|
raise HTTPException(status_code=503, detail="Embedding model not initialized")
|
|
|
|
|
|
|
|
|
|
prefixed_inputs = [
|
|
|
|
|
text if text.startswith("search_") else f"search_query: {text}"
|
|
|
|
|
for text in inputs
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
vectors = model.encode(prefixed_inputs, convert_to_tensor=True)
|
|
|
|
|
vectors = pad_and_normalize(vectors)
|
|
|
|
|
|
|
|
|
|
vector_list = vectors.cpu().tolist()
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"object": "list",
|
|
|
|
|
"data": [
|
|
|
|
|
{
|
|
|
|
|
"object": "embedding",
|
|
|
|
|
"index": idx,
|
|
|
|
|
"embedding": embedding,
|
|
|
|
|
}
|
|
|
|
|
for idx, embedding in enumerate(vector_list)
|
|
|
|
|
],
|
|
|
|
|
"model": EMBED_MODEL_NAME,
|
|
|
|
|
"usage": {
|
|
|
|
|
"prompt_tokens": sum(len(text.split()) for text in inputs),
|
|
|
|
|
"total_tokens": sum(len(text.split()) for text in inputs),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@app.post("/v1/semantic-chunk")
|
|
|
|
|
async def semantic_chunk(request: Request):
|
|
|
|
|
"""Processes raw text into semantically cohesive blocks."""
|
|
|
|
|
data = await request.json()
|
|
|
|
|
raw_text = data.get("text", "")
|
|
|
|
|
threshold_percentile = data.get("threshold", 95)
|
|
|
|
|
|
|
|
|
|
if not raw_text:
|
|
|
|
|
return {"chunks": [], "embeddings": []}
|
|
|
|
|
|
|
|
|
|
if len(raw_text) > 50000:
|
|
|
|
|
raise HTTPException(status_code=413, detail="Text block too large. Please batch on the client.")
|
|
|
|
|
|
|
|
|
|
model = state.get("embed_model")
|
|
|
|
|
if model is None:
|
|
|
|
|
raise HTTPException(status_code=503, detail="Embedding model not initialized")
|
|
|
|
|
|
|
|
|
|
# Split by sentences
|
|
|
|
|
sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()]
|
|
|
|
|
if len(sentences) < 2:
|
|
|
|
|
return {
|
|
|
|
|
"chunks": [raw_text],
|
|
|
|
|
"embeddings": model.encode([f"search_document: {raw_text}"]).tolist()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Generate sentence embeddings to find breakpoints via cosine distance
|
|
|
|
|
s_embeddings = model.encode(sentences, convert_to_tensor=True)
|
|
|
|
|
distances = [
|
|
|
|
|
1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item()
|
|
|
|
|
for i in range(len(s_embeddings) - 1)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
breakpoint_threshold = np.percentile(distances, threshold_percentile)
|
|
|
|
|
indices = [i for i, d in enumerate(distances) if d > breakpoint_threshold]
|
|
|
|
|
|
|
|
|
|
chunks = []
|
|
|
|
|
start = 0
|
|
|
|
|
for idx in indices:
|
|
|
|
|
chunks.append(". ".join(sentences[start : idx + 1]) + ".")
|
|
|
|
|
start = idx + 1
|
|
|
|
|
chunks.append(". ".join(sentences[start:]) + ".")
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
final_embeddings = model.encode(
|
|
|
|
|
[f"search_document: {c}" for c in chunks],
|
|
|
|
|
convert_to_tensor=True
|
|
|
|
|
)
|
|
|
|
|
final_embeddings = pad_and_normalize(final_embeddings)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"chunks": chunks,
|
|
|
|
|
"embeddings": final_embeddings.cpu().tolist()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
|
|
|
async def chat_completions(request: Request):
|
|
|
|
|
"""Unified LLM completion endpoint compatible with OpenAI-style requests."""
|
2026-02-27 00:45:34 +00:00
|
|
|
try:
|
|
|
|
|
data = await request.json()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raw_body = await request.body()
|
|
|
|
|
preview = raw_body[:500].decode("utf-8", errors="replace")
|
|
|
|
|
logger.error(f"Invalid JSON payload for chat completions: {e}; body_preview={preview}")
|
|
|
|
|
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
|
|
|
|
|
2026-02-26 01:32:04 +00:00
|
|
|
messages = data.get("messages", [])
|
|
|
|
|
stream = data.get("stream", False)
|
|
|
|
|
|
|
|
|
|
# Log incoming request details
|
|
|
|
|
logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}")
|
|
|
|
|
|
|
|
|
|
llm = state.get("llm")
|
|
|
|
|
if not llm:
|
|
|
|
|
raise HTTPException(status_code=503, detail="LLM not initialized or model file missing.")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
response = llm.create_chat_completion(
|
|
|
|
|
messages=messages,
|
|
|
|
|
stream=stream,
|
|
|
|
|
temperature=data.get("temperature", 0.7),
|
|
|
|
|
max_tokens=data.get("max_tokens", 1024),
|
|
|
|
|
stop=["<|eot_id|>", "<|end_of_text|>"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if stream:
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
llm_streamer(response),
|
|
|
|
|
media_type="text/event-stream"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Inference error: {e}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
async def llm_streamer(response_iterator):
|
|
|
|
|
"""Iterates through llama-cpp generator and yields SSE chunks."""
|
|
|
|
|
for chunk in response_iterator:
|
|
|
|
|
yield f"data: {json.dumps(chunk)}\n\n"
|
|
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import uvicorn
|
|
|
|
|
uvicorn.run("gpu_server:app", host="0.0.0.0", port=8001, reload=True)
|