import json import logging from enum import Enum import httpx from channels.db import database_sync_to_async from channels.generic.websocket import AsyncWebsocketConsumer from django.conf import settings from django.db.models import Q from django.utils import timezone from apps.accounts.models import User, Role from apps.accounts.permissions import get_organization_from_object, can_manage_organization from apps.onboarding.consumers.prompts import OnboardingPrompts from apps.onboarding.mcp import mcp_router, MCPRouter from apps.onboarding.models import AgentConfig from apps.onboarding.utils.content_moderator import ContentModerator __all__ = ["BaseOnboardingConsumer", "LogType"] logger = logging.getLogger(__name__) class LogType(Enum): STATUS = "status" # General progress updates ERROR = "error" # Failures INFO = "info" # Debug/Verbose logs THOUGHT = "thought" # AI internal reasoning turns TOOL_START = "tool_start" # When an MCP tool is called TOOL_RESULT = "tool_result" # When data comes back from a tool STREAM_CHUNK = "stream_chunk" # Incremental token from a streaming LLM response COMPLETED = "completed" # The final completion signal class BaseOnboardingConsumer(AsyncWebsocketConsumer): """ Base consumer for all onboarding-related WebSocket consumers. """ user: User router: MCPRouter logger: logging.Logger = logger moderator: ContentModerator = ContentModerator() ### Connection Management ### async def connect(self): self.user = self.scope["user"] if not self.user.is_authenticated: self.logger.warning("WebSocket connection rejected: unauthenticated user attempted to connect") return await self.close() self.parse_extra() self.router = mcp_router self.logger.info(f"WebSocket connected: user={self.user.full_name}") return await self.accept() async def disconnect(self, close_code: int): self.logger.info(f"WebSocket disconnected: user={self.user.full_name} close_code={close_code}") ### Event Handling ### async def receive(self, text_data: str): """ Main entry point for incoming messages. """ try: data = self.from_json(text_data) for field in ("query", "message"): value = data.get(field) if value and not self.moderator.is_clean(str(value)): return await self.send_error("Message blocked: content did not pass moderation.") action_name = data.get("action", "message") method = getattr(self, f"action_{action_name}", None) if method: self.logger.info(f"Dispatching action: {action_name}") await method(data) else: await self.send_error(f"Action '{action_name}' not supported on this endpoint.") except Exception as e: await self.send_error(f"An unexpected error occurred when processing the event.") self.logger.exception(f"WebSocket receive critical failure: {str(e)}") ### MCP Handling ### async def orchestrate(self, message: str, config: AgentConfig, minimum_turns: int = 2, maximum_turns: int = 5, max_tokens: int | None = None, raise_on_error: bool = False, request_timeout: int = 60.0) -> str: """ Orchestrates a multi-turn conversation with the agent, including tool calls and reasoning steps. """ resolved_max_tokens = max_tokens or 1024 messages = [ {"role": "system", "content": config.system_prompt or OnboardingPrompts.default_system_prompt()}, {"role": "user", "content": message} ] last_content = "" async with httpx.AsyncClient(timeout=request_timeout, auth=settings.INFERENCE_AUTH) as client: for turn in range(1, maximum_turns + 1): await self.send_log(LogType.THOUGHT, f"Agent reasoning (Turn {turn})...") try: response = await client.post( settings.INFERENCE_CHAT_COMPLETIONS_ENDPOINT, json={ "messages": messages, "tools": self.router.get_tool_definitions(), "tool_choice": "auto", "max_tokens": resolved_max_tokens, } ) response.raise_for_status() ai_message = response.json()["choices"][0]["message"] messages.append(ai_message) if ai_message.get("tool_calls"): for tool_call in ai_message["tool_calls"]: fn_name = tool_call["function"]["name"] fn_args = json.loads(tool_call["function"]["arguments"]) await self.send_log(LogType.TOOL_START, f"Calling tool: {fn_name}", content=fn_args) result = await self.router.handle_tool_call(fn_name, fn_args) await self.send_log(LogType.TOOL_RESULT, f"Tool {fn_name} returned data", content=result) messages.append({ "role": "tool", "tool_call_id": tool_call["id"], "name": fn_name, "content": self.to_json(result) }) continue else: last_content = self.moderator.censor(str(ai_message.get("content") or "").strip()) if turn < minimum_turns: messages.append({ "role": "user", "content": OnboardingPrompts.force_reasoning_prompt() }) continue return last_content except Exception as e: await self.send_error(f"AI Orchestration failed: {str(e)}") if raise_on_error: raise e return f"Error: {str(e)}" return last_content async def stream_llm(self, config, prompt: str, *, max_tokens: int = 1024, stop: list[str] | None = None, system_prompt_suffix: str | None = None) -> str | None: """Single-turn streaming LLM call. Sends STREAM_CHUNK events for each token and returns the full text.""" if not config: return None system_prompt = config.system_prompt or OnboardingPrompts.default_system_prompt() if system_prompt_suffix: system_prompt = system_prompt + "\n\n" + system_prompt_suffix payload: dict = { "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ], "max_tokens": max_tokens, "stream": True, } if stop: payload["stop"] = stop try: chunks: list[str] = [] async with httpx.AsyncClient(timeout=120.0, auth=settings.INFERENCE_AUTH) as client: async with client.stream("POST", settings.INFERENCE_CHAT_COMPLETIONS_ENDPOINT, json=payload) as response: response.raise_for_status() async for line in response.aiter_lines(): if not line.startswith("data: "): continue data = line[6:].strip() if data == "[DONE]": break try: chunk_obj = json.loads(data) choice = chunk_obj["choices"][0] delta = choice.get("delta", {}).get("content", "") if delta: chunks.append(delta) await self.send_log(LogType.STREAM_CHUNK, delta) if choice.get("finish_reason") == "length": self.logger.warning("LLM response truncated (finish_reason=length)") await self.send_log(LogType.STATUS, "Response was cut off, try increasing Max Tokens.") except Exception: continue return "".join(chunks).strip() or None except Exception as e: self.logger.exception("Streaming LLM call failed: %s", e) return None ### Regular Helpers ### async def send_log(self, log_type: LogType, message: str, content: str | dict | None = None): if log_type == LogType.ERROR: self.logger.error(f"[{log_type.value}]: message={str(message)[:100]} content={str(content)[:60]}") else: self.logger.info(f"[{log_type.value}]: message={str(message)[:100]} content={str(content)[:60]}") await self.send(self.to_json({ "type": log_type.value, "message": message, "content": content, "timestamp": self.get_timestamp()})) async def send_error(self, message: str): await self.send_log(LogType.ERROR, message) def parse_extra(self): """ Override for custom parsing """ pass def to_json(self, data: dict | list) -> str: return json.dumps(data, default=str) def from_json(self, data: str) -> dict | list: return json.loads(data) def get_timestamp(self): return timezone.now().isoformat() def parse_max_tokens(self, max_tokens: int | None) -> int: if max_tokens is None: return None if isinstance(max_tokens, int) and max_tokens > 0: return max_tokens return None ### Database Helpers ### @database_sync_to_async def get_config(self, config_uuid): return AgentConfig.objects.get(uuid = config_uuid) @database_sync_to_async def get_role(self, role_uuid): return Role.objects.get(uuid = role_uuid) @database_sync_to_async def get_config_by_type(self, role_uuid, agent_type): role_specific = AgentConfig.objects.filter( role__uuid=role_uuid, agent_type=agent_type, ).order_by('-updated_at').first() if role_specific: return role_specific return AgentConfig.objects.filter( organization__roles__uuid=role_uuid, role__isnull=True, agent_type=agent_type, ).order_by('-updated_at').first() @database_sync_to_async def get_config_for_user(self, config_uuid): return AgentConfig.objects.filter(uuid = config_uuid).filter( Q(organization__owner__id=self.user.id) | Q(organization__members__id=self.user.id) ).first() @database_sync_to_async def can_manage_role(self, role_uuid): role = Role.objects.filter(uuid=role_uuid).first() if role is None: return False return can_manage_organization(self.user, get_organization_from_object(role)) @database_sync_to_async def can_access_role(self, role_uuid): role = Role.objects.filter(uuid=role_uuid).first() if role is None: return False return role.members.filter(id=self.user.id).exists() or can_manage_organization(self.user, get_organization_from_object(role))