Added mcp tweaks and fixed failing tests with profanities
This commit is contained in:
parent
6aa98b2839
commit
20ac7f471c
2 changed files with 111 additions and 49 deletions
|
|
@ -1,7 +1,7 @@
|
|||
import httpx
|
||||
import logging
|
||||
import random
|
||||
|
||||
import httpx
|
||||
from channels.db import database_sync_to_async
|
||||
from django.conf import settings
|
||||
from django.db.models import Q
|
||||
|
|
@ -12,40 +12,35 @@ from apps.knowledge.models import RoleRagDocument
|
|||
from apps.onboarding.models import OnboardingSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
mcp_meta_value = 'mcp_tool_meta'
|
||||
|
||||
_MCP_TOOL_META = 'mcp_tool_meta'
|
||||
|
||||
def mcp_tool(name, description, input_schema):
|
||||
def decorator(func):
|
||||
setattr(func, mcp_meta_value, {
|
||||
setattr(func, _MCP_TOOL_META, {
|
||||
'name': name,
|
||||
'description': description,
|
||||
'inputSchema': input_schema,
|
||||
})
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _collect_tools(class_namespace):
|
||||
tools = []
|
||||
for method_name, value in class_namespace.items():
|
||||
metadata = getattr(value, mcp_meta_value, None)
|
||||
metadata = getattr(value, _MCP_TOOL_META, None)
|
||||
if not metadata:
|
||||
continue
|
||||
|
||||
tools.append(
|
||||
{
|
||||
tools.append({
|
||||
'name': metadata['name'],
|
||||
'method': method_name,
|
||||
'description': metadata['description'],
|
||||
'inputSchema': metadata['inputSchema'],
|
||||
}
|
||||
)
|
||||
|
||||
})
|
||||
return tools
|
||||
|
||||
|
||||
class MCPRouter:
|
||||
|
||||
def get_tool_definitions(self):
|
||||
return self.tools
|
||||
|
||||
|
|
@ -58,11 +53,7 @@ class MCPRouter:
|
|||
method = getattr(self, method_name, None)
|
||||
if method:
|
||||
result = await method(arguments)
|
||||
logger.info(
|
||||
'MCP tool call completed: tool=%s result=%s',
|
||||
name,
|
||||
result,
|
||||
)
|
||||
logger.info('MCP tool call completed: tool=%s result=%s', name, result)
|
||||
return result
|
||||
|
||||
logger.warning('MCP tool call rejected: unknown tool=%s', name)
|
||||
|
|
@ -73,9 +64,7 @@ class MCPRouter:
|
|||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
settings.INFERENCE_EMBEDDINGS_ENDPOINT,
|
||||
json={
|
||||
'input': text,
|
||||
},
|
||||
json={'input': text},
|
||||
)
|
||||
response.raise_for_status()
|
||||
embedding = response.json()['data'][0]['embedding']
|
||||
|
|
@ -130,11 +119,7 @@ class MCPRouter:
|
|||
}
|
||||
for d in docs
|
||||
]
|
||||
logger.info(
|
||||
'MCP search_knowledge_documents completed: role_uuid=%s results=%s',
|
||||
role_uuid,
|
||||
len(results),
|
||||
)
|
||||
logger.info('MCP search_knowledge_documents completed: role_uuid=%s results=%s', role_uuid, len(results))
|
||||
return results
|
||||
|
||||
@mcp_tool(
|
||||
|
|
@ -152,7 +137,11 @@ class MCPRouter:
|
|||
)
|
||||
@database_sync_to_async
|
||||
def _update_progress(self, args):
|
||||
session = OnboardingSession.objects.get(uuid=args.get('session_uuid'))
|
||||
session_uuid = args.get('session_uuid')
|
||||
session = OnboardingSession.objects.filter(uuid=session_uuid).first()
|
||||
if session is None:
|
||||
logger.warning('MCP update_progress session not found: session_uuid=%s', session_uuid)
|
||||
return {'error': 'Session not found'}
|
||||
|
||||
state = session.state or {}
|
||||
if 'score' in args:
|
||||
|
|
@ -162,10 +151,7 @@ class MCPRouter:
|
|||
|
||||
session.state = state
|
||||
session.save()
|
||||
logger.info(
|
||||
'MCP update_progress completed: session_uuid=%s',
|
||||
args.get('session_uuid'),
|
||||
)
|
||||
logger.info('MCP update_progress completed: session_uuid=%s', session_uuid)
|
||||
return {'status': 'success', 'new_state': state}
|
||||
|
||||
@mcp_tool(
|
||||
|
|
@ -181,12 +167,10 @@ class MCPRouter:
|
|||
},
|
||||
)
|
||||
async def _random_int(self, args):
|
||||
min_value = args.get('min')
|
||||
max_value = args.get('max')
|
||||
try:
|
||||
min_value = int(min_value)
|
||||
max_value = int(max_value)
|
||||
except Exception:
|
||||
min_value = int(args.get('min'))
|
||||
max_value = int(args.get('max'))
|
||||
except (TypeError, ValueError):
|
||||
logger.warning('MCP random_int invalid args: %s', args)
|
||||
return {'error': 'min and max must be integers'}
|
||||
|
||||
|
|
@ -194,15 +178,11 @@ class MCPRouter:
|
|||
min_value, max_value = max_value, min_value
|
||||
|
||||
value = random.randint(min_value, max_value)
|
||||
logger.info(
|
||||
'MCP random_int generated value=%s range=[%s,%s]',
|
||||
value,
|
||||
min_value,
|
||||
max_value,
|
||||
)
|
||||
logger.info('MCP random_int generated value=%s range=[%s,%s]', value, min_value, max_value)
|
||||
return {'value': value, 'min': min_value, 'max': max_value}
|
||||
|
||||
tools = _collect_tools(locals())
|
||||
_tool_name_to_method = {tool['name']: tool['method'] for tool in tools}
|
||||
|
||||
|
||||
mcp_router = MCPRouter()
|
||||
|
|
@ -1,13 +1,95 @@
|
|||
from asgiref.sync import async_to_sync
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from apps.accounts.models import Organization, Role
|
||||
from apps.onboarding.consumers import OnboardingConsumer
|
||||
from apps.onboarding.consumers import OnboardingGenerateConsumer
|
||||
from apps.onboarding.models import AgentConfig
|
||||
from apps.onboarding.utils.content_moderator import ContentModerator
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
_PROFANE = "fuck" # Common profanity...
|
||||
|
||||
class ContentModeratorTests(TestCase):
|
||||
def setUp(self):
|
||||
self.moderator = ContentModerator()
|
||||
|
||||
def test_clean_text_passes(self):
|
||||
self.assertTrue(self.moderator.is_clean("What is the onboarding process?"))
|
||||
|
||||
def test_profane_text_blocked(self):
|
||||
self.assertFalse(self.moderator.is_clean(f"this is {_PROFANE} content"))
|
||||
|
||||
def test_non_string_input_passes(self):
|
||||
self.assertTrue(self.moderator.is_clean(None))
|
||||
self.assertTrue(self.moderator.is_clean(123))
|
||||
|
||||
def test_censor_replaces_profanity(self):
|
||||
result = self.moderator.censor(f"this is {_PROFANE} content")
|
||||
self.assertNotIn(_PROFANE, result)
|
||||
|
||||
def test_censor_passes_clean_text_unchanged(self):
|
||||
text = "Please review the onboarding materials."
|
||||
self.assertEqual(self.moderator.censor(text), text)
|
||||
|
||||
def test_censor_non_string_returned_as_is(self):
|
||||
self.assertIsNone(self.moderator.censor(None))
|
||||
self.assertEqual(self.moderator.censor(42), 42)
|
||||
|
||||
|
||||
class ConsumerModerationTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(
|
||||
email_address='moderation-test@example.com',
|
||||
password='pass1234',
|
||||
first_name='Mod',
|
||||
last_name='Tester',
|
||||
date_of_birth='1995-05-05',
|
||||
is_manager=True,
|
||||
)
|
||||
self.org = Organization.objects.create(name='Moderation Test Org', owner=self.user)
|
||||
self.org.members.add(self.user)
|
||||
self.role = Role.objects.create(name='Mod Role', organization=self.org)
|
||||
self.consumer = OnboardingGenerateConsumer()
|
||||
self.consumer.user = self.user
|
||||
|
||||
def _run_receive(self, payload: str):
|
||||
return async_to_sync(self.consumer.receive)(payload)
|
||||
|
||||
def test_clean_query_is_dispatched(self):
|
||||
import json
|
||||
self.consumer.send_error = AsyncMock()
|
||||
with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action:
|
||||
self._run_receive(json.dumps({"action": "start_full_onboarding", "query": "Tell me about onboarding"}))
|
||||
mock_action.assert_called_once()
|
||||
self.consumer.send_error.assert_not_called()
|
||||
|
||||
def test_profane_query_is_blocked(self):
|
||||
import json
|
||||
self.consumer.send_error = AsyncMock()
|
||||
self._run_receive(json.dumps({"action": "start_full_onboarding", "query": f"this is {_PROFANE} content"}))
|
||||
self.consumer.send_error.assert_called_once()
|
||||
args = self.consumer.send_error.call_args[0]
|
||||
self.assertIn("moderation", args[0].lower())
|
||||
|
||||
def test_profane_message_field_is_blocked(self):
|
||||
import json
|
||||
self.consumer.send_error = AsyncMock()
|
||||
self._run_receive(json.dumps({"action": "start_full_onboarding", "message": f"this is {_PROFANE} content"}))
|
||||
self.consumer.send_error.assert_called_once()
|
||||
args = self.consumer.send_error.call_args[0]
|
||||
self.assertIn("moderation", args[0].lower())
|
||||
|
||||
def test_clean_message_field_is_dispatched(self):
|
||||
import json
|
||||
self.consumer.send_error = AsyncMock()
|
||||
with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action:
|
||||
self._run_receive(json.dumps({"action": "start_full_onboarding", "message": "begin onboarding"}))
|
||||
mock_action.assert_called_once()
|
||||
self.consumer.send_error.assert_not_called()
|
||||
|
||||
class OnboardingConsumerConfigSelectionTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(
|
||||
|
|
@ -24,7 +106,7 @@ class OnboardingConsumerConfigSelectionTests(TestCase):
|
|||
self.quant_role = Role.objects.create(name='Quant Role Consumer', organization=self.org)
|
||||
self.ux_role = Role.objects.create(name='UX Role Consumer', organization=self.org)
|
||||
|
||||
self.consumer = OnboardingConsumer()
|
||||
self.consumer = OnboardingGenerateConsumer()
|
||||
|
||||
def test_get_config_by_type_prefers_exact_role(self):
|
||||
quant_cfg = AgentConfig.objects.create(
|
||||
|
|
|
|||
Loading…
Reference in a new issue