import json from channels.generic.websocket import AsyncWebsocketConsumer from channels.db import database_sync_to_async from apps.agents.models import Agent, AgentExecution, AgentEvent from apps.agents.tasks import start_agent_task_mcp class AgentConsumer(AsyncWebsocketConsumer): async def connect(self): self.user = self.scope["user"] self.agent_id = self.scope['url_route']['kwargs'].get('agent_id') self.room_group_name = f"agent_{self.agent_id}" if not self.user.is_authenticated: await self.close() return await self.channel_layer.group_add(self.room_group_name, self.channel_name) await self.accept() await self.send(json.dumps({ "type": "connection", "message": "Connected to agent stream", "agent_id": str(self.agent_id) })) async def disconnect(self, close_code): await self.channel_layer.group_discard(self.room_group_name, self.channel_name) async def receive(self, text_data): try: data = json.loads(text_data) action = data.get('action') if action == 'start_agent': await self.handle_start_agent(data) elif action == 'stop_agent': await self.handle_stop_agent(data) else: await self.send(json.dumps({ "type": "error", "message": f"Unknown action: {action}" })) except json.JSONDecodeError: await self.send(json.dumps({ "type": "error", "message": "Invalid JSON" })) except Exception as e: await self.send(json.dumps({ "type": "error", "message": str(e) })) async def handle_start_agent(self, data): input_data = data.get('input_data', {}) agent = await self.get_agent(self.agent_id, self.user) if not agent: await self.send(json.dumps({ "type": "error", "message": "Agent not found" })) return execution = await self.create_execution(agent, self.user, input_data) await self.send(json.dumps({ "type": "execution_started", "execution_id": str(execution.uuid), "agent_id": str(agent.uuid), "message": f"Agent execution {execution.uuid} queued" })) try: from apps.agents.tasks import start_agent_task_mcp print(f"[Consumer] Queuing MCP execution for {execution.uuid}") start_agent_task_mcp.delay(str(execution.uuid)) except Exception as e: print(f"Error queuing agent task: {e}") await self.send(json.dumps({ "type": "execution_error", "execution_id": str(execution.uuid), "error_message": str(e) })) async def handle_stop_agent(self, data): execution_id = data.get('execution_id') execution = await self.get_execution(execution_id, self.user) if not execution: await self.send(json.dumps({ "type": "error", "message": "Execution not found" })) return await self.update_execution_status(execution, 'failed') await self.send(json.dumps({ "type": "execution_stopped", "execution_id": str(execution.uuid), "message": "Agent execution stopped by user" })) async def agent_event(self, event): await self.send(json.dumps({ "type": "agent_event", "event_type": event['event_type'], "content": event['content'], "timestamp": event['timestamp'] })) async def agent_completed(self, event): await self.send(json.dumps({ "type": "execution_completed", "execution_id": event['execution_id'], "output_data": event['output_data'], "message": "Agent execution completed" })) async def agent_error(self, event): await self.send(json.dumps({ "type": "execution_error", "execution_id": event['execution_id'], "error_message": event['error_message'] })) @database_sync_to_async def get_agent(self, agent_id, user): try: return Agent.objects.get(uuid=agent_id, user=user) except Agent.DoesNotExist: return None @database_sync_to_async def get_execution(self, execution_id, user): try: return AgentExecution.objects.get(uuid=execution_id, user=user) except AgentExecution.DoesNotExist: return None @database_sync_to_async def create_execution(self, agent, user, input_data): return AgentExecution.objects.create( agent=agent, user=user, input_data=input_data ) @database_sync_to_async def update_execution_status(self, execution, status): execution.status = status execution.save() return execution @database_sync_to_async def create_event(self, execution, event_type, content): return AgentEvent.objects.create( execution=execution, event_type=event_type, content=content )