import asyncio import gc import logging import os import json import time from contextlib import asynccontextmanager from typing import Dict, Any import numpy as np from torch import cuda, no_grad, Tensor import torch.nn.functional as F import secrets from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.responses import StreamingResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials from llama_cpp import Llama from sentence_transformers import SentenceTransformer logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) logger = logging.getLogger("gpu-node") EMBED_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5" LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf") LLM_IDLE_TIMEOUT = int(os.getenv("LLM_IDLE_TIMEOUT", "1800")) TARGET_DIMENSIONS = 768 state: Dict[str, Any] = {} gpu_semaphore = asyncio.Semaphore(1) def _load_llm() -> Llama: logger.info(f"Loading LLM: {LLM_MODEL_PATH}") return Llama(model_path=LLM_MODEL_PATH, n_gpu_layers=-1, n_ctx=8192, n_batch=512, verbose=False) def _unload_llm(): llm = state.pop("llm", None) del llm gc.collect() if cuda.is_available(): cuda.empty_cache() logger.info("LLM unloaded due to inactivity") async def _inactivity_watcher(): while True: await asyncio.sleep(60) llm = state.get("llm") last_used = state.get("llm_last_used") if llm is not None and last_used is not None: if time.monotonic() - last_used > LLM_IDLE_TIMEOUT: async with gpu_semaphore: _unload_llm() def _touch_llm(): state["llm_last_used"] = time.monotonic() async def _ensure_llm() -> Llama: if state.get("llm") is None: async with gpu_semaphore: if state.get("llm") is None: if not os.path.exists(LLM_MODEL_PATH): raise HTTPException(status_code=503, detail="LLM model file not found.") loop = asyncio.get_event_loop() state["llm"] = await loop.run_in_executor(None, _load_llm) _touch_llm() return state["llm"] @asynccontextmanager async def lifespan(app: FastAPI): device = "cuda" if cuda.is_available() else "cpu" logger.info(f"--- Initializing GPU Node on {device} ---") if device == "cpu": logger.warning("CUDA NOT DETECTED. Performance will be severely degraded.") try: logger.info(f"Loading Embedding Model: {EMBED_MODEL_NAME}") state["embed_model"] = SentenceTransformer( EMBED_MODEL_NAME, trust_remote_code=True, device=device ) if not os.path.exists(LLM_MODEL_PATH): logger.warning(f"LLM file not found at {LLM_MODEL_PATH} — will load on first request") else: state["llm"] = _load_llm() _touch_llm() logger.info(f"--- GPU Node Ready (LLM idle timeout: {LLM_IDLE_TIMEOUT}s) ---") except Exception as e: logger.error(f"Failed to load models: {e}") raise e watcher = asyncio.create_task(_inactivity_watcher()) yield watcher.cancel() state.clear() if cuda.is_available(): cuda.empty_cache() app = FastAPI(title="Agentic GPU Node", lifespan=lifespan, docs_url=None, redoc_url=None, openapi_url=None) _security = HTTPBasic() _API_USER = os.getenv("INFERENCE_USERNAME", "admin") _API_PASS = os.getenv("INFERENCE_PASSWORD", "changeme") def require_auth(credentials: HTTPBasicCredentials = Depends(_security)): valid_user = secrets.compare_digest(credentials.username.encode(), _API_USER.encode()) valid_pass = secrets.compare_digest(credentials.password.encode(), _API_PASS.encode()) if not (valid_user and valid_pass): raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"}) @app.get("/health", dependencies=[Depends(require_auth)]) async def health(): return { "status": "ok", "embedding_ready": state.get("embed_model") is not None, "llm_ready": state.get("llm") is not None, } def pad_and_normalize(embeddings: Tensor, target_dimensions: int) -> Tensor: curr_dim = embeddings.shape[1] if curr_dim < target_dimensions: embeddings = F.pad(embeddings, (0, target_dimensions - curr_dim), "constant", 0) elif curr_dim > target_dimensions: embeddings = embeddings[:, :target_dimensions] return F.normalize(embeddings, p=2, dim=1) @app.post("/v1/embeddings", dependencies=[Depends(require_auth)]) async def embeddings(request: Request): data = await request.json() input_data = data.get("input", "") input_kind = type(input_data).__name__ input_count = len(input_data) if isinstance(input_data, list) else (1 if isinstance(input_data, str) else 0) logger.info("/v1/embeddings request received: input_kind=%s input_count=%s", input_kind, input_count) logger.info("/v1/embeddings using target_dimensions=%s", TARGET_DIMENSIONS) if isinstance(input_data, str): inputs = [input_data] elif isinstance(input_data, list): inputs = [str(item) for item in input_data if str(item).strip()] else: logger.warning("/v1/embeddings bad input type: %s", input_kind) raise HTTPException(status_code=400, detail="'input' must be a string or list of strings") if not inputs: return { "object": "list", "data": [], "model": EMBED_MODEL_NAME, "usage": {"prompt_tokens": 0, "total_tokens": 0}, } model = state.get("embed_model") if model is None: raise HTTPException(status_code=503, detail="Embedding model not initialized") prefixed_inputs = [ text if text.startswith("search_") else f"search_query: {text}" for text in inputs ] loop = asyncio.get_event_loop() def _encode(): with no_grad(): vectors = model.encode(prefixed_inputs, convert_to_tensor=True) return pad_and_normalize(vectors, target_dimensions=TARGET_DIMENSIONS) async with gpu_semaphore: vectors = await loop.run_in_executor(None, _encode) vector_list = vectors.cpu().tolist() return { "object": "list", "data": [ { "object": "embedding", "index": idx, "embedding": embedding, } for idx, embedding in enumerate(vector_list) ], "model": EMBED_MODEL_NAME, "usage": { "prompt_tokens": sum(len(text.split()) for text in inputs), "total_tokens": sum(len(text.split()) for text in inputs), }, } @app.post("/v1/semantic-chunk", dependencies=[Depends(require_auth)]) async def semantic_chunk(request: Request): data = await request.json() raw_text = data.get("text", "") threshold_percentile = data.get("threshold", 95) raw_text_len = len(raw_text) if isinstance(raw_text, str) else -1 logger.info("/v1/semantic-chunk request received: text_len=%s threshold=%s", raw_text_len, threshold_percentile,) logger.info("/v1/semantic-chunk using target_dimensions=%s", TARGET_DIMENSIONS) if not raw_text: logger.info("/v1/semantic-chunk empty text payload") return {"chunks": [], "embeddings": []} if len(raw_text) > 50000: logger.warning("/v1/semantic-chunk payload too large: text_len=%s", len(raw_text)) raise HTTPException(status_code=413, detail="Text block too large. Please batch on the client.") model = state.get("embed_model") if model is None: logger.error("/v1/semantic-chunk embedding model not initialized") raise HTTPException(status_code=503, detail="Embedding model not initialized") loop = asyncio.get_event_loop() sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()] def _chunk_and_embed(): if len(sentences) < 2: single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True) single = pad_and_normalize(single, target_dimensions=TARGET_DIMENSIONS) return {"chunks": [raw_text], "embeddings": single.cpu().tolist()} with no_grad(): s_embeddings = model.encode(sentences, convert_to_tensor=True) distances = [ 1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item() for i in range(len(s_embeddings) - 1) ] breakpoint_threshold = np.percentile(distances, threshold_percentile) indices = [i for i, d in enumerate(distances) if d > breakpoint_threshold] chunks = [] start = 0 for idx in indices: chunks.append(". ".join(sentences[start : idx + 1]) + ".") start = idx + 1 chunks.append(". ".join(sentences[start:]) + ".") with no_grad(): final_embeddings = model.encode( [f"search_document: {c}" for c in chunks], convert_to_tensor=True ) final_embeddings = pad_and_normalize(final_embeddings, target_dimensions=TARGET_DIMENSIONS) return {"chunks": chunks, "embeddings": final_embeddings.cpu().tolist()} async with gpu_semaphore: result = await loop.run_in_executor(None, _chunk_and_embed) return result @app.post("/v1/chat/completions", dependencies=[Depends(require_auth)]) async def chat_completions(request: Request): try: data = await request.json() except Exception as e: raw_body = await request.body() preview = raw_body[:500].decode("utf-8", errors="replace") logger.error(f"Invalid JSON payload for chat completions: {e}; body_preview={preview}") raise HTTPException(status_code=400, detail="Invalid JSON payload") messages = data.get("messages", []) stream = data.get("stream", False) logger.info(f"Chat completion request: {len(messages)} messages, stream={stream}") llm = await _ensure_llm() loop = asyncio.get_event_loop() temperature = data.get("temperature", 0.7) max_tokens = data.get("max_tokens", 1024) def _infer(): return llm.create_chat_completion( messages=messages, stream=False, temperature=temperature, max_tokens=max_tokens, stop=["<|eot_id|>", "<|end_of_text|>"], ) try: if stream: sentinel = object() async def _stream_response(): queue: asyncio.Queue = asyncio.Queue() _loop = asyncio.get_event_loop() def _produce(): try: for chunk in llm.create_chat_completion( messages=messages, stream=True, temperature=temperature, max_tokens=max_tokens, stop=["<|eot_id|>", "<|end_of_text|>"], ): _loop.call_soon_threadsafe(queue.put_nowait, chunk) finally: _loop.call_soon_threadsafe(queue.put_nowait, sentinel) async with gpu_semaphore: fut = _loop.run_in_executor(None, _produce) while True: item = await queue.get() if item is sentinel: break yield f"data: {json.dumps(item)}\n\n" await fut yield "data: [DONE]\n\n" return StreamingResponse(_stream_response(), media_type="text/event-stream") async with gpu_semaphore: response = await loop.run_in_executor(None, _infer) return response except Exception as e: logger.error(f"Inference error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run("gpu_server:app", host="0.0.0.0", port=8001, reload=True)