2026-01-27 22:17:22 +00:00
|
|
|
import logging
|
|
|
|
|
import traceback
|
|
|
|
|
from asgiref.sync import async_to_sync
|
2026-01-20 17:21:28 +00:00
|
|
|
from celery import shared_task
|
|
|
|
|
from channels.layers import get_channel_layer
|
2026-01-27 22:17:22 +00:00
|
|
|
from django.utils import timezone
|
|
|
|
|
|
|
|
|
|
from apps.orgs.models import TrainingFile
|
2026-01-20 17:21:28 +00:00
|
|
|
from . import services
|
2026-01-27 22:17:22 +00:00
|
|
|
from .models import Agent, AgentEvent, AgentModel, AgentRun
|
2026-01-25 17:29:37 +00:00
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
def _send_group_event(room_group_name: str, event_type: str, content: dict):
|
|
|
|
|
channel_layer = get_channel_layer()
|
|
|
|
|
async_to_sync(channel_layer.group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_event",
|
|
|
|
|
"event_type": event_type,
|
|
|
|
|
"content": content,
|
|
|
|
|
"timestamp": timezone.now().isoformat(),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _persist_event(execution: AgentRun, event_type: str, content: dict):
|
|
|
|
|
AgentEvent.objects.create(
|
|
|
|
|
execution=execution,
|
|
|
|
|
event_type=event_type,
|
|
|
|
|
content=content,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _update_agent_status(agent: Agent, status: str):
|
|
|
|
|
agent.status = status
|
|
|
|
|
if status == "running":
|
|
|
|
|
agent.started_at = timezone.now()
|
|
|
|
|
elif status in ("completed", "failed"):
|
|
|
|
|
agent.completed_at = timezone.now()
|
|
|
|
|
agent.save()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@shared_task
|
|
|
|
|
def start_fine_tune_run_task(execution_id: str):
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Fine-tune run task started for execution: {execution_id}")
|
2026-01-20 17:21:28 +00:00
|
|
|
try:
|
|
|
|
|
execution = AgentRun.objects.get(uuid=execution_id)
|
|
|
|
|
except AgentRun.DoesNotExist:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.error(f"Execution not found: {execution_id}")
|
2026-01-20 17:21:28 +00:00
|
|
|
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
|
|
|
|
|
|
|
|
|
|
agent = execution.agent
|
|
|
|
|
room_group_name = f"mlstore_agent_{agent.uuid}"
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
execution.status = "running"
|
|
|
|
|
execution.started_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "running")
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Execution {execution_id} status updated to 'running'")
|
|
|
|
|
|
|
|
|
|
from apps.mlstore.services import BASE_MODEL_CACHE
|
|
|
|
|
logger.info(f"Base model cache directory: {BASE_MODEL_CACHE}")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
input_data = execution.input_data or {}
|
|
|
|
|
base_model = input_data.get("base_model") or agent.model.name
|
2026-01-25 17:29:37 +00:00
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
training_files = input_data.get("training_files") or []
|
2026-01-27 22:17:22 +00:00
|
|
|
org_training_files = []
|
2026-01-25 17:29:37 +00:00
|
|
|
if not training_files and agent.organization:
|
2026-01-27 22:17:22 +00:00
|
|
|
org_training_files = list(TrainingFile.objects.filter(
|
2026-01-25 17:29:37 +00:00
|
|
|
organization=agent.organization,
|
|
|
|
|
is_processed=False
|
2026-01-27 22:17:22 +00:00
|
|
|
).select_related('uploaded_by'))
|
2026-01-25 17:29:37 +00:00
|
|
|
training_files = [tf.file.path for tf in org_training_files if tf.file]
|
|
|
|
|
logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}")
|
|
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
hyperparams = input_data.get("hyperparams") or {}
|
2026-01-27 22:17:22 +00:00
|
|
|
name = input_data.get("name") or agent.model.name
|
|
|
|
|
|
|
|
|
|
if not input_data.get("version"):
|
|
|
|
|
existing_models = AgentModel.objects.filter(name=name).order_by('-version')
|
|
|
|
|
if existing_models.exists():
|
|
|
|
|
last_version = existing_models.first().version
|
|
|
|
|
try:
|
|
|
|
|
if last_version.startswith('v'):
|
|
|
|
|
num = int(last_version[1:])
|
|
|
|
|
version = f"v{num + 1}"
|
|
|
|
|
else:
|
|
|
|
|
version = f"v1"
|
|
|
|
|
except:
|
|
|
|
|
version = "v1"
|
|
|
|
|
else:
|
|
|
|
|
version = "v1"
|
|
|
|
|
else:
|
|
|
|
|
version = input_data.get("version")
|
|
|
|
|
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Fine-tune parameters: base_model={base_model}, name={name}, version={version}")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
|
|
|
|
|
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
result = services.fine_tune_model(base_model, training_files, hyperparams, name, version)
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Fine-tune result received: {result.get('status')}")
|
|
|
|
|
logger.debug(f"Full fine-tune result: {result}")
|
|
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
if isinstance(result, dict) and result.get("status") == "completed":
|
|
|
|
|
model_path = result.get("model_path") or result.get("path") or ""
|
|
|
|
|
model_version = result.get("version") or version
|
|
|
|
|
new_model = AgentModel.objects.create(name=name, version=model_version, path=model_path)
|
|
|
|
|
agent.model = new_model
|
|
|
|
|
agent.save()
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Fine-tune completed. New model created: {new_model.uuid} at {model_path}")
|
2026-01-27 22:17:22 +00:00
|
|
|
|
|
|
|
|
if org_training_files:
|
|
|
|
|
file_ids = [tf.id for tf in org_training_files]
|
|
|
|
|
TrainingFile.objects.filter(id__in=file_ids).update(is_processed=True)
|
|
|
|
|
logger.info(f"Marked {len(org_training_files)} training files as processed")
|
|
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
execution.status = "completed"
|
|
|
|
|
execution.output_data = {
|
|
|
|
|
"result": result,
|
|
|
|
|
"model_id": new_model.id,
|
|
|
|
|
"model_uuid": str(new_model.uuid),
|
|
|
|
|
}
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "completed")
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Execution {execution_id} completed successfully")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
|
|
|
|
|
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
|
|
|
|
|
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_completed",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"output_data": execution.output_data,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {"status": "completed", "execution_id": execution_id, "model_id": new_model.id}
|
|
|
|
|
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.warning(f"Fine-tune did not complete successfully. Status: {result.get('status')}")
|
2026-01-20 17:21:28 +00:00
|
|
|
execution.status = "failed"
|
|
|
|
|
execution.error_message = str(result)
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "failed")
|
|
|
|
|
|
|
|
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": result})
|
|
|
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": result})
|
|
|
|
|
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_error",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"error_message": str(result),
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return {"status": "failed", "execution_id": execution_id, "result": result}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.error(f"Fine-tune task failed with exception for execution {execution_id}: {str(e)}", exc_info=True)
|
2026-01-20 17:21:28 +00:00
|
|
|
traceback.print_exc()
|
|
|
|
|
execution.status = "failed"
|
|
|
|
|
execution.error_message = str(e)
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "failed")
|
|
|
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
|
|
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_error",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"error_message": str(e),
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return {"status": "error", "execution_id": execution_id, "error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@shared_task
|
|
|
|
|
def infer_run_task(execution_id: str):
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Inference run task started for execution: {execution_id}")
|
2026-01-20 17:21:28 +00:00
|
|
|
try:
|
|
|
|
|
execution = AgentRun.objects.get(uuid=execution_id)
|
|
|
|
|
except AgentRun.DoesNotExist:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.error(f"Execution not found: {execution_id}")
|
2026-01-20 17:21:28 +00:00
|
|
|
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
|
|
|
|
|
|
|
|
|
|
agent = execution.agent
|
|
|
|
|
room_group_name = f"mlstore_agent_{agent.uuid}"
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
execution.status = "running"
|
|
|
|
|
execution.started_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "running")
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Execution {execution_id} status updated to 'running'")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
input_data = execution.input_data or {}
|
|
|
|
|
prompt = input_data.get("prompt") or input_data.get("query") or ""
|
|
|
|
|
options = input_data.get("options") or {}
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Prompt length: {len(prompt)} characters")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
if not prompt:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.warning(f"No prompt provided for inference run {execution_id}")
|
2026-01-20 17:21:28 +00:00
|
|
|
execution.status = "failed"
|
|
|
|
|
execution.error_message = "prompt_required"
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "failed")
|
|
|
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
|
|
|
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_error",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"error_message": "prompt_required",
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return {"status": "failed", "execution_id": execution_id, "error": "prompt_required"}
|
|
|
|
|
|
|
|
|
|
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "infer"})
|
|
|
|
|
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "infer"})
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
try:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Loading model: {agent.model.path}")
|
2026-01-20 17:21:28 +00:00
|
|
|
services.load_model_for_inference(agent.model.path)
|
2026-01-25 17:29:37 +00:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to preload model: {str(e)}")
|
2026-01-20 17:21:28 +00:00
|
|
|
pass
|
|
|
|
|
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Starting inference with model: {agent.model.path}")
|
2026-01-20 17:21:28 +00:00
|
|
|
result = services.infer_with_model(agent.model.path, prompt, options)
|
|
|
|
|
|
|
|
|
|
execution.status = "completed"
|
|
|
|
|
execution.output_data = {"result": result}
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "completed")
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Inference execution {execution_id} completed successfully")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "result": result})
|
|
|
|
|
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "result": result})
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_completed",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"output_data": execution.output_data,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return {"status": "completed", "execution_id": execution_id}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.error(f"Inference task failed with exception for execution {execution_id}: {str(e)}", exc_info=True)
|
2026-01-20 17:21:28 +00:00
|
|
|
traceback.print_exc()
|
|
|
|
|
execution.status = "failed"
|
|
|
|
|
execution.error_message = str(e)
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "failed")
|
|
|
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
|
|
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_error",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"error_message": str(e),
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return {"status": "failed", "execution_id": execution_id, "error": str(e)}
|