Dynavera/apps/onboarding/mcp.py

194 lines
No EOL
6 KiB
Python

import httpx
import logging
import random
from channels.db import database_sync_to_async
from django.conf import settings
from pgvector.django import CosineDistance
from apps.knowledge.models import RoleRagDocument
from apps.onboarding.models import OnboardingSession
logger = logging.getLogger(__name__)
def mcp_tool(name, description, input_schema):
def decorator(func):
func._mcp_tool_meta = {
'name': name,
'description': description,
'inputSchema': input_schema,
}
return func
return decorator
def _collect_tools(class_namespace):
tools = []
for method_name, value in class_namespace.items():
metadata = getattr(value, '_mcp_tool_meta', None)
if not metadata:
continue
tools.append(
{
'name': metadata['name'],
'method': method_name,
'description': metadata['description'],
'inputSchema': metadata['inputSchema'],
}
)
return tools
class MCPRouter:
def get_tool_definitions(self):
return self.tools
async def handle_tool_call(self, name, arguments):
logger.info('MCP tool call received: tool=%s args=%s', name, arguments)
arguments = arguments or {}
method_name = self._tool_name_to_method.get(name)
if method_name:
method = getattr(self, method_name, None)
if method:
result = await method(arguments)
logger.info(
'MCP tool call completed: tool=%s result=%s',
name,
result,
)
return result
logger.warning('MCP tool call rejected: unknown tool=%s', name)
return {'error': f'Tool {name} not found'}
async def _get_embedding(self, text):
logger.info('MCP embedding request started')
async with httpx.AsyncClient() as client:
response = await client.post(
settings.INFERENCE_EMBEDDINGS_ENDPOINT,
json={'input': text},
)
response.raise_for_status()
embedding = response.json()['data'][0]['embedding']
logger.info('MCP embedding request completed')
return embedding
@mcp_tool(
name='search_knowledge',
description='Search the RAG database for role-specific training content.',
input_schema={
'type': 'object',
'properties': {
'query': {'type': 'string'},
'role_uuid': {'type': 'string'},
},
'required': ['query', 'role_uuid'],
},
)
async def _search_knowledge(self, args):
query = args.get('query')
role_uuid = args.get('role_uuid')
if not query or not role_uuid:
logger.warning('MCP search_knowledge missing query or role_uuid')
return []
query_vector = await self._get_embedding(query)
return await self._search_knowledge_documents(role_uuid, query_vector)
@database_sync_to_async
def _search_knowledge_documents(self, role_uuid, query_vector):
docs = RoleRagDocument.objects.filter(
role__uuid=role_uuid,
is_active=True,
).annotate(
distance=CosineDistance('embedding', query_vector)
).order_by('distance')[:5]
results = [
{
'content': d.content,
'source': d.metadata.get('file_name', 'Unknown Source'),
'relevance': round(1 - d.distance, 4),
}
for d in docs
]
logger.info(
'MCP search_knowledge_documents completed: role_uuid=%s results=%s',
role_uuid,
len(results),
)
return results
@mcp_tool(
name='update_progress',
description="Update the user's score or current module in their session.",
input_schema={
'type': 'object',
'properties': {
'session_uuid': {'type': 'string'},
'score': {'type': 'integer'},
'completed_module': {'type': 'string'},
},
'required': ['session_uuid'],
},
)
@database_sync_to_async
def _update_progress(self, args):
session = OnboardingSession.objects.get(uuid=args.get('session_uuid'))
state = session.state or {}
if 'score' in args:
state['last_score'] = args['score']
if 'completed_module' in args:
state.setdefault('completed_modules', []).append(args['completed_module'])
session.state = state
session.save()
logger.info(
'MCP update_progress completed: session_uuid=%s',
args.get('session_uuid'),
)
return {'status': 'success', 'new_state': state}
@mcp_tool(
name='random_int',
description='Generate a random integer in an inclusive range.',
input_schema={
'type': 'object',
'properties': {
'min': {'type': 'integer'},
'max': {'type': 'integer'},
},
'required': ['min', 'max'],
},
)
async def _random_int(self, args):
min_value = args.get('min')
max_value = args.get('max')
try:
min_value = int(min_value)
max_value = int(max_value)
except Exception:
logger.warning('MCP random_int invalid args: %s', args)
return {'error': 'min and max must be integers'}
if min_value > max_value:
min_value, max_value = max_value, min_value
value = random.randint(min_value, max_value)
logger.info(
'MCP random_int generated value=%s range=[%s,%s]',
value,
min_value,
max_value,
)
return {'value': value, 'min': min_value, 'max': max_value}
tools = _collect_tools(locals())
_tool_name_to_method = {tool['name']: tool['method'] for tool in tools}