Dynavera/apps/mlstore/services.py

405 lines
13 KiB
Python
Raw Normal View History

import asyncio
import logging
import os
import re
from typing import Any, Dict, List, Optional, Tuple
from django.conf import settings
from mcp_agent.mcp_client import MCPClient
from .models import AgentModel, RoleRagDocument
logger = logging.getLogger(__name__)
try:
from mcp_agent.mcp_server import BASE_MODEL_CACHE_DIR
BASE_MODEL_CACHE = BASE_MODEL_CACHE_DIR
except ImportError:
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
BASE_MODEL_CACHE = os.path.join(project_root, "model", "base-model")
logger.info(f"Base model cache directory reference: {BASE_MODEL_CACHE}")
async def _call_mcp(tool: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Internal async helper to call the MCP HTTP bridge via MCPClient."""
server_url = getattr(settings, "MCP_AGENT_URL")
client = MCPClient(server_url)
logger.info(f"MCP: Calling tool '{tool}' on {server_url}")
logger.debug(f"MCP: Arguments for '{tool}': {arguments}")
try:
resp = await client.send(tool, arguments)
logger.info(f"MCP: Tool '{tool}' completed successfully")
logger.debug(f"MCP: Response from '{tool}': {resp}")
return resp
except Exception as e:
logger.error(f"MCP: Tool '{tool}' failed with error: {str(e)}")
raise
finally:
await client.close()
def fine_tune_model(
base_model: str,
training_files: List[str],
hyperparams: Dict[str, Any],
name: str,
version: str,
) -> Dict[str, Any]:
"""Synchronously request a fine-tune run on the MCP server.
Expects the MCP tool `fine_tune` to accept: {base_model, training_files, hyperparams, name, version}
and to return a JSON-like dict containing at least `status` and on success `model_path` and `version`.
"""
logger.info(f"Fine-tuning model: name={name}, version={version}, base_model={base_model}")
logger.info(f"Training files count: {len(training_files)}")
logger.debug(f"Training files: {training_files}")
try:
logger.info("Calling MCP fine_tune tool...")
result = asyncio.run(_call_mcp("fine_tune", {
"base_model": base_model,
"training_files": training_files,
"hyperparams": hyperparams,
"name": name,
"version": version,
}))
logger.info(f"Fine-tune completed: status={result.get('status')}")
logger.debug(f"Fine-tune result: {result}")
return result
except Exception as e:
error_msg = str(e) if str(e) else f"Unknown error: {type(e).__name__}"
logger.error(f"Fine-tune failed: {error_msg}", exc_info=True)
return {
"status": "failed",
"error": error_msg,
"error_type": type(e).__name__,
}
def load_model_for_inference(model_path: str) -> Dict[str, Any]:
"""Tell the MCP server to load a model into memory/serving for inference.
Expects the MCP tool `load_model` with {model_path} returning status info.
"""
logger.info(f"Loading model for inference: {model_path}")
try:
result = asyncio.run(_call_mcp("load_model", {"model_path": model_path}))
logger.info(f"Model loaded successfully")
return result
except Exception as e:
logger.error(f"Failed to load model: {str(e)}", exc_info=True)
raise
def infer_with_model(model_path: str, prompt: str, options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Request inference from the MCP server using a previously fine-tuned model.
Calls the MCP tool `infer` with {model_path, prompt, options}.
"""
logger.info(f"Running inference with model: {model_path}")
logger.debug(f"Prompt length: {len(prompt)} characters")
logger.debug(f"Inference options: {options}")
try:
result = asyncio.run(_call_mcp("infer", {"model_path": model_path, "prompt": prompt, "options": options or {}}))
logger.info(f"Inference completed successfully")
logger.debug(f"Inference result keys: {list(result.keys()) if isinstance(result, dict) else 'not a dict'}")
return result
except Exception as e:
logger.error(f"Inference failed: {str(e)}", exc_info=True)
raise
def register_model_in_db(name: str, version: str, model_path: str) -> AgentModel:
"""Convenience DB helper: create and return an AgentModel record.
NOTE: migrations are required after the model field change prior to using this in production.
"""
return AgentModel.objects.create(name=name, version=version, path=model_path)
def embed_texts(texts: List[str]) -> List[List[float]]:
"""Generate embeddings for texts using the MCP embedding service.
Falls back to local sentence-transformers if MCP unavailable.
Args:
texts: List of text strings to embed.
Returns:
List of embedding vectors (list of floats).
Raises:
RuntimeError: If both MCP and local embedding fail.
"""
logger.info(f"Embedding {len(texts)} texts")
try:
result = asyncio.run(_call_mcp("embed", {"texts": texts}))
embeddings = result.get("embeddings", [])
if embeddings and len(embeddings) == len(texts):
logger.info(f"Successfully embedded {len(texts)} texts via MCP")
return embeddings
except Exception as e:
logger.warning(f"MCP embedding failed, trying local fallback: {e}")
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(texts).tolist()
logger.info(f"Successfully embedded {len(texts)} texts via local model")
return embeddings
except Exception as e:
logger.error(f"Local embedding also failed: {e}")
raise RuntimeError(f"Failed to embed texts: {e}")
def embed_text(text: str) -> List[float]:
"""Generate embedding for a single text.
Args:
text: Text string to embed.
Returns:
Embedding vector (list of floats).
"""
return embed_texts([text])[0]
def search_similar_documents(
query: str,
role_uuid: str,
top_k: int = 5,
similarity_threshold: float = 0.0,
) -> List[Tuple[RoleRagDocument, float]]:
"""Search for documents similar to the query using vector similarity.
Args:
query: Query text to embed and search for.
role_uuid: UUID of role to scope search.
top_k: Number of top results to return.
similarity_threshold: Minimum similarity score (0-1) to include results.
Returns:
List of (RoleRagDocument, similarity_score) tuples, ordered by similarity DESC.
Raises:
ValueError: If role not found or embedding fails.
"""
from apps.orgs.models import Role
try:
query_embedding = embed_text(query)
logger.info(f"Embedded query: '{query[:50]}...' to {len(query_embedding)}D vector")
except Exception as e:
logger.error(f"Failed to embed query: {e}")
raise ValueError(f"Failed to embed query: {e}")
try:
role = Role.objects.get(uuid=role_uuid)
except Role.DoesNotExist:
raise ValueError(f"Role with UUID {role_uuid} not found")
queryset = RoleRagDocument.objects.filter(
role=role,
)
if not queryset.exists():
logger.warning(f"No documents with embeddings found for role {role.uuid}")
return []
from django.db import connection
with connection.cursor() as cursor:
query_sql = """
SELECT id, 1 - (embedding <=> %s::vector) as similarity
FROM mlstore_roleragdocument
WHERE role_id = %s AND embedding IS NOT NULL
ORDER BY similarity DESC
LIMIT %s
"""
cursor.execute(
query_sql,
)
doc_ids_with_scores = cursor.fetchall()
if not doc_ids_with_scores:
logger.info(f"No similar documents found for query in role {role.uuid}")
return []
filtered_docs = [
(doc_id, score)
for doc_id, score in doc_ids_with_scores
if score >= similarity_threshold
][:top_k]
if not filtered_docs:
logger.info(
f"No documents met similarity threshold {similarity_threshold}"
)
return []
doc_ids = [doc_id for doc_id, _ in filtered_docs]
doc_scores = {doc_id: score for doc_id, score in filtered_docs}
documents = RoleRagDocument.objects.filter(id__in=doc_ids)
results = [
(doc, doc_scores[doc.id])
for doc in documents
if doc.id in doc_scores
]
results.sort(key=lambda x: x[1], reverse=True)
logger.info(
f"Found {len(results)} similar documents for query "
f"(threshold={similarity_threshold}, top_k={top_k})"
)
return results
def batch_embed_documents(
documents: List[RoleRagDocument],
batch_size: int = 32,
force_reembed: bool = False,
) -> Tuple[int, int]:
"""Batch embed documents that don't have embeddings yet.
Args:
documents: List of RoleRagDocument instances to embed.
batch_size: Number of documents to embed per API call.
force_reembed: If True, re-embed documents that already have embeddings.
Returns:
Tuple of (num_embedded, num_failed).
Note:
Updates documents in-place with embedding values.
"""
to_embed = [
doc for doc in documents
if force_reembed or not doc.embedding
]
if not to_embed:
logger.info("No documents to embed")
return 0, 0
num_embedded = 0
num_failed = 0
for i in range(0, len(to_embed), batch_size):
batch = to_embed[i : i + batch_size]
logger.info(
f"Embedding batch {i // batch_size + 1} "
f"({len(batch)} documents)"
)
try:
texts = [doc.content for doc in batch]
embeddings = embed_texts(texts)
for doc, embedding in zip(batch, embeddings):
doc.embedding = embedding
num_embedded += 1
RoleRagDocument.objects.bulk_update(batch, ["embedding"], batch_size=500)
logger.info(f"Successfully embedded {len(batch)} documents")
except Exception as e:
logger.error(f"Failed to embed batch: {e}")
num_failed += len(batch)
logger.info(
f"Embedding complete: {num_embedded} embedded, {num_failed} failed"
)
return num_embedded, num_failed
def get_context_for_query(
query: str,
role_uuid: str,
top_k: int = 5,
similarity_threshold: float = 0.5,
) -> str:
"""Get context string from similar documents for a query.
Useful for augmenting prompts with retrieved context.
Args:
query: Query text.
role_uuid: UUID of role to search within.
top_k: Number of top results to include.
similarity_threshold: Minimum similarity score.
Returns:
Formatted context string with source attribution.
"""
def _clean_chunk_text(text: str) -> str:
"""Strip junk and deduplicate paragraphs to keep context lean."""
if not text:
return ""
text = re.sub(r"\[\s*Answer\s*:.*?\]", "", text, flags=re.IGNORECASE | re.DOTALL)
lines = []
for raw_line in text.splitlines():
line = raw_line.strip()
if not line:
lines.append("")
continue
lower = line.lower()
if line.startswith("#"):
continue
if "do you have any questions" in lower:
continue
if "feel free to ask" in lower:
continue
if "references" in lower or "sources" in lower or "wikipedia" in lower:
continue
lines.append(line)
cleaned = "\n".join(lines)
paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", cleaned) if p.strip()]
seen = set()
unique_paragraphs: List[str] = []
for para in paragraphs:
if para in seen:
continue
seen.add(para)
unique_paragraphs.append(para)
return "\n\n".join(unique_paragraphs)
try:
results = search_similar_documents(
query=query,
role_uuid=role_uuid,
top_k=top_k,
similarity_threshold=similarity_threshold,
)
except Exception as e:
logger.warning(f"Failed to retrieve context: {e}")
return ""
if not results:
return ""
context_parts = []
for doc, similarity in results:
cleaned = _clean_chunk_text(doc.content)
if not cleaned:
continue
source = "unknown"
if doc.training_file:
source = doc.training_file.file_name
context_parts.append(
f"[Source: {source}, Similarity: {similarity:.2%}]\n{cleaned}\n"
)
context = "\n---\n".join(context_parts)
return context