253 lines
11 KiB
Python
253 lines
11 KiB
Python
import json
|
|
import logging
|
|
from enum import Enum
|
|
|
|
import httpx
|
|
from channels.db import database_sync_to_async
|
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
|
from django.conf import settings
|
|
from django.db.models import Q
|
|
from django.utils import timezone
|
|
|
|
from apps.accounts.models import User, Role
|
|
from apps.accounts.permissions import get_organization_from_object, can_manage_organization
|
|
from apps.onboarding.consumers.prompts import OnboardingPrompts
|
|
from apps.onboarding.mcp import mcp_router, MCPRouter
|
|
from apps.onboarding.models import AgentConfig
|
|
from apps.onboarding.utils.content_moderator import ContentModerator
|
|
|
|
__all__ = ["BaseOnboardingConsumer", "LogType"]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class LogType(Enum):
|
|
STATUS = "status" # General progress updates
|
|
ERROR = "error" # Failures
|
|
INFO = "info" # Debug/Verbose logs
|
|
THOUGHT = "thought" # AI internal reasoning turns
|
|
TOOL_START = "tool_start" # When an MCP tool is called
|
|
TOOL_RESULT = "tool_result" # When data comes back from a tool
|
|
STREAM_CHUNK = "stream_chunk" # Incremental token from a streaming LLM response
|
|
COMPLETED = "completed" # The final completion signal
|
|
|
|
class BaseOnboardingConsumer(AsyncWebsocketConsumer):
|
|
"""
|
|
Base consumer for all onboarding-related WebSocket consumers.
|
|
"""
|
|
|
|
user: User
|
|
router: MCPRouter
|
|
logger: logging.Logger = logger
|
|
moderator: ContentModerator = ContentModerator()
|
|
|
|
async def connect(self):
|
|
self.user = self.scope["user"]
|
|
if not self.user.is_authenticated:
|
|
self.logger.warning("WebSocket connection rejected: unauthenticated user attempted to connect")
|
|
return await self.close()
|
|
self.parse_extra()
|
|
self.router = mcp_router
|
|
self.logger.info(f"WebSocket connected: user={self.user.full_name}")
|
|
return await self.accept()
|
|
|
|
async def disconnect(self, close_code: int):
|
|
self.logger.info(f"WebSocket disconnected: user={self.user.full_name} close_code={close_code}")
|
|
|
|
async def receive(self, text_data: str):
|
|
"""
|
|
Main entry point for incoming messages.
|
|
"""
|
|
try:
|
|
data = self.from_json(text_data)
|
|
for field in ("query", "message"):
|
|
value = data.get(field)
|
|
if value and not self.moderator.is_clean(str(value)):
|
|
return await self.send_error("Message blocked: content did not pass moderation.")
|
|
action_name = data.get("action", "message")
|
|
method = getattr(self, f"action_{action_name}", None)
|
|
if method:
|
|
self.logger.info(f"Dispatching action: {action_name}")
|
|
await method(data)
|
|
else:
|
|
await self.send_error(f"Action '{action_name}' not supported on this endpoint.")
|
|
except Exception as e:
|
|
await self.send_error(f"An unexpected error occurred when processing the event.")
|
|
self.logger.exception(f"WebSocket receive critical failure: {str(e)}")
|
|
|
|
async def orchestrate(self, message: str, config: AgentConfig, minimum_turns: int = 2, maximum_turns: int = 5,
|
|
max_tokens: int | None = None, raise_on_error: bool = False, request_timeout: float = settings.INFERENCE_REQUEST_TIMEOUT) -> str:
|
|
"""
|
|
Orchestrates a multi-turn conversation with the agent, including tool calls and reasoning steps.
|
|
"""
|
|
resolved_max_tokens = max_tokens or 1024
|
|
messages = [
|
|
{"role": "system", "content": config.system_prompt or OnboardingPrompts.default_system_prompt()},
|
|
{"role": "user", "content": message}
|
|
]
|
|
last_content = ""
|
|
async with httpx.AsyncClient(timeout=request_timeout, auth=settings.INFERENCE_AUTH) as client:
|
|
for turn in range(1, maximum_turns + 1):
|
|
await self.send_log(LogType.THOUGHT, f"Agent reasoning (Turn {turn})...")
|
|
try:
|
|
response = await client.post(
|
|
settings.INFERENCE_CHAT_COMPLETIONS_ENDPOINT,
|
|
json={
|
|
"messages": messages,
|
|
"tools": self.router.get_tool_definitions(),
|
|
"tool_choice": "auto",
|
|
"max_tokens": resolved_max_tokens,
|
|
}
|
|
)
|
|
response.raise_for_status()
|
|
ai_message = response.json()["choices"][0]["message"]
|
|
messages.append(ai_message)
|
|
if ai_message.get("tool_calls"):
|
|
for tool_call in ai_message["tool_calls"]:
|
|
fn_name = tool_call["function"]["name"]
|
|
fn_args = json.loads(tool_call["function"]["arguments"])
|
|
await self.send_log(LogType.TOOL_START, f"Calling tool: {fn_name}", content=fn_args)
|
|
result = await self.router.handle_tool_call(fn_name, fn_args)
|
|
await self.send_log(LogType.TOOL_RESULT, f"Tool {fn_name} returned data", content=result)
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call["id"],
|
|
"name": fn_name,
|
|
"content": self.to_json(result)
|
|
})
|
|
continue
|
|
else:
|
|
last_content = self.moderator.censor(str(ai_message.get("content") or "").strip())
|
|
if turn < minimum_turns:
|
|
messages.append({
|
|
"role": "user",
|
|
"content": OnboardingPrompts.force_reasoning_prompt()
|
|
})
|
|
continue
|
|
return last_content
|
|
except Exception as e:
|
|
await self.send_error(f"AI Orchestration failed: {str(e)}")
|
|
if raise_on_error:
|
|
raise e
|
|
return f"Error: {str(e)}"
|
|
return last_content
|
|
|
|
async def stream_llm(self, config, prompt: str, *, max_tokens: int = 1024, stop: list[str] | None = None, system_prompt_suffix: str | None = None) -> str | None:
|
|
"""Single-turn streaming LLM call. Sends STREAM_CHUNK events for each token and returns the full text."""
|
|
if not config:
|
|
return None
|
|
system_prompt = config.system_prompt or OnboardingPrompts.default_system_prompt()
|
|
if system_prompt_suffix:
|
|
system_prompt = system_prompt + "\n\n" + system_prompt_suffix
|
|
payload: dict = {
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
"max_tokens": max_tokens,
|
|
"stream": True,
|
|
}
|
|
if stop:
|
|
payload["stop"] = stop
|
|
try:
|
|
chunks: list[str] = []
|
|
async with httpx.AsyncClient(timeout=settings.INFERENCE_STREAM_TIMEOUT, auth=settings.INFERENCE_AUTH) as client:
|
|
async with client.stream("POST", settings.INFERENCE_CHAT_COMPLETIONS_ENDPOINT, json=payload) as response:
|
|
response.raise_for_status()
|
|
async for line in response.aiter_lines():
|
|
if not line.startswith("data: "):
|
|
continue
|
|
data = line[6:].strip()
|
|
if data == "[DONE]":
|
|
break
|
|
try:
|
|
chunk_obj = json.loads(data)
|
|
choice = chunk_obj["choices"][0]
|
|
delta = choice.get("delta", {}).get("content", "")
|
|
if delta:
|
|
chunks.append(delta)
|
|
await self.send_log(LogType.STREAM_CHUNK, delta)
|
|
if choice.get("finish_reason") == "length":
|
|
self.logger.warning("LLM response truncated (finish_reason=length)")
|
|
await self.send_log(LogType.STATUS, "Response was cut off, try increasing Max Tokens.")
|
|
except Exception:
|
|
continue
|
|
return "".join(chunks).strip() or None
|
|
except Exception as e:
|
|
self.logger.exception("Streaming LLM call failed: %s", e)
|
|
return None
|
|
|
|
async def send_log(self, log_type: LogType, message: str, content: str | dict | None = None):
|
|
if log_type == LogType.ERROR:
|
|
self.logger.error(f"[{log_type.value}]: message={str(message)[:100]} content={str(content)[:60]}")
|
|
else:
|
|
self.logger.info(f"[{log_type.value}]: message={str(message)[:100]} content={str(content)[:60]}")
|
|
await self.send(self.to_json({ "type": log_type.value, "message": message, "content": content, "timestamp": self.get_timestamp()}))
|
|
|
|
async def send_error(self, message: str):
|
|
await self.send_log(LogType.ERROR, message)
|
|
|
|
def parse_extra(self):
|
|
"""
|
|
Override for custom parsing
|
|
"""
|
|
pass
|
|
|
|
def to_json(self, data: dict | list) -> str:
|
|
return json.dumps(data, default=str)
|
|
|
|
def from_json(self, data: str) -> dict | list:
|
|
return json.loads(data)
|
|
|
|
def get_timestamp(self):
|
|
return timezone.now().isoformat()
|
|
|
|
def parse_max_tokens(self, max_tokens: int | None) -> int:
|
|
if max_tokens is None:
|
|
return None
|
|
if isinstance(max_tokens, int) and max_tokens > 0:
|
|
return max_tokens
|
|
return None
|
|
|
|
@database_sync_to_async
|
|
def get_config(self, config_uuid):
|
|
return AgentConfig.objects.get(uuid = config_uuid)
|
|
|
|
@database_sync_to_async
|
|
def get_role(self, role_uuid):
|
|
return Role.objects.get(uuid = role_uuid)
|
|
|
|
@database_sync_to_async
|
|
def get_config_by_type(self, role_uuid, agent_type):
|
|
role_specific = AgentConfig.objects.filter(
|
|
role__uuid=role_uuid,
|
|
agent_type=agent_type,
|
|
).order_by('-updated_at').first()
|
|
|
|
if role_specific:
|
|
return role_specific
|
|
|
|
return AgentConfig.objects.filter(
|
|
organization__roles__uuid=role_uuid,
|
|
role__isnull=True,
|
|
agent_type=agent_type,
|
|
).order_by('-updated_at').first()
|
|
|
|
@database_sync_to_async
|
|
def get_config_for_user(self, config_uuid):
|
|
return AgentConfig.objects.filter(uuid = config_uuid).filter(
|
|
Q(organization__owner__id=self.user.id) | Q(organization__members__id=self.user.id)
|
|
).first()
|
|
|
|
@database_sync_to_async
|
|
def can_manage_role(self, role_uuid):
|
|
role = Role.objects.filter(uuid=role_uuid).first()
|
|
if role is None:
|
|
return False
|
|
return can_manage_organization(self.user, get_organization_from_object(role))
|
|
|
|
@database_sync_to_async
|
|
def can_access_role(self, role_uuid):
|
|
role = Role.objects.filter(uuid=role_uuid).first()
|
|
if role is None:
|
|
return False
|
|
return role.members.filter(id=self.user.id).exists() or can_manage_organization(self.user, get_organization_from_object(role))
|