diff --git a/gpu_server.py b/gpu_server.py index 5584d55..6eb6623 100644 --- a/gpu_server.py +++ b/gpu_server.py @@ -1,3 +1,4 @@ +import asyncio import logging import os import json @@ -23,6 +24,7 @@ LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH", "/app/models/Meta-Llama-3.1-8B-Inst TARGET_DIMENSIONS = 768 state: Dict[str, Any] = {} +gpu_semaphore = asyncio.Semaphore(1) @asynccontextmanager async def lifespan(app: FastAPI): @@ -117,9 +119,15 @@ async def embeddings(request: Request): for text in inputs ] - with no_grad(): - vectors = model.encode(prefixed_inputs, convert_to_tensor=True) - vectors = pad_and_normalize(vectors, target_dimensions=TARGET_DIMENSIONS) + loop = asyncio.get_event_loop() + + def _encode(): + with no_grad(): + vectors = model.encode(prefixed_inputs, convert_to_tensor=True) + return pad_and_normalize(vectors, target_dimensions=TARGET_DIMENSIONS) + + async with gpu_semaphore: + vectors = await loop.run_in_executor(None, _encode) vector_list = vectors.cpu().tolist() @@ -162,42 +170,43 @@ async def semantic_chunk(request: Request): logger.error("/v1/semantic-chunk embedding model not initialized") raise HTTPException(status_code=503, detail="Embedding model not initialized") + loop = asyncio.get_event_loop() sentences = [s.strip() for s in raw_text.replace('\n', ' ').split('. ') if s.strip()] - if len(sentences) < 2: - single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True) - single = pad_and_normalize(single, target_dimensions=TARGET_DIMENSIONS) - return { - "chunks": [raw_text], - "embeddings": single.cpu().tolist(), - } - s_embeddings = model.encode(sentences, convert_to_tensor=True) - distances = [ - 1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item() - for i in range(len(s_embeddings) - 1) - ] + def _chunk_and_embed(): + if len(sentences) < 2: + single = model.encode([f"search_document: {raw_text}"], convert_to_tensor=True) + single = pad_and_normalize(single, target_dimensions=TARGET_DIMENSIONS) + return {"chunks": [raw_text], "embeddings": single.cpu().tolist()} - breakpoint_threshold = np.percentile(distances, threshold_percentile) - indices = [i for i, d in enumerate(distances) if d > breakpoint_threshold] + s_embeddings = model.encode(sentences, convert_to_tensor=True) + distances = [ + 1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item() + for i in range(len(s_embeddings) - 1) + ] - chunks = [] - start = 0 - for idx in indices: - chunks.append(". ".join(sentences[start : idx + 1]) + ".") - start = idx + 1 - chunks.append(". ".join(sentences[start:]) + ".") + breakpoint_threshold = np.percentile(distances, threshold_percentile) + indices = [i for i, d in enumerate(distances) if d > breakpoint_threshold] - with no_grad(): - final_embeddings = model.encode( - [f"search_document: {c}" for c in chunks], - convert_to_tensor=True - ) - final_embeddings = pad_and_normalize(final_embeddings, target_dimensions=TARGET_DIMENSIONS) + chunks = [] + start = 0 + for idx in indices: + chunks.append(". ".join(sentences[start : idx + 1]) + ".") + start = idx + 1 + chunks.append(". ".join(sentences[start:]) + ".") - return { - "chunks": chunks, - "embeddings": final_embeddings.cpu().tolist() - } + with no_grad(): + final_embeddings = model.encode( + [f"search_document: {c}" for c in chunks], + convert_to_tensor=True + ) + final_embeddings = pad_and_normalize(final_embeddings, target_dimensions=TARGET_DIMENSIONS) + + return {"chunks": chunks, "embeddings": final_embeddings.cpu().tolist()} + + async with gpu_semaphore: + result = await loop.run_in_executor(None, _chunk_and_embed) + return result @app.post("/v1/chat/completions") async def chat_completions(request: Request): @@ -218,30 +227,47 @@ async def chat_completions(request: Request): if not llm: raise HTTPException(status_code=503, detail="LLM not initialized or model file missing.") - try: - response = llm.create_chat_completion( + loop = asyncio.get_event_loop() + temperature = data.get("temperature", 0.7) + max_tokens = data.get("max_tokens", 1024) + + def _infer(): + return llm.create_chat_completion( messages=messages, - stream=stream, - temperature=data.get("temperature", 0.7), - max_tokens=data.get("max_tokens", 1024), - stop=["<|eot_id|>", "<|end_of_text|>"] + stream=False, + temperature=temperature, + max_tokens=max_tokens, + stop=["<|eot_id|>", "<|end_of_text|>"], ) + try: if stream: - return StreamingResponse( - llm_streamer(response), - media_type="text/event-stream" - ) + # For streaming, run inference in executor and stream results back + def _infer_stream(): + return llm.create_chat_completion( + messages=messages, + stream=True, + temperature=temperature, + max_tokens=max_tokens, + stop=["<|eot_id|>", "<|end_of_text|>"], + ) + async def _stream_response(): + async with gpu_semaphore: + chunks = await loop.run_in_executor(None, lambda: list(_infer_stream())) + for chunk in chunks: + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(_stream_response(), media_type="text/event-stream") + + async with gpu_semaphore: + response = await loop.run_in_executor(None, _infer) return response except Exception as e: logger.error(f"Inference error: {e}") raise HTTPException(status_code=500, detail=str(e)) -async def llm_streamer(response_iterator): - for chunk in response_iterator: - yield f"data: {json.dumps(chunk)}\n\n" - yield "data: [DONE]\n\n" if __name__ == "__main__": import uvicorn