Added mcp tweaks and fixed failing tests with profanities

This commit is contained in:
Viswamedha Nalabotu 2026-03-18 01:04:16 +00:00
parent 6aa98b2839
commit 20ac7f471c
2 changed files with 111 additions and 49 deletions

View file

@ -1,7 +1,7 @@
import httpx
import logging import logging
import random import random
import httpx
from channels.db import database_sync_to_async from channels.db import database_sync_to_async
from django.conf import settings from django.conf import settings
from django.db.models import Q from django.db.models import Q
@ -12,40 +12,35 @@ from apps.knowledge.models import RoleRagDocument
from apps.onboarding.models import OnboardingSession from apps.onboarding.models import OnboardingSession
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
mcp_meta_value = 'mcp_tool_meta'
_MCP_TOOL_META = 'mcp_tool_meta'
def mcp_tool(name, description, input_schema): def mcp_tool(name, description, input_schema):
def decorator(func): def decorator(func):
setattr(func, mcp_meta_value, { setattr(func, _MCP_TOOL_META, {
'name': name, 'name': name,
'description': description, 'description': description,
'inputSchema': input_schema, 'inputSchema': input_schema,
}) })
return func return func
return decorator return decorator
def _collect_tools(class_namespace): def _collect_tools(class_namespace):
tools = [] tools = []
for method_name, value in class_namespace.items(): for method_name, value in class_namespace.items():
metadata = getattr(value, mcp_meta_value, None) metadata = getattr(value, _MCP_TOOL_META, None)
if not metadata: if not metadata:
continue continue
tools.append({
tools.append( 'name': metadata['name'],
{ 'method': method_name,
'name': metadata['name'], 'description': metadata['description'],
'method': method_name, 'inputSchema': metadata['inputSchema'],
'description': metadata['description'], })
'inputSchema': metadata['inputSchema'],
}
)
return tools return tools
class MCPRouter: class MCPRouter:
def get_tool_definitions(self): def get_tool_definitions(self):
return self.tools return self.tools
@ -58,11 +53,7 @@ class MCPRouter:
method = getattr(self, method_name, None) method = getattr(self, method_name, None)
if method: if method:
result = await method(arguments) result = await method(arguments)
logger.info( logger.info('MCP tool call completed: tool=%s result=%s', name, result)
'MCP tool call completed: tool=%s result=%s',
name,
result,
)
return result return result
logger.warning('MCP tool call rejected: unknown tool=%s', name) logger.warning('MCP tool call rejected: unknown tool=%s', name)
@ -73,9 +64,7 @@ class MCPRouter:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
settings.INFERENCE_EMBEDDINGS_ENDPOINT, settings.INFERENCE_EMBEDDINGS_ENDPOINT,
json={ json={'input': text},
'input': text,
},
) )
response.raise_for_status() response.raise_for_status()
embedding = response.json()['data'][0]['embedding'] embedding = response.json()['data'][0]['embedding']
@ -130,11 +119,7 @@ class MCPRouter:
} }
for d in docs for d in docs
] ]
logger.info( logger.info('MCP search_knowledge_documents completed: role_uuid=%s results=%s', role_uuid, len(results))
'MCP search_knowledge_documents completed: role_uuid=%s results=%s',
role_uuid,
len(results),
)
return results return results
@mcp_tool( @mcp_tool(
@ -152,7 +137,11 @@ class MCPRouter:
) )
@database_sync_to_async @database_sync_to_async
def _update_progress(self, args): def _update_progress(self, args):
session = OnboardingSession.objects.get(uuid=args.get('session_uuid')) session_uuid = args.get('session_uuid')
session = OnboardingSession.objects.filter(uuid=session_uuid).first()
if session is None:
logger.warning('MCP update_progress session not found: session_uuid=%s', session_uuid)
return {'error': 'Session not found'}
state = session.state or {} state = session.state or {}
if 'score' in args: if 'score' in args:
@ -162,10 +151,7 @@ class MCPRouter:
session.state = state session.state = state
session.save() session.save()
logger.info( logger.info('MCP update_progress completed: session_uuid=%s', session_uuid)
'MCP update_progress completed: session_uuid=%s',
args.get('session_uuid'),
)
return {'status': 'success', 'new_state': state} return {'status': 'success', 'new_state': state}
@mcp_tool( @mcp_tool(
@ -181,12 +167,10 @@ class MCPRouter:
}, },
) )
async def _random_int(self, args): async def _random_int(self, args):
min_value = args.get('min')
max_value = args.get('max')
try: try:
min_value = int(min_value) min_value = int(args.get('min'))
max_value = int(max_value) max_value = int(args.get('max'))
except Exception: except (TypeError, ValueError):
logger.warning('MCP random_int invalid args: %s', args) logger.warning('MCP random_int invalid args: %s', args)
return {'error': 'min and max must be integers'} return {'error': 'min and max must be integers'}
@ -194,15 +178,11 @@ class MCPRouter:
min_value, max_value = max_value, min_value min_value, max_value = max_value, min_value
value = random.randint(min_value, max_value) value = random.randint(min_value, max_value)
logger.info( logger.info('MCP random_int generated value=%s range=[%s,%s]', value, min_value, max_value)
'MCP random_int generated value=%s range=[%s,%s]',
value,
min_value,
max_value,
)
return {'value': value, 'min': min_value, 'max': max_value} return {'value': value, 'min': min_value, 'max': max_value}
tools = _collect_tools(locals()) tools = _collect_tools(locals())
_tool_name_to_method = {tool['name']: tool['method'] for tool in tools} _tool_name_to_method = {tool['name']: tool['method'] for tool in tools}
mcp_router = MCPRouter()
mcp_router = MCPRouter()

