import asyncio import logging import os import re from typing import Any, Dict, List, Optional, Tuple from django.conf import settings from mcp_agent.mcp_client import MCPClient from .models import AgentModel, RoleRagDocument logger = logging.getLogger(__name__) try: from mcp_agent.mcp_server import BASE_MODEL_CACHE_DIR BASE_MODEL_CACHE = BASE_MODEL_CACHE_DIR except ImportError: project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) BASE_MODEL_CACHE = os.path.join(project_root, "model", "base-model") logger.info(f"Base model cache directory reference: {BASE_MODEL_CACHE}") async def _call_mcp(tool: str, arguments: Dict[str, Any]) -> Dict[str, Any]: """Internal async helper to call the MCP HTTP bridge via MCPClient.""" server_url = getattr(settings, "MCP_AGENT_URL") client = MCPClient(server_url) logger.info(f"MCP: Calling tool '{tool}' on {server_url}") logger.debug(f"MCP: Arguments for '{tool}': {arguments}") try: resp = await client.send(tool, arguments) logger.info(f"MCP: Tool '{tool}' completed successfully") logger.debug(f"MCP: Response from '{tool}': {resp}") return resp except Exception as e: logger.error(f"MCP: Tool '{tool}' failed with error: {str(e)}") raise finally: await client.close() def fine_tune_model( base_model: str, training_files: List[str], hyperparams: Dict[str, Any], name: str, version: str, ) -> Dict[str, Any]: """Synchronously request a fine-tune run on the MCP server. Expects the MCP tool `fine_tune` to accept: {base_model, training_files, hyperparams, name, version} and to return a JSON-like dict containing at least `status` and on success `model_path` and `version`. """ logger.info(f"Fine-tuning model: name={name}, version={version}, base_model={base_model}") logger.info(f"Training files count: {len(training_files)}") logger.debug(f"Training files: {training_files}") try: logger.info("Calling MCP fine_tune tool...") result = asyncio.run(_call_mcp("fine_tune", { "base_model": base_model, "training_files": training_files, "hyperparams": hyperparams, "name": name, "version": version, })) logger.info(f"Fine-tune completed: status={result.get('status')}") logger.debug(f"Fine-tune result: {result}") return result except Exception as e: error_msg = str(e) if str(e) else f"Unknown error: {type(e).__name__}" logger.error(f"Fine-tune failed: {error_msg}", exc_info=True) return { "status": "failed", "error": error_msg, "error_type": type(e).__name__, } def load_model_for_inference(model_path: str) -> Dict[str, Any]: """Tell the MCP server to load a model into memory/serving for inference. Expects the MCP tool `load_model` with {model_path} returning status info. """ logger.info(f"Loading model for inference: {model_path}") try: result = asyncio.run(_call_mcp("load_model", {"model_path": model_path})) logger.info(f"Model loaded successfully") return result except Exception as e: logger.error(f"Failed to load model: {str(e)}", exc_info=True) raise def infer_with_model(model_path: str, prompt: str, options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """Request inference from the MCP server using a previously fine-tuned model. Calls the MCP tool `infer` with {model_path, prompt, options}. """ logger.info(f"Running inference with model: {model_path}") logger.debug(f"Prompt length: {len(prompt)} characters") logger.debug(f"Inference options: {options}") try: result = asyncio.run(_call_mcp("infer", {"model_path": model_path, "prompt": prompt, "options": options or {}})) logger.info(f"Inference completed successfully") logger.debug(f"Inference result keys: {list(result.keys()) if isinstance(result, dict) else 'not a dict'}") return result except Exception as e: logger.error(f"Inference failed: {str(e)}", exc_info=True) raise def register_model_in_db(name: str, version: str, model_path: str) -> AgentModel: """Convenience DB helper: create and return an AgentModel record. NOTE: migrations are required after the model field change prior to using this in production. """ return AgentModel.objects.create(name=name, version=version, path=model_path) def embed_texts(texts: List[str]) -> List[List[float]]: """Generate embeddings for texts using the MCP embedding service. Falls back to local sentence-transformers if MCP unavailable. Args: texts: List of text strings to embed. Returns: List of embedding vectors (list of floats). Raises: RuntimeError: If both MCP and local embedding fail. """ logger.info(f"Embedding {len(texts)} texts") try: result = asyncio.run(_call_mcp("embed", {"texts": texts})) embeddings = result.get("embeddings", []) if embeddings and len(embeddings) == len(texts): logger.info(f"Successfully embedded {len(texts)} texts via MCP") return embeddings except Exception as e: logger.warning(f"MCP embedding failed, trying local fallback: {e}") try: from sentence_transformers import SentenceTransformer model = SentenceTransformer("all-MiniLM-L6-v2") embeddings = model.encode(texts).tolist() logger.info(f"Successfully embedded {len(texts)} texts via local model") return embeddings except Exception as e: logger.error(f"Local embedding also failed: {e}") raise RuntimeError(f"Failed to embed texts: {e}") def embed_text(text: str) -> List[float]: """Generate embedding for a single text. Args: text: Text string to embed. Returns: Embedding vector (list of floats). """ return embed_texts([text])[0] def search_similar_documents( query: str, role_uuid: str, top_k: int = 5, similarity_threshold: float = 0.0, ) -> List[Tuple[RoleRagDocument, float]]: """Search for documents similar to the query using vector similarity. Args: query: Query text to embed and search for. role_uuid: UUID of role to scope search. top_k: Number of top results to return. similarity_threshold: Minimum similarity score (0-1) to include results. Returns: List of (RoleRagDocument, similarity_score) tuples, ordered by similarity DESC. Raises: ValueError: If role not found or embedding fails. """ from apps.orgs.models import Role try: query_embedding = embed_text(query) logger.info(f"Embedded query: '{query[:50]}...' to {len(query_embedding)}D vector") except Exception as e: logger.error(f"Failed to embed query: {e}") raise ValueError(f"Failed to embed query: {e}") try: role = Role.objects.get(uuid=role_uuid) except Role.DoesNotExist: raise ValueError(f"Role with UUID {role_uuid} not found") queryset = RoleRagDocument.objects.filter( role=role, ) if not queryset.exists(): logger.warning(f"No documents with embeddings found for role {role.uuid}") return [] from django.db import connection with connection.cursor() as cursor: query_sql = """ SELECT id, 1 - (embedding <=> %s::vector) as similarity FROM mlstore_roleragdocument WHERE role_id = %s AND embedding IS NOT NULL ORDER BY similarity DESC LIMIT %s """ cursor.execute( query_sql, ) doc_ids_with_scores = cursor.fetchall() if not doc_ids_with_scores: logger.info(f"No similar documents found for query in role {role.uuid}") return [] filtered_docs = [ (doc_id, score) for doc_id, score in doc_ids_with_scores if score >= similarity_threshold ][:top_k] if not filtered_docs: logger.info( f"No documents met similarity threshold {similarity_threshold}" ) return [] doc_ids = [doc_id for doc_id, _ in filtered_docs] doc_scores = {doc_id: score for doc_id, score in filtered_docs} documents = RoleRagDocument.objects.filter(id__in=doc_ids) results = [ (doc, doc_scores[doc.id]) for doc in documents if doc.id in doc_scores ] results.sort(key=lambda x: x[1], reverse=True) logger.info( f"Found {len(results)} similar documents for query " f"(threshold={similarity_threshold}, top_k={top_k})" ) return results def batch_embed_documents( documents: List[RoleRagDocument], batch_size: int = 32, force_reembed: bool = False, ) -> Tuple[int, int]: """Batch embed documents that don't have embeddings yet. Args: documents: List of RoleRagDocument instances to embed. batch_size: Number of documents to embed per API call. force_reembed: If True, re-embed documents that already have embeddings. Returns: Tuple of (num_embedded, num_failed). Note: Updates documents in-place with embedding values. """ to_embed = [ doc for doc in documents if force_reembed or not doc.embedding ] if not to_embed: logger.info("No documents to embed") return 0, 0 num_embedded = 0 num_failed = 0 for i in range(0, len(to_embed), batch_size): batch = to_embed[i : i + batch_size] logger.info( f"Embedding batch {i // batch_size + 1} " f"({len(batch)} documents)" ) try: texts = [doc.content for doc in batch] embeddings = embed_texts(texts) for doc, embedding in zip(batch, embeddings): doc.embedding = embedding num_embedded += 1 RoleRagDocument.objects.bulk_update(batch, ["embedding"], batch_size=500) logger.info(f"Successfully embedded {len(batch)} documents") except Exception as e: logger.error(f"Failed to embed batch: {e}") num_failed += len(batch) logger.info( f"Embedding complete: {num_embedded} embedded, {num_failed} failed" ) return num_embedded, num_failed def get_context_for_query( query: str, role_uuid: str, top_k: int = 5, similarity_threshold: float = 0.5, ) -> str: """Get context string from similar documents for a query. Useful for augmenting prompts with retrieved context. Args: query: Query text. role_uuid: UUID of role to search within. top_k: Number of top results to include. similarity_threshold: Minimum similarity score. Returns: Formatted context string with source attribution. """ def _clean_chunk_text(text: str) -> str: """Strip junk and deduplicate paragraphs to keep context lean.""" if not text: return "" text = re.sub(r"\[\s*Answer\s*:.*?\]", "", text, flags=re.IGNORECASE | re.DOTALL) lines = [] for raw_line in text.splitlines(): line = raw_line.strip() if not line: lines.append("") continue lower = line.lower() if line.startswith("#"): continue if "do you have any questions" in lower: continue if "feel free to ask" in lower: continue if "references" in lower or "sources" in lower or "wikipedia" in lower: continue lines.append(line) cleaned = "\n".join(lines) paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", cleaned) if p.strip()] seen = set() unique_paragraphs: List[str] = [] for para in paragraphs: if para in seen: continue seen.add(para) unique_paragraphs.append(para) return "\n\n".join(unique_paragraphs) try: results = search_similar_documents( query=query, role_uuid=role_uuid, top_k=top_k, similarity_threshold=similarity_threshold, ) except Exception as e: logger.warning(f"Failed to retrieve context: {e}") return "" if not results: return "" context_parts = [] for doc, similarity in results: cleaned = _clean_chunk_text(doc.content) if not cleaned: continue source = "unknown" if doc.training_file: source = doc.training_file.file_name context_parts.append( f"[Source: {source}, Similarity: {similarity:.2%}]\n{cleaned}\n" ) context = "\n---\n".join(context_parts) return context