2026-01-27 22:17:22 +00:00
|
|
|
import logging
|
2026-02-08 15:34:26 +00:00
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import time
|
2026-01-27 22:17:22 +00:00
|
|
|
import traceback
|
2026-02-08 15:34:26 +00:00
|
|
|
from hashlib import sha256
|
2026-01-27 22:17:22 +00:00
|
|
|
from asgiref.sync import async_to_sync
|
2026-01-20 17:21:28 +00:00
|
|
|
from celery import shared_task
|
|
|
|
|
from channels.layers import get_channel_layer
|
2026-01-27 22:17:22 +00:00
|
|
|
from django.utils import timezone
|
2026-02-08 15:34:26 +00:00
|
|
|
from django.db import transaction
|
2026-01-27 22:17:22 +00:00
|
|
|
|
2026-02-08 15:34:26 +00:00
|
|
|
from apps.orgs.models import TrainingFile, Role
|
2026-01-20 17:21:28 +00:00
|
|
|
from . import services
|
2026-02-08 15:34:26 +00:00
|
|
|
from .models import Agent, AgentEvent, AgentModel, AgentRun, RoleRagDocument
|
2026-01-25 17:29:37 +00:00
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
2026-01-20 17:21:28 +00:00
|
|
|
|
2026-02-08 15:34:26 +00:00
|
|
|
|
|
|
|
|
def _get_mem_info() -> str:
|
|
|
|
|
try:
|
|
|
|
|
with open('/proc/self/status', 'r', encoding='utf-8') as f:
|
|
|
|
|
lines = f.read().splitlines()
|
|
|
|
|
mem = {line.split(':', 1)[0]: line.split(':', 1)[1].strip() for line in lines if ':' in line}
|
|
|
|
|
return f"VmRSS={mem.get('VmRSS','?')}, VmHWM={mem.get('VmHWM','?')}, VmSize={mem.get('VmSize','?')}"
|
|
|
|
|
except Exception:
|
|
|
|
|
return "mem_info_unavailable"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _estimate_tokens(text: str) -> int:
|
|
|
|
|
if not text:
|
|
|
|
|
return 0
|
|
|
|
|
return len(re.findall(r"\w+|[^\s\w]", text))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _split_semantic_units(text: str) -> list[str]:
|
|
|
|
|
paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", text) if p.strip()]
|
|
|
|
|
units: list[str] = []
|
|
|
|
|
for para in paragraphs:
|
|
|
|
|
sentences = re.split(r"(?<=[.!?])\s+", para)
|
|
|
|
|
for sent in sentences:
|
|
|
|
|
sent = sent.strip()
|
|
|
|
|
if sent:
|
|
|
|
|
units.append(sent)
|
|
|
|
|
return units or paragraphs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _chunk_text(text: str, max_tokens: int = 400, overlap_tokens: int = 50) -> list[str]:
|
|
|
|
|
if not text:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
units = _split_semantic_units(text)
|
|
|
|
|
logger.info(
|
|
|
|
|
"Semantic chunking units=%s max_tokens=%s overlap_tokens=%s mem=%s",
|
|
|
|
|
len(units),
|
|
|
|
|
max_tokens,
|
|
|
|
|
overlap_tokens,
|
|
|
|
|
_get_mem_info(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
chunks: list[str] = []
|
|
|
|
|
current: list[str] = []
|
|
|
|
|
current_tokens = 0
|
|
|
|
|
|
|
|
|
|
for unit in units:
|
|
|
|
|
unit_tokens = _estimate_tokens(unit)
|
|
|
|
|
if unit_tokens == 0:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if current_tokens + unit_tokens > max_tokens and current:
|
|
|
|
|
chunk = " ".join(current).strip()
|
|
|
|
|
if chunk:
|
|
|
|
|
chunks.append(chunk)
|
|
|
|
|
|
|
|
|
|
if overlap_tokens > 0:
|
|
|
|
|
overlap: list[str] = []
|
|
|
|
|
overlap_count = 0
|
|
|
|
|
for prev in reversed(current):
|
|
|
|
|
prev_tokens = _estimate_tokens(prev)
|
|
|
|
|
if overlap_count + prev_tokens > overlap_tokens:
|
|
|
|
|
break
|
|
|
|
|
overlap.insert(0, prev)
|
|
|
|
|
overlap_count += prev_tokens
|
|
|
|
|
current = overlap
|
|
|
|
|
current_tokens = overlap_count
|
|
|
|
|
else:
|
|
|
|
|
current = []
|
|
|
|
|
current_tokens = 0
|
|
|
|
|
|
|
|
|
|
current.append(unit)
|
|
|
|
|
current_tokens += unit_tokens
|
|
|
|
|
|
|
|
|
|
if current:
|
|
|
|
|
chunk = " ".join(current).strip()
|
|
|
|
|
if chunk:
|
|
|
|
|
chunks.append(chunk)
|
|
|
|
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_text_from_file(file_path: str, file_type: str | None) -> str:
|
|
|
|
|
file_type = (file_type or '').lower()
|
|
|
|
|
if file_type in {'txt', 'md', 'csv', 'json'}:
|
|
|
|
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
|
|
|
|
return f.read()
|
|
|
|
|
|
|
|
|
|
if file_type == 'pdf':
|
|
|
|
|
try:
|
|
|
|
|
from PyPDF2 import PdfReader
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise RuntimeError('PyPDF2 is required to parse PDF files') from e
|
|
|
|
|
reader = PdfReader(file_path)
|
|
|
|
|
return "\n".join(page.extract_text() or "" for page in reader.pages)
|
|
|
|
|
|
|
|
|
|
if file_type in {'docx', 'doc'}:
|
|
|
|
|
try:
|
|
|
|
|
import docx
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise RuntimeError('python-docx is required to parse DOCX files') from e
|
|
|
|
|
doc = docx.Document(file_path)
|
|
|
|
|
return "\n".join(p.text for p in doc.paragraphs)
|
|
|
|
|
|
|
|
|
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
|
|
|
|
return f.read()
|
|
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
def _send_group_event(room_group_name: str, event_type: str, content: dict):
|
|
|
|
|
channel_layer = get_channel_layer()
|
|
|
|
|
async_to_sync(channel_layer.group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_event",
|
|
|
|
|
"event_type": event_type,
|
|
|
|
|
"content": content,
|
|
|
|
|
"timestamp": timezone.now().isoformat(),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _persist_event(execution: AgentRun, event_type: str, content: dict):
|
|
|
|
|
AgentEvent.objects.create(
|
|
|
|
|
execution=execution,
|
|
|
|
|
event_type=event_type,
|
|
|
|
|
content=content,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _update_agent_status(agent: Agent, status: str):
|
|
|
|
|
agent.status = status
|
|
|
|
|
if status == "running":
|
|
|
|
|
agent.started_at = timezone.now()
|
|
|
|
|
elif status in ("completed", "failed"):
|
|
|
|
|
agent.completed_at = timezone.now()
|
|
|
|
|
agent.save()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@shared_task
|
|
|
|
|
def start_fine_tune_run_task(execution_id: str):
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Fine-tune run task started for execution: {execution_id}")
|
2026-01-20 17:21:28 +00:00
|
|
|
try:
|
|
|
|
|
execution = AgentRun.objects.get(uuid=execution_id)
|
|
|
|
|
except AgentRun.DoesNotExist:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.error(f"Execution not found: {execution_id}")
|
2026-01-20 17:21:28 +00:00
|
|
|
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
|
|
|
|
|
|
|
|
|
|
agent = execution.agent
|
|
|
|
|
room_group_name = f"mlstore_agent_{agent.uuid}"
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
execution.status = "running"
|
|
|
|
|
execution.started_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "running")
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Execution {execution_id} status updated to 'running'")
|
|
|
|
|
|
|
|
|
|
from apps.mlstore.services import BASE_MODEL_CACHE
|
|
|
|
|
logger.info(f"Base model cache directory: {BASE_MODEL_CACHE}")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
input_data = execution.input_data or {}
|
|
|
|
|
base_model = input_data.get("base_model") or agent.model.name
|
2026-01-25 17:29:37 +00:00
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
training_files = input_data.get("training_files") or []
|
2026-01-27 22:17:22 +00:00
|
|
|
org_training_files = []
|
2026-02-08 15:34:26 +00:00
|
|
|
role_uuid = input_data.get("role_uuid")
|
2026-01-25 17:29:37 +00:00
|
|
|
if not training_files and agent.organization:
|
2026-02-08 15:34:26 +00:00
|
|
|
training_files_qs = TrainingFile.objects.filter(
|
|
|
|
|
role__organization=agent.organization,
|
2026-01-25 17:29:37 +00:00
|
|
|
is_processed=False
|
2026-02-08 15:34:26 +00:00
|
|
|
).select_related('uploaded_by', 'role')
|
|
|
|
|
|
|
|
|
|
if role_uuid:
|
|
|
|
|
try:
|
|
|
|
|
role = Role.objects.get(uuid=role_uuid, organization=agent.organization)
|
|
|
|
|
training_files_qs = training_files_qs.filter(role=role)
|
|
|
|
|
except Role.DoesNotExist:
|
|
|
|
|
logger.warning(f"Role {role_uuid} not found for organization {agent.organization.name}")
|
|
|
|
|
|
|
|
|
|
org_training_files = list(training_files_qs)
|
2026-01-25 17:29:37 +00:00
|
|
|
training_files = [tf.file.path for tf in org_training_files if tf.file]
|
|
|
|
|
logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}")
|
|
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
hyperparams = input_data.get("hyperparams") or {}
|
2026-01-27 22:17:22 +00:00
|
|
|
name = input_data.get("name") or agent.model.name
|
|
|
|
|
|
|
|
|
|
if not input_data.get("version"):
|
|
|
|
|
existing_models = AgentModel.objects.filter(name=name).order_by('-version')
|
|
|
|
|
if existing_models.exists():
|
|
|
|
|
last_version = existing_models.first().version
|
|
|
|
|
try:
|
|
|
|
|
if last_version.startswith('v'):
|
|
|
|
|
num = int(last_version[1:])
|
|
|
|
|
version = f"v{num + 1}"
|
|
|
|
|
else:
|
|
|
|
|
version = f"v1"
|
|
|
|
|
except:
|
|
|
|
|
version = "v1"
|
|
|
|
|
else:
|
|
|
|
|
version = "v1"
|
|
|
|
|
else:
|
|
|
|
|
version = input_data.get("version")
|
|
|
|
|
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Fine-tune parameters: base_model={base_model}, name={name}, version={version}")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
|
|
|
|
|
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
result = services.fine_tune_model(base_model, training_files, hyperparams, name, version)
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Fine-tune result received: {result.get('status')}")
|
|
|
|
|
logger.debug(f"Full fine-tune result: {result}")
|
|
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
if isinstance(result, dict) and result.get("status") == "completed":
|
|
|
|
|
model_path = result.get("model_path") or result.get("path") or ""
|
|
|
|
|
model_version = result.get("version") or version
|
|
|
|
|
new_model = AgentModel.objects.create(name=name, version=model_version, path=model_path)
|
|
|
|
|
agent.model = new_model
|
|
|
|
|
agent.save()
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Fine-tune completed. New model created: {new_model.uuid} at {model_path}")
|
2026-01-27 22:17:22 +00:00
|
|
|
|
|
|
|
|
if org_training_files:
|
|
|
|
|
file_ids = [tf.id for tf in org_training_files]
|
|
|
|
|
TrainingFile.objects.filter(id__in=file_ids).update(is_processed=True)
|
|
|
|
|
logger.info(f"Marked {len(org_training_files)} training files as processed")
|
|
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
execution.status = "completed"
|
|
|
|
|
execution.output_data = {
|
|
|
|
|
"result": result,
|
|
|
|
|
"model_id": new_model.id,
|
|
|
|
|
"model_uuid": str(new_model.uuid),
|
|
|
|
|
}
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "completed")
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Execution {execution_id} completed successfully")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
|
|
|
|
|
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
|
|
|
|
|
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_completed",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"output_data": execution.output_data,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {"status": "completed", "execution_id": execution_id, "model_id": new_model.id}
|
|
|
|
|
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.warning(f"Fine-tune did not complete successfully. Status: {result.get('status')}")
|
2026-01-20 17:21:28 +00:00
|
|
|
execution.status = "failed"
|
|
|
|
|
execution.error_message = str(result)
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "failed")
|
|
|
|
|
|
|
|
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": result})
|
|
|
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": result})
|
|
|
|
|
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_error",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"error_message": str(result),
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return {"status": "failed", "execution_id": execution_id, "result": result}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.error(f"Fine-tune task failed with exception for execution {execution_id}: {str(e)}", exc_info=True)
|
2026-01-20 17:21:28 +00:00
|
|
|
traceback.print_exc()
|
|
|
|
|
execution.status = "failed"
|
|
|
|
|
execution.error_message = str(e)
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "failed")
|
|
|
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
|
|
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_error",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"error_message": str(e),
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return {"status": "error", "execution_id": execution_id, "error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
2026-02-08 15:34:26 +00:00
|
|
|
@shared_task
|
|
|
|
|
def ingest_training_file_task(training_file_uuid: str):
|
|
|
|
|
logger.info(f"Ingest task started for training_file_uuid={training_file_uuid}")
|
|
|
|
|
started_at = time.time()
|
|
|
|
|
try:
|
|
|
|
|
training_file = TrainingFile.objects.select_related('role').get(uuid=training_file_uuid)
|
|
|
|
|
except TrainingFile.DoesNotExist:
|
|
|
|
|
logger.error(f"Training file not found: {training_file_uuid}")
|
|
|
|
|
return {"status": "error", "error": "training_file_not_found"}
|
|
|
|
|
|
|
|
|
|
if training_file.is_processed:
|
|
|
|
|
logger.info(f"Training file already processed: {training_file_uuid}")
|
|
|
|
|
return {"status": "skipped", "reason": "already_processed"}
|
|
|
|
|
|
|
|
|
|
if not training_file.file:
|
|
|
|
|
logger.error(f"Training file has no file attached: {training_file_uuid}")
|
|
|
|
|
return {"status": "error", "error": "file_missing"}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
file_path = training_file.file.path
|
|
|
|
|
file_size = os.path.getsize(file_path) if os.path.exists(file_path) else 0
|
|
|
|
|
logger.info(
|
|
|
|
|
"Ingesting file: name=%s type=%s size_bytes=%s role=%s path=%s",
|
|
|
|
|
training_file.file_name,
|
|
|
|
|
training_file.file_type,
|
|
|
|
|
file_size,
|
|
|
|
|
training_file.role_id,
|
|
|
|
|
file_path,
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to stat training file for {training_file_uuid}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
training_file.status = 'ingesting'
|
|
|
|
|
training_file.save(update_fields=['status'])
|
|
|
|
|
|
|
|
|
|
extract_started = time.time()
|
|
|
|
|
text = _extract_text_from_file(training_file.file.path, training_file.file_type)
|
|
|
|
|
logger.info(
|
|
|
|
|
"Extracted text length=%s for training_file_uuid=%s in %.2fs mem=%s",
|
|
|
|
|
len(text),
|
|
|
|
|
training_file_uuid,
|
|
|
|
|
time.time() - extract_started,
|
|
|
|
|
_get_mem_info(),
|
|
|
|
|
)
|
|
|
|
|
chunk_started = time.time()
|
|
|
|
|
chunks = _chunk_text(text)
|
|
|
|
|
logger.info(
|
|
|
|
|
"Chunked text into %s chunks in %.2fs (sample lengths: %s) mem=%s",
|
|
|
|
|
len(chunks),
|
|
|
|
|
time.time() - chunk_started,
|
|
|
|
|
[len(c) for c in chunks[:5]],
|
|
|
|
|
_get_mem_info(),
|
|
|
|
|
)
|
|
|
|
|
if not chunks:
|
|
|
|
|
raise RuntimeError("No text extracted from file")
|
|
|
|
|
|
|
|
|
|
with transaction.atomic():
|
|
|
|
|
logger.info("Clearing existing RAG docs for training_file_uuid=%s mem=%s", training_file_uuid, _get_mem_info())
|
|
|
|
|
RoleRagDocument.objects.filter(training_file=training_file).delete()
|
|
|
|
|
logger.info("Preparing %s RAG docs for bulk_create mem=%s", len(chunks), _get_mem_info())
|
|
|
|
|
existing_hashes = set(
|
|
|
|
|
RoleRagDocument.objects.filter(role=training_file.role)
|
|
|
|
|
.values_list('content_hash', flat=True)
|
|
|
|
|
)
|
|
|
|
|
documents = []
|
|
|
|
|
skipped = 0
|
|
|
|
|
for index, chunk in enumerate(chunks):
|
|
|
|
|
content_hash = sha256(chunk.encode('utf-8')).hexdigest()
|
|
|
|
|
if content_hash in existing_hashes:
|
|
|
|
|
skipped += 1
|
|
|
|
|
continue
|
|
|
|
|
documents.append(
|
|
|
|
|
RoleRagDocument(
|
|
|
|
|
role=training_file.role,
|
|
|
|
|
training_file=training_file,
|
|
|
|
|
content=chunk,
|
|
|
|
|
embedding=None,
|
|
|
|
|
content_hash=content_hash,
|
|
|
|
|
metadata={
|
|
|
|
|
"file_name": training_file.file_name,
|
|
|
|
|
"file_type": training_file.file_type,
|
|
|
|
|
"chunk_size": len(chunk),
|
|
|
|
|
"source": "training_file",
|
|
|
|
|
},
|
|
|
|
|
chunk_index=index,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
logger.info("Skipped %s duplicate chunks based on content_hash", skipped)
|
|
|
|
|
logger.info("Bulk creating RAG docs count=%s mem=%s", len(documents), _get_mem_info())
|
|
|
|
|
RoleRagDocument.objects.bulk_create(documents, batch_size=500)
|
|
|
|
|
training_file.status = 'chunked'
|
|
|
|
|
training_file.is_processed = True
|
|
|
|
|
training_file.save(update_fields=['status', 'is_processed'])
|
|
|
|
|
|
|
|
|
|
elapsed = time.time() - started_at
|
|
|
|
|
logger.info(
|
|
|
|
|
"Ingested training file %s into %s RAG chunks in %.2fs",
|
|
|
|
|
training_file_uuid,
|
|
|
|
|
len(chunks),
|
|
|
|
|
elapsed,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Enqueueing embedding task for training_file_uuid={training_file_uuid}")
|
|
|
|
|
embed_training_file_task.delay(training_file_uuid)
|
|
|
|
|
|
|
|
|
|
return {"status": "completed", "chunks": len(chunks)}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
elapsed = time.time() - started_at
|
|
|
|
|
logger.error(f"Failed to ingest training file {training_file_uuid}: {str(e)}", exc_info=True)
|
|
|
|
|
logger.error(f"Ingest task failed after {elapsed:.2f}s for training_file_uuid={training_file_uuid}")
|
|
|
|
|
try:
|
|
|
|
|
TrainingFile.objects.filter(uuid=training_file_uuid).update(status='failed')
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
return {"status": "error", "error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@shared_task
|
|
|
|
|
def embed_training_file_task(training_file_uuid: str):
|
|
|
|
|
"""Generate embeddings for all documents of a training file.
|
|
|
|
|
|
|
|
|
|
This task is called after chunking to embed the document chunks
|
|
|
|
|
using the configured embedding provider (OpenAI, Google, or local).
|
|
|
|
|
"""
|
|
|
|
|
logger.info(f"Embedding task started for training_file_uuid={training_file_uuid}")
|
|
|
|
|
started_at = time.time()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
training_file = TrainingFile.objects.select_related('role').get(uuid=training_file_uuid)
|
|
|
|
|
except TrainingFile.DoesNotExist:
|
|
|
|
|
logger.error(f"Training file not found: {training_file_uuid}")
|
|
|
|
|
return {"status": "error", "error": "training_file_not_found"}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
documents = list(RoleRagDocument.objects.filter(training_file=training_file))
|
|
|
|
|
|
|
|
|
|
if not documents:
|
|
|
|
|
logger.warning(f"No RAG documents found for training_file_uuid={training_file_uuid}")
|
|
|
|
|
return {"status": "skipped", "reason": "no_documents"}
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Starting to embed {len(documents)} documents for training_file_uuid={training_file_uuid} "
|
|
|
|
|
f"mem={_get_mem_info()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
num_embedded, num_failed = services.batch_embed_documents(documents, batch_size=32)
|
|
|
|
|
|
|
|
|
|
if num_failed == 0:
|
|
|
|
|
training_file.status = 'embedded'
|
|
|
|
|
training_file.save(update_fields=['status'])
|
|
|
|
|
logger.info(f"Successfully embedded all documents for training_file_uuid={training_file_uuid}")
|
|
|
|
|
elif num_embedded > 0:
|
|
|
|
|
training_file.status = 'embedded'
|
|
|
|
|
training_file.save(update_fields=['status'])
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Partially embedded {num_embedded} documents, {num_failed} failed "
|
|
|
|
|
f"for training_file_uuid={training_file_uuid}"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
training_file.status = 'failed'
|
|
|
|
|
training_file.save(update_fields=['status'])
|
|
|
|
|
logger.error(f"Failed to embed any documents for training_file_uuid={training_file_uuid}")
|
|
|
|
|
return {"status": "error", "error": "embedding_failed", "num_failed": num_failed}
|
|
|
|
|
|
|
|
|
|
elapsed = time.time() - started_at
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Embedding task completed for {training_file_uuid}: "
|
|
|
|
|
f"embedded={num_embedded}, failed={num_failed}, time={elapsed:.2f}s"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"status": "completed",
|
|
|
|
|
"num_embedded": num_embedded,
|
|
|
|
|
"num_failed": num_failed,
|
|
|
|
|
"elapsed": elapsed,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
elapsed = time.time() - started_at
|
|
|
|
|
logger.error(
|
|
|
|
|
f"Failed to embed training file {training_file_uuid}: {str(e)}",
|
|
|
|
|
exc_info=True
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
TrainingFile.objects.filter(uuid=training_file_uuid).update(status='failed')
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
return {"status": "error", "error": str(e), "elapsed": elapsed}
|
|
|
|
|
|
|
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
@shared_task
|
|
|
|
|
def infer_run_task(execution_id: str):
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Inference run task started for execution: {execution_id}")
|
2026-01-20 17:21:28 +00:00
|
|
|
try:
|
|
|
|
|
execution = AgentRun.objects.get(uuid=execution_id)
|
|
|
|
|
except AgentRun.DoesNotExist:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.error(f"Execution not found: {execution_id}")
|
2026-01-20 17:21:28 +00:00
|
|
|
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
|
|
|
|
|
|
|
|
|
|
agent = execution.agent
|
|
|
|
|
room_group_name = f"mlstore_agent_{agent.uuid}"
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
execution.status = "running"
|
|
|
|
|
execution.started_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "running")
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Execution {execution_id} status updated to 'running'")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
input_data = execution.input_data or {}
|
|
|
|
|
prompt = input_data.get("prompt") or input_data.get("query") or ""
|
2026-02-08 15:34:26 +00:00
|
|
|
options = dict(input_data.get("options") or {})
|
|
|
|
|
role_uuid = input_data.get("role_uuid") or options.get("role_uuid")
|
|
|
|
|
rag_top_k = int(input_data.get("rag_top_k", 5))
|
|
|
|
|
rag_similarity_threshold = float(input_data.get("rag_similarity_threshold", 0.5))
|
|
|
|
|
|
|
|
|
|
options.setdefault("temperature", 0.2)
|
|
|
|
|
options.setdefault("top_p", 0.9)
|
|
|
|
|
options.setdefault("max_tokens", 200)
|
|
|
|
|
options.setdefault("stop", ["\n\n", "References:", "Sources:"])
|
|
|
|
|
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Prompt length: {len(prompt)} characters")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
2026-02-08 15:34:26 +00:00
|
|
|
if not role_uuid:
|
|
|
|
|
logger.warning(f"No role_uuid provided for inference run {execution_id}")
|
|
|
|
|
execution.status = "failed"
|
|
|
|
|
execution.error_message = "role_uuid_required"
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "failed")
|
|
|
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "role_uuid_required"})
|
|
|
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "role_uuid_required"})
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_error",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"error_message": "role_uuid_required",
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return {"status": "failed", "execution_id": execution_id, "error": "role_uuid_required"}
|
|
|
|
|
|
|
|
|
|
if role_uuid and prompt:
|
|
|
|
|
try:
|
|
|
|
|
context = services.get_context_for_query(
|
|
|
|
|
query=prompt,
|
|
|
|
|
role_uuid=str(role_uuid),
|
|
|
|
|
top_k=rag_top_k,
|
|
|
|
|
similarity_threshold=rag_similarity_threshold,
|
|
|
|
|
)
|
|
|
|
|
if context:
|
|
|
|
|
logger.info(f"RAG context retrieved for role={role_uuid} (top_k={rag_top_k})")
|
|
|
|
|
prompt = (
|
|
|
|
|
"You are a technical assistant.\n\n"
|
|
|
|
|
"Answer the question using ONLY the information in the context.\n"
|
|
|
|
|
"Do NOT:\n"
|
|
|
|
|
"- ask follow-up questions\n"
|
|
|
|
|
"- include hashtags\n"
|
|
|
|
|
"- include references or sources\n"
|
|
|
|
|
"- repeat the question\n"
|
|
|
|
|
"- add headings or sections\n"
|
|
|
|
|
"- add information not present in the context\n\n"
|
|
|
|
|
"Answer in 3-6 concise sentences.\n"
|
|
|
|
|
"If the context is insufficient, say: \"The context does not provide enough information.\"\n\n"
|
|
|
|
|
"Context:\n"
|
|
|
|
|
f"{context}\n\n"
|
|
|
|
|
"Question:\n"
|
|
|
|
|
f"{prompt}\n\n"
|
|
|
|
|
"Answer:"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
logger.info(f"No RAG context found for role={role_uuid}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"RAG context retrieval failed for role={role_uuid}: {e}")
|
|
|
|
|
|
2026-01-20 17:21:28 +00:00
|
|
|
if not prompt:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.warning(f"No prompt provided for inference run {execution_id}")
|
2026-01-20 17:21:28 +00:00
|
|
|
execution.status = "failed"
|
|
|
|
|
execution.error_message = "prompt_required"
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "failed")
|
|
|
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
|
|
|
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_error",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"error_message": "prompt_required",
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return {"status": "failed", "execution_id": execution_id, "error": "prompt_required"}
|
|
|
|
|
|
|
|
|
|
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "infer"})
|
|
|
|
|
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "infer"})
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
try:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Loading model: {agent.model.path}")
|
2026-01-20 17:21:28 +00:00
|
|
|
services.load_model_for_inference(agent.model.path)
|
2026-01-25 17:29:37 +00:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to preload model: {str(e)}")
|
2026-01-20 17:21:28 +00:00
|
|
|
pass
|
|
|
|
|
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Starting inference with model: {agent.model.path}")
|
2026-01-20 17:21:28 +00:00
|
|
|
result = services.infer_with_model(agent.model.path, prompt, options)
|
|
|
|
|
|
|
|
|
|
execution.status = "completed"
|
|
|
|
|
execution.output_data = {"result": result}
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "completed")
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.info(f"Inference execution {execution_id} completed successfully")
|
2026-01-20 17:21:28 +00:00
|
|
|
|
|
|
|
|
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "result": result})
|
|
|
|
|
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "result": result})
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_completed",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"output_data": execution.output_data,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return {"status": "completed", "execution_id": execution_id}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2026-01-25 17:29:37 +00:00
|
|
|
logger.error(f"Inference task failed with exception for execution {execution_id}: {str(e)}", exc_info=True)
|
2026-01-20 17:21:28 +00:00
|
|
|
traceback.print_exc()
|
|
|
|
|
execution.status = "failed"
|
|
|
|
|
execution.error_message = str(e)
|
|
|
|
|
execution.completed_at = timezone.now()
|
|
|
|
|
execution.save()
|
|
|
|
|
_update_agent_status(agent, "failed")
|
|
|
|
|
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
|
|
|
|
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
|
|
|
|
|
async_to_sync(get_channel_layer().group_send)(
|
|
|
|
|
room_group_name,
|
|
|
|
|
{
|
|
|
|
|
"type": "mlstore_error",
|
|
|
|
|
"execution_id": str(execution.uuid),
|
|
|
|
|
"error_message": str(e),
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return {"status": "failed", "execution_id": execution_id, "error": str(e)}
|