View file

@ -1,13 +1,95 @@
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.test import TestCase from django.test import TestCase
from unittest.mock import AsyncMock, patch
from apps.accounts.models import Organization, Role from apps.accounts.models import Organization, Role
from apps.onboarding.consumers import OnboardingConsumer from apps.onboarding.consumers import OnboardingGenerateConsumer
from apps.onboarding.models import AgentConfig from apps.onboarding.models import AgentConfig
from apps.onboarding.utils.content_moderator import ContentModerator
User = get_user_model() User = get_user_model()
_PROFANE = "fuck" # Common profanity...
class ContentModeratorTests(TestCase):
def setUp(self):
self.moderator = ContentModerator()
def test_clean_text_passes(self):
self.assertTrue(self.moderator.is_clean("What is the onboarding process?"))
def test_profane_text_blocked(self):
self.assertFalse(self.moderator.is_clean(f"this is {_PROFANE} content"))
def test_non_string_input_passes(self):
self.assertTrue(self.moderator.is_clean(None))
self.assertTrue(self.moderator.is_clean(123))
def test_censor_replaces_profanity(self):
result = self.moderator.censor(f"this is {_PROFANE} content")
self.assertNotIn(_PROFANE, result)
def test_censor_passes_clean_text_unchanged(self):
text = "Please review the onboarding materials."
self.assertEqual(self.moderator.censor(text), text)
def test_censor_non_string_returned_as_is(self):
self.assertIsNone(self.moderator.censor(None))
self.assertEqual(self.moderator.censor(42), 42)
class ConsumerModerationTests(TestCase):
def setUp(self):
self.user = User.objects.create_user(
email_address='moderation-test@example.com',
password='pass1234',
first_name='Mod',
last_name='Tester',
date_of_birth='1995-05-05',
is_manager=True,
)
self.org = Organization.objects.create(name='Moderation Test Org', owner=self.user)
self.org.members.add(self.user)
self.role = Role.objects.create(name='Mod Role', organization=self.org)
self.consumer = OnboardingGenerateConsumer()
self.consumer.user = self.user
def _run_receive(self, payload: str):
return async_to_sync(self.consumer.receive)(payload)
def test_clean_query_is_dispatched(self):
import json
self.consumer.send_error = AsyncMock()
with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action:
self._run_receive(json.dumps({"action": "start_full_onboarding", "query": "Tell me about onboarding"}))
mock_action.assert_called_once()
self.consumer.send_error.assert_not_called()
def test_profane_query_is_blocked(self):
import json
self.consumer.send_error = AsyncMock()
self._run_receive(json.dumps({"action": "start_full_onboarding", "query": f"this is {_PROFANE} content"}))
self.consumer.send_error.assert_called_once()
args = self.consumer.send_error.call_args[0]
self.assertIn("moderation", args[0].lower())
def test_profane_message_field_is_blocked(self):
import json
self.consumer.send_error = AsyncMock()
self._run_receive(json.dumps({"action": "start_full_onboarding", "message": f"this is {_PROFANE} content"}))
self.consumer.send_error.assert_called_once()
args = self.consumer.send_error.call_args[0]
self.assertIn("moderation", args[0].lower())
def test_clean_message_field_is_dispatched(self):
import json
self.consumer.send_error = AsyncMock()
with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action:
self._run_receive(json.dumps({"action": "start_full_onboarding", "message": "begin onboarding"}))
mock_action.assert_called_once()
self.consumer.send_error.assert_not_called()
class OnboardingConsumerConfigSelectionTests(TestCase): class OnboardingConsumerConfigSelectionTests(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user( self.user = User.objects.create_user(
@ -24,7 +106,7 @@ class OnboardingConsumerConfigSelectionTests(TestCase):
self.quant_role = Role.objects.create(name='Quant Role Consumer', organization=self.org) self.quant_role = Role.objects.create(name='Quant Role Consumer', organization=self.org)
self.ux_role = Role.objects.create(name='UX Role Consumer', organization=self.org) self.ux_role = Role.objects.create(name='UX Role Consumer', organization=self.org)
self.consumer = OnboardingConsumer() self.consumer = OnboardingGenerateConsumer()
def test_get_config_by_type_prefers_exact_role(self): def test_get_config_by_type_prefers_exact_role(self):
quant_cfg = AgentConfig.objects.create( quant_cfg = AgentConfig.objects.create(