Added mcp tweaks and fixed failing tests with profanities

This commit is contained in:
Viswamedha Nalabotu 2026-03-18 01:04:16 +00:00
parent 6aa98b2839
commit 20ac7f471c
2 changed files with 111 additions and 49 deletions

View file

@ -1,7 +1,7 @@
import httpx
import logging
import random
import httpx
from channels.db import database_sync_to_async
from django.conf import settings
from django.db.models import Q
@ -12,40 +12,35 @@ from apps.knowledge.models import RoleRagDocument
from apps.onboarding.models import OnboardingSession
logger = logging.getLogger(__name__)
mcp_meta_value = 'mcp_tool_meta'
_MCP_TOOL_META = 'mcp_tool_meta'
def mcp_tool(name, description, input_schema):
def decorator(func):
setattr(func, mcp_meta_value, {
setattr(func, _MCP_TOOL_META, {
'name': name,
'description': description,
'inputSchema': input_schema,
})
return func
return decorator
def _collect_tools(class_namespace):
tools = []
for method_name, value in class_namespace.items():
metadata = getattr(value, mcp_meta_value, None)
metadata = getattr(value, _MCP_TOOL_META, None)
if not metadata:
continue
tools.append(
{
'name': metadata['name'],
'method': method_name,
'description': metadata['description'],
'inputSchema': metadata['inputSchema'],
}
)
tools.append({
'name': metadata['name'],
'method': method_name,
'description': metadata['description'],
'inputSchema': metadata['inputSchema'],
})
return tools
class MCPRouter:
def get_tool_definitions(self):
return self.tools
@ -58,11 +53,7 @@ class MCPRouter:
method = getattr(self, method_name, None)
if method:
result = await method(arguments)
logger.info(
'MCP tool call completed: tool=%s result=%s',
name,
result,
)
logger.info('MCP tool call completed: tool=%s result=%s', name, result)
return result
logger.warning('MCP tool call rejected: unknown tool=%s', name)
@ -73,9 +64,7 @@ class MCPRouter:
async with httpx.AsyncClient() as client:
response = await client.post(
settings.INFERENCE_EMBEDDINGS_ENDPOINT,
json={
'input': text,
},
json={'input': text},
)
response.raise_for_status()
embedding = response.json()['data'][0]['embedding']
@ -130,11 +119,7 @@ class MCPRouter:
}
for d in docs
]
logger.info(
'MCP search_knowledge_documents completed: role_uuid=%s results=%s',
role_uuid,
len(results),
)
logger.info('MCP search_knowledge_documents completed: role_uuid=%s results=%s', role_uuid, len(results))
return results
@mcp_tool(
@ -152,7 +137,11 @@ class MCPRouter:
)
@database_sync_to_async
def _update_progress(self, args):
session = OnboardingSession.objects.get(uuid=args.get('session_uuid'))
session_uuid = args.get('session_uuid')
session = OnboardingSession.objects.filter(uuid=session_uuid).first()
if session is None:
logger.warning('MCP update_progress session not found: session_uuid=%s', session_uuid)
return {'error': 'Session not found'}
state = session.state or {}
if 'score' in args:
@ -162,10 +151,7 @@ class MCPRouter:
session.state = state
session.save()
logger.info(
'MCP update_progress completed: session_uuid=%s',
args.get('session_uuid'),
)
logger.info('MCP update_progress completed: session_uuid=%s', session_uuid)
return {'status': 'success', 'new_state': state}
@mcp_tool(
@ -181,12 +167,10 @@ class MCPRouter:
},
)
async def _random_int(self, args):
min_value = args.get('min')
max_value = args.get('max')
try:
min_value = int(min_value)
max_value = int(max_value)
except Exception:
min_value = int(args.get('min'))
max_value = int(args.get('max'))
except (TypeError, ValueError):
logger.warning('MCP random_int invalid args: %s', args)
return {'error': 'min and max must be integers'}
@ -194,15 +178,11 @@ class MCPRouter:
min_value, max_value = max_value, min_value
value = random.randint(min_value, max_value)
logger.info(
'MCP random_int generated value=%s range=[%s,%s]',
value,
min_value,
max_value,
)
logger.info('MCP random_int generated value=%s range=[%s,%s]', value, min_value, max_value)
return {'value': value, 'min': min_value, 'max': max_value}
tools = _collect_tools(locals())
_tool_name_to_method = {tool['name']: tool['method'] for tool in tools}
mcp_router = MCPRouter()
mcp_router = MCPRouter()

