Dynavera/apps/onboarding/mcp.py

102 lines
3.6 KiB
Python
Raw Normal View History

import httpx
from channels.db import database_sync_to_async
from django.conf import settings
from pgvector.django import CosineDistance
from apps.knowledge.models import RoleRagDocument
from apps.onboarding.models import OnboardingSession
class MCPRouter:
def get_tool_definitions(self):
return [
{
"name": "search_knowledge",
"description": "Search the RAG database for role-specific training content.",
"inputSchema": {
"type": "object",
"properties": {
"query": {"type": "string"},
"role_uuid": {"type": "string"}
},
"required": ["query", "role_uuid"]
}
},
{
"name": "update_progress",
"description": "Update the user's score or current module in their session.",
"inputSchema": {
"type": "object",
"properties": {
"session_uuid": {"type": "string"},
"score": {"type": "integer"},
"completed_module": {"type": "string"}
},
"required": ["session_uuid"]
}
}
]
async def handle_tool_call(self, name, arguments):
if name == "search_knowledge":
return await self._search_knowledge(arguments)
elif name == "update_progress":
return await self._update_progress(arguments)
return {"error": f"Tool {name} not found"}
async def _get_embedding(self, text):
"""Fetch embedding from the GPU node."""
async with httpx.AsyncClient() as client:
response = await client.post(
f"{settings.INFERENCE_URL}/v1/embeddings",
json={"input": text}
)
return response.json()["data"][0]["embedding"]
async def _search_knowledge(self, args):
query = args.get("query")
role_uuid = args.get("role_uuid")
if not query or not role_uuid:
return []
query_vector = await self._get_embedding(query)
return await self._search_knowledge_documents(role_uuid, query_vector)
@database_sync_to_async
def _search_knowledge_documents(self, role_uuid, query_vector):
docs = RoleRagDocument.objects.filter(
role__uuid=role_uuid,
is_active=True
).annotate(
distance=CosineDistance('embedding', query_vector)
).order_by('distance')[:5]
return [
{
"content": d.content,
"source": d.metadata.get("file_name", "Unknown Source"),
"relevance": round(1 - d.distance, 4)
}
for d in docs
]
@database_sync_to_async
def _update_progress(self, args):
session = OnboardingSession.objects.get(uuid=args.get("session_uuid"))
state = session.state or {}
if "score" in args:
state["last_score"] = args["score"]
if "completed_module" in args:
state.setdefault("completed_modules", []).append(args["completed_module"])
session.state = state
session.save()
return {"status": "success", "new_state": state}