Tweaking and reframing yields and streaming with extra garbage collection

This commit is contained in:
Viswamedha Nalabotu 2026-03-22 17:15:42 +00:00
parent 1eada257b9
commit bf9eb6efb5

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import gc
import logging import logging
import os import os
import json import json
@ -38,6 +39,7 @@ def _load_llm() -> Llama:
def _unload_llm(): def _unload_llm():
llm = state.pop("llm", None) llm = state.pop("llm", None)
del llm del llm
gc.collect()
if cuda.is_available(): if cuda.is_available():
cuda.empty_cache() cuda.empty_cache()
logger.info("LLM unloaded due to inactivity") logger.info("LLM unloaded due to inactivity")
@ -56,8 +58,9 @@ def _touch_llm():
state["llm_last_used"] = time.monotonic() state["llm_last_used"] = time.monotonic()
async def _ensure_llm() -> Llama: async def _ensure_llm() -> Llama:
llm = state.get("llm") if state.get("llm") is None:
if llm is None: async with gpu_semaphore:
if state.get("llm") is None:
if not os.path.exists(LLM_MODEL_PATH): if not os.path.exists(LLM_MODEL_PATH):
raise HTTPException(status_code=503, detail="LLM model file not found.") raise HTTPException(status_code=503, detail="LLM model file not found.")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -225,6 +228,7 @@ async def semantic_chunk(request: Request):
single = pad_and_normalize(single, target_dimensions=TARGET_DIMENSIONS) single = pad_and_normalize(single, target_dimensions=TARGET_DIMENSIONS)
return {"chunks": [raw_text], "embeddings": single.cpu().tolist()} return {"chunks": [raw_text], "embeddings": single.cpu().tolist()}
with no_grad():
s_embeddings = model.encode(sentences, convert_to_tensor=True) s_embeddings = model.encode(sentences, convert_to_tensor=True)
distances = [ distances = [
1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item() 1 - F.cosine_similarity(s_embeddings[i].unsqueeze(0), s_embeddings[i+1].unsqueeze(0)).item()
@ -286,20 +290,33 @@ async def chat_completions(request: Request):
try: try:
if stream: if stream:
def _infer_stream(): sentinel = object()
return llm.create_chat_completion(
async def _stream_response():
queue: asyncio.Queue = asyncio.Queue()
_loop = asyncio.get_event_loop()
def _produce():
try:
for chunk in llm.create_chat_completion(
messages=messages, messages=messages,
stream=True, stream=True,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
stop=["<|eot_id|>", "<|end_of_text|>"], stop=["<|eot_id|>", "<|end_of_text|>"],
) ):
_loop.call_soon_threadsafe(queue.put_nowait, chunk)
finally:
_loop.call_soon_threadsafe(queue.put_nowait, sentinel)
async def _stream_response():
async with gpu_semaphore: async with gpu_semaphore:
chunks = await loop.run_in_executor(None, lambda: list(_infer_stream())) fut = _loop.run_in_executor(None, _produce)
for chunk in chunks: while True:
yield f"data: {json.dumps(chunk)}\n\n" item = await queue.get()
if item is sentinel:
break
yield f"data: {json.dumps(item)}\n\n"
await fut
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(_stream_response(), media_type="text/event-stream") return StreamingResponse(_stream_response(), media_type="text/event-stream")