import json from channels.generic.websocket import AsyncWebsocketConsumer from channels.db import database_sync_to_async from django.utils import timezone from .models import Agent, AgentRun, AgentEvent from .tasks import start_fine_tune_run_task, infer_run_task class MLStoreConsumer(AsyncWebsocketConsumer): async def connect(self): self.user = self.scope["user"] self.agent_id = self.scope["url_route"]["kwargs"].get("agent_id") self.room_group_name = f"mlstore_agent_{self.agent_id}" if not self.user.is_authenticated: await self.close() return agent = await self.get_agent(self.agent_id) if not agent: await self.close() return await self.channel_layer.group_add(self.room_group_name, self.channel_name) await self.accept() await self.send(json.dumps({ "type": "connection", "message": "Connected to mlstore agent stream", "agent_id": str(self.agent_id) })) async def disconnect(self, close_code): await self.channel_layer.group_discard(self.room_group_name, self.channel_name) async def receive(self, text_data): try: data = json.loads(text_data) action = data.get("action") if action == "fine_tune": await self.handle_fine_tune(data) elif action == "infer": await self.handle_infer(data) elif action in ("stop_agent", "stop"): await self.handle_stop(data) else: await self.send(json.dumps({ "type": "error", "message": f"Unknown action: {action}" })) except json.JSONDecodeError: await self.send(json.dumps({ "type": "error", "message": "Invalid JSON" })) except Exception as e: await self.send(json.dumps({ "type": "error", "message": str(e) })) async def handle_fine_tune(self, data): agent = await self.get_agent(self.agent_id) if not agent: await self.send(json.dumps({ "type": "error", "message": "Agent not found" })) return input_data = data.get("input_data") or {} execution = await self.create_run(agent, self.user, input_data) await self.send(json.dumps({ "type": "execution_started", "execution_id": str(execution.uuid), "agent_id": str(agent.uuid), "message": f"Fine-tune run {execution.uuid} queued" })) start_fine_tune_run_task.delay(str(execution.uuid)) async def handle_infer(self, data): agent = await self.get_agent(self.agent_id) if not agent: await self.send(json.dumps({ "type": "error", "message": "Agent not found" })) return input_data = data.get("input_data") or {} execution = await self.create_run(agent, self.user, input_data) await self.send(json.dumps({ "type": "execution_started", "execution_id": str(execution.uuid), "agent_id": str(agent.uuid), "message": f"Inference run {execution.uuid} queued" })) infer_run_task.delay(str(execution.uuid)) async def handle_stop(self, data): execution_id = data.get("execution_id") if not execution_id: await self.send(json.dumps({ "type": "error", "message": "execution_id required to stop" })) return execution = await self.get_execution(execution_id) if not execution: await self.send(json.dumps({ "type": "error", "message": "Execution not found" })) return await self.update_execution_status(execution, "failed") await self.send(json.dumps({ "type": "execution_error", "execution_id": str(execution.uuid), "error_message": "Execution stopped by user" })) async def mlstore_event(self, event): await self.send(json.dumps({ "type": "mlstore_event", "event_type": event["event_type"], "content": event["content"], "timestamp": event["timestamp"] })) async def mlstore_completed(self, event): await self.send(json.dumps({ "type": "execution_completed", "execution_id": event["execution_id"], "output_data": event["output_data"], "message": "Execution completed" })) async def mlstore_error(self, event): await self.send(json.dumps({ "type": "execution_error", "execution_id": event["execution_id"], "error_message": event["error_message"] })) @database_sync_to_async def get_agent(self, agent_id): try: return Agent.objects.get(uuid=agent_id) except Agent.DoesNotExist: return None @database_sync_to_async def create_run(self, agent, user, input_data): return AgentRun.objects.create( agent=agent, user=user, input_data=input_data, ) @database_sync_to_async def get_execution(self, execution_id): try: return AgentRun.objects.get(uuid=execution_id) except AgentRun.DoesNotExist: return None @database_sync_to_async def update_execution_status(self, execution, status): execution.status = status execution.completed_at = timezone.now() execution.save() try: agent = execution.agent agent.status = status agent.completed_at = timezone.now() agent.save() except Exception: pass return execution @database_sync_to_async def create_event(self, execution, event_type, content): return AgentEvent.objects.create( execution=execution, event_type=event_type, content=content, )