View file

@ -1,13 +1,95 @@
from asgiref.sync import async_to_sync
from django.contrib.auth import get_user_model
from django.test import TestCase
from unittest.mock import AsyncMock, patch
from apps.accounts.models import Organization, Role
from apps.onboarding.consumers import OnboardingConsumer
from apps.onboarding.consumers import OnboardingGenerateConsumer
from apps.onboarding.models import AgentConfig
from apps.onboarding.utils.content_moderator import ContentModerator
User = get_user_model()
_PROFANE = "fuck" # Common profanity...
class ContentModeratorTests(TestCase):
def setUp(self):
self.moderator = ContentModerator()
def test_clean_text_passes(self):
self.assertTrue(self.moderator.is_clean("What is the onboarding process?"))
def test_profane_text_blocked(self):
self.assertFalse(self.moderator.is_clean(f"this is {_PROFANE} content"))
def test_non_string_input_passes(self):
self.assertTrue(self.moderator.is_clean(None))
self.assertTrue(self.moderator.is_clean(123))
def test_censor_replaces_profanity(self):
result = self.moderator.censor(f"this is {_PROFANE} content")
self.assertNotIn(_PROFANE, result)
def test_censor_passes_clean_text_unchanged(self):
text = "Please review the onboarding materials."
self.assertEqual(self.moderator.censor(text), text)
def test_censor_non_string_returned_as_is(self):
self.assertIsNone(self.moderator.censor(None))
self.assertEqual(self.moderator.censor(42), 42)
class ConsumerModerationTests(TestCase):
def setUp(self):
self.user = User.objects.create_user(
email_address='moderation-test@example.com',
password='pass1234',
first_name='Mod',
last_name='Tester',
date_of_birth='1995-05-05',
is_manager=True,
)
self.org = Organization.objects.create(name='Moderation Test Org', owner=self.user)
self.org.members.add(self.user)
self.role = Role.objects.create(name='Mod Role', organization=self.org)
self.consumer = OnboardingGenerateConsumer()
self.consumer.user = self.user
def _run_receive(self, payload: str):
return async_to_sync(self.consumer.receive)(payload)
def test_clean_query_is_dispatched(self):
import json
self.consumer.send_error = AsyncMock()
with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action:
self._run_receive(json.dumps({"action": "start_full_onboarding", "query": "Tell me about onboarding"}))
mock_action.assert_called_once()
self.consumer.send_error.assert_not_called()
def test_profane_query_is_blocked(self):
import json
self.consumer.send_error = AsyncMock()
self._run_receive(json.dumps({"action": "start_full_onboarding", "query": f"this is {_PROFANE} content"}))
self.consumer.send_error.assert_called_once()
args = self.consumer.send_error.call_args[0]
self.assertIn("moderation", args[0].lower())
def test_profane_message_field_is_blocked(self):
import json
self.consumer.send_error = AsyncMock()
self._run_receive(json.dumps({"action": "start_full_onboarding", "message": f"this is {_PROFANE} content"}))
self.consumer.send_error.assert_called_once()
args = self.consumer.send_error.call_args[0]
self.assertIn("moderation", args[0].lower())
def test_clean_message_field_is_dispatched(self):
import json
self.consumer.send_error = AsyncMock()
with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action:
self._run_receive(json.dumps({"action": "start_full_onboarding", "message": "begin onboarding"}))
mock_action.assert_called_once()
self.consumer.send_error.assert_not_called()
class OnboardingConsumerConfigSelectionTests(TestCase):
def setUp(self):
self.user = User.objects.create_user(
@ -24,7 +106,7 @@ class OnboardingConsumerConfigSelectionTests(TestCase):
self.quant_role = Role.objects.create(name='Quant Role Consumer', organization=self.org)
self.ux_role = Role.objects.create(name='UX Role Consumer', organization=self.org)
self.consumer = OnboardingConsumer()
self.consumer = OnboardingGenerateConsumer()
def test_get_config_by_type_prefers_exact_role(self):
quant_cfg = AgentConfig.objects.create(