163 lines
4.3 KiB
Python
163 lines
4.3 KiB
Python
import json
|
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
|
from channels.db import database_sync_to_async
|
|
from apps.agents.models import Agent, AgentExecution, AgentEvent
|
|
from apps.agents.tasks import start_agent_task_mcp
|
|
|
|
|
|
class AgentConsumer(AsyncWebsocketConsumer):
|
|
async def connect(self):
|
|
self.user = self.scope["user"]
|
|
self.agent_id = self.scope['url_route']['kwargs'].get('agent_id')
|
|
self.room_group_name = f"agent_{self.agent_id}"
|
|
|
|
if not self.user.is_authenticated:
|
|
await self.close()
|
|
return
|
|
|
|
await self.channel_layer.group_add(self.room_group_name, self.channel_name)
|
|
await self.accept()
|
|
await self.send(json.dumps({
|
|
"type": "connection",
|
|
"message": "Connected to agent stream",
|
|
"agent_id": str(self.agent_id)
|
|
}))
|
|
|
|
async def disconnect(self, close_code):
|
|
await self.channel_layer.group_discard(self.room_group_name, self.channel_name)
|
|
|
|
async def receive(self, text_data):
|
|
try:
|
|
data = json.loads(text_data)
|
|
action = data.get('action')
|
|
|
|
if action == 'start_agent':
|
|
await self.handle_start_agent(data)
|
|
elif action == 'stop_agent':
|
|
await self.handle_stop_agent(data)
|
|
else:
|
|
await self.send(json.dumps({
|
|
"type": "error",
|
|
"message": f"Unknown action: {action}"
|
|
}))
|
|
except json.JSONDecodeError:
|
|
await self.send(json.dumps({
|
|
"type": "error",
|
|
"message": "Invalid JSON"
|
|
}))
|
|
except Exception as e:
|
|
await self.send(json.dumps({
|
|
"type": "error",
|
|
"message": str(e)
|
|
}))
|
|
|
|
async def handle_start_agent(self, data):
|
|
input_data = data.get('input_data', {})
|
|
|
|
agent = await self.get_agent(self.agent_id, self.user)
|
|
if not agent:
|
|
await self.send(json.dumps({
|
|
"type": "error",
|
|
"message": "Agent not found"
|
|
}))
|
|
return
|
|
|
|
execution = await self.create_execution(agent, self.user, input_data)
|
|
|
|
await self.send(json.dumps({
|
|
"type": "execution_started",
|
|
"execution_id": str(execution.uuid),
|
|
"agent_id": str(agent.uuid),
|
|
"message": f"Agent execution {execution.uuid} queued"
|
|
}))
|
|
|
|
try:
|
|
from apps.agents.tasks import start_agent_task_mcp
|
|
|
|
print(f"[Consumer] Queuing MCP execution for {execution.uuid}")
|
|
start_agent_task_mcp.delay(str(execution.uuid))
|
|
|
|
except Exception as e:
|
|
print(f"Error queuing agent task: {e}")
|
|
await self.send(json.dumps({
|
|
"type": "execution_error",
|
|
"execution_id": str(execution.uuid),
|
|
"error_message": str(e)
|
|
}))
|
|
|
|
async def handle_stop_agent(self, data):
|
|
execution_id = data.get('execution_id')
|
|
execution = await self.get_execution(execution_id, self.user)
|
|
|
|
if not execution:
|
|
await self.send(json.dumps({
|
|
"type": "error",
|
|
"message": "Execution not found"
|
|
}))
|
|
return
|
|
|
|
await self.update_execution_status(execution, 'failed')
|
|
await self.send(json.dumps({
|
|
"type": "execution_stopped",
|
|
"execution_id": str(execution.uuid),
|
|
"message": "Agent execution stopped by user"
|
|
}))
|
|
|
|
async def agent_event(self, event):
|
|
await self.send(json.dumps({
|
|
"type": "agent_event",
|
|
"event_type": event['event_type'],
|
|
"content": event['content'],
|
|
"timestamp": event['timestamp']
|
|
}))
|
|
|
|
async def agent_completed(self, event):
|
|
await self.send(json.dumps({
|
|
"type": "execution_completed",
|
|
"execution_id": event['execution_id'],
|
|
"output_data": event['output_data'],
|
|
"message": "Agent execution completed"
|
|
}))
|
|
|
|
async def agent_error(self, event):
|
|
await self.send(json.dumps({
|
|
"type": "execution_error",
|
|
"execution_id": event['execution_id'],
|
|
"error_message": event['error_message']
|
|
}))
|
|
|
|
@database_sync_to_async
|
|
def get_agent(self, agent_id, user):
|
|
try:
|
|
return Agent.objects.get(uuid=agent_id, user=user)
|
|
except Agent.DoesNotExist:
|
|
return None
|
|
|
|
@database_sync_to_async
|
|
def get_execution(self, execution_id, user):
|
|
try:
|
|
return AgentExecution.objects.get(uuid=execution_id, user=user)
|
|
except AgentExecution.DoesNotExist:
|
|
return None
|
|
|
|
@database_sync_to_async
|
|
def create_execution(self, agent, user, input_data):
|
|
return AgentExecution.objects.create(
|
|
agent=agent,
|
|
user=user,
|
|
input_data=input_data
|
|
)
|
|
|
|
@database_sync_to_async
|
|
def update_execution_status(self, execution, status):
|
|
execution.status = status
|
|
execution.save()
|
|
return execution
|
|
|
|
@database_sync_to_async
|
|
def create_event(self, execution, event_type, content):
|
|
return AgentEvent.objects.create(
|
|
execution=execution,
|
|
event_type=event_type,
|
|
content=content
|
|
)
|