import logging import traceback from asgiref.sync import async_to_sync from celery import shared_task from channels.layers import get_channel_layer from django.utils import timezone from apps.orgs.models import TrainingFile from . import services from .models import Agent, AgentEvent, AgentModel, AgentRun logger = logging.getLogger(__name__) def _send_group_event(room_group_name: str, event_type: str, content: dict): channel_layer = get_channel_layer() async_to_sync(channel_layer.group_send)( room_group_name, { "type": "mlstore_event", "event_type": event_type, "content": content, "timestamp": timezone.now().isoformat(), } ) def _persist_event(execution: AgentRun, event_type: str, content: dict): AgentEvent.objects.create( execution=execution, event_type=event_type, content=content, ) def _update_agent_status(agent: Agent, status: str): agent.status = status if status == "running": agent.started_at = timezone.now() elif status in ("completed", "failed"): agent.completed_at = timezone.now() agent.save() @shared_task def start_fine_tune_run_task(execution_id: str): logger.info(f"Fine-tune run task started for execution: {execution_id}") try: execution = AgentRun.objects.get(uuid=execution_id) except AgentRun.DoesNotExist: logger.error(f"Execution not found: {execution_id}") return {"status": "error", "error": "execution_not_found", "execution_id": execution_id} agent = execution.agent room_group_name = f"mlstore_agent_{agent.uuid}" logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}") execution.status = "running" execution.started_at = timezone.now() execution.save() _update_agent_status(agent, "running") logger.info(f"Execution {execution_id} status updated to 'running'") from apps.mlstore.services import BASE_MODEL_CACHE logger.info(f"Base model cache directory: {BASE_MODEL_CACHE}") input_data = execution.input_data or {} base_model = input_data.get("base_model") or agent.model.name training_files = input_data.get("training_files") or [] org_training_files = [] if not training_files and agent.organization: org_training_files = list(TrainingFile.objects.filter( organization=agent.organization, is_processed=False ).select_related('uploaded_by')) training_files = [tf.file.path for tf in org_training_files if tf.file] logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}") hyperparams = input_data.get("hyperparams") or {} name = input_data.get("name") or agent.model.name if not input_data.get("version"): existing_models = AgentModel.objects.filter(name=name).order_by('-version') if existing_models.exists(): last_version = existing_models.first().version try: if last_version.startswith('v'): num = int(last_version[1:]) version = f"v{num + 1}" else: version = f"v1" except: version = "v1" else: version = "v1" else: version = input_data.get("version") logger.info(f"Fine-tune parameters: base_model={base_model}, name={name}, version={version}") _send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"}) _persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"}) try: result = services.fine_tune_model(base_model, training_files, hyperparams, name, version) logger.info(f"Fine-tune result received: {result.get('status')}") logger.debug(f"Full fine-tune result: {result}") if isinstance(result, dict) and result.get("status") == "completed": model_path = result.get("model_path") or result.get("path") or "" model_version = result.get("version") or version new_model = AgentModel.objects.create(name=name, version=model_version, path=model_path) agent.model = new_model agent.save() logger.info(f"Fine-tune completed. New model created: {new_model.uuid} at {model_path}") if org_training_files: file_ids = [tf.id for tf in org_training_files] TrainingFile.objects.filter(id__in=file_ids).update(is_processed=True) logger.info(f"Marked {len(org_training_files)} training files as processed") execution.status = "completed" execution.output_data = { "result": result, "model_id": new_model.id, "model_uuid": str(new_model.uuid), } execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "completed") logger.info(f"Execution {execution_id} completed successfully") _send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path}) _persist_event(execution, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_completed", "execution_id": str(execution.uuid), "output_data": execution.output_data, }, ) return {"status": "completed", "execution_id": execution_id, "model_id": new_model.id} logger.warning(f"Fine-tune did not complete successfully. Status: {result.get('status')}") execution.status = "failed" execution.error_message = str(result) execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": result}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": result}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": str(result), }, ) return {"status": "failed", "execution_id": execution_id, "result": result} except Exception as e: logger.error(f"Fine-tune task failed with exception for execution {execution_id}: {str(e)}", exc_info=True) traceback.print_exc() execution.status = "failed" execution.error_message = str(e) execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": str(e), }, ) return {"status": "error", "execution_id": execution_id, "error": str(e)} @shared_task def infer_run_task(execution_id: str): logger.info(f"Inference run task started for execution: {execution_id}") try: execution = AgentRun.objects.get(uuid=execution_id) except AgentRun.DoesNotExist: logger.error(f"Execution not found: {execution_id}") return {"status": "error", "error": "execution_not_found", "execution_id": execution_id} agent = execution.agent room_group_name = f"mlstore_agent_{agent.uuid}" logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}") execution.status = "running" execution.started_at = timezone.now() execution.save() _update_agent_status(agent, "running") logger.info(f"Execution {execution_id} status updated to 'running'") input_data = execution.input_data or {} prompt = input_data.get("prompt") or input_data.get("query") or "" options = input_data.get("options") or {} logger.info(f"Prompt length: {len(prompt)} characters") if not prompt: logger.warning(f"No prompt provided for inference run {execution_id}") execution.status = "failed" execution.error_message = "prompt_required" execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": "prompt_required", }, ) return {"status": "failed", "execution_id": execution_id, "error": "prompt_required"} _send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "infer"}) _persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "infer"}) try: try: logger.info(f"Loading model: {agent.model.path}") services.load_model_for_inference(agent.model.path) except Exception as e: logger.warning(f"Failed to preload model: {str(e)}") pass logger.info(f"Starting inference with model: {agent.model.path}") result = services.infer_with_model(agent.model.path, prompt, options) execution.status = "completed" execution.output_data = {"result": result} execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "completed") logger.info(f"Inference execution {execution_id} completed successfully") _send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "result": result}) _persist_event(execution, "completed", {"execution_id": str(execution.uuid), "result": result}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_completed", "execution_id": str(execution.uuid), "output_data": execution.output_data, }, ) return {"status": "completed", "execution_id": execution_id} except Exception as e: logger.error(f"Inference task failed with exception for execution {execution_id}: {str(e)}", exc_info=True) traceback.print_exc() execution.status = "failed" execution.error_message = str(e) execution.completed_at = timezone.now() execution.save() _update_agent_status(agent, "failed") _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)}) _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)}) async_to_sync(get_channel_layer().group_send)( room_group_name, { "type": "mlstore_error", "execution_id": str(execution.uuid), "error_message": str(e), }, ) return {"status": "failed", "execution_id": execution_id, "error": str(e)}