Dynavera/apps/mlstore/tasks.py

657 lines
26 KiB
Python

import logging
import os
import re
import time
import traceback
from hashlib import sha256
from asgiref.sync import async_to_sync
from celery import shared_task
from channels.layers import get_channel_layer
from django.utils import timezone
from django.db import transaction
from apps.orgs.models import TrainingFile, Role
from . import services
from .models import Agent, AgentEvent, AgentModel, AgentRun, RoleRagDocument
logger = logging.getLogger(__name__)
def _get_mem_info() -> str:
try:
with open('/proc/self/status', 'r', encoding='utf-8') as f:
lines = f.read().splitlines()
mem = {line.split(':', 1)[0]: line.split(':', 1)[1].strip() for line in lines if ':' in line}
return f"VmRSS={mem.get('VmRSS','?')}, VmHWM={mem.get('VmHWM','?')}, VmSize={mem.get('VmSize','?')}"
except Exception:
return "mem_info_unavailable"
def _estimate_tokens(text: str) -> int:
if not text:
return 0
return len(re.findall(r"\w+|[^\s\w]", text))
def _split_semantic_units(text: str) -> list[str]:
paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", text) if p.strip()]
units: list[str] = []
for para in paragraphs:
sentences = re.split(r"(?<=[.!?])\s+", para)
for sent in sentences:
sent = sent.strip()
if sent:
units.append(sent)
return units or paragraphs
def _chunk_text(text: str, max_tokens: int = 400, overlap_tokens: int = 50) -> list[str]:
if not text:
return []
units = _split_semantic_units(text)
logger.info(
"Semantic chunking units=%s max_tokens=%s overlap_tokens=%s mem=%s",
len(units),
max_tokens,
overlap_tokens,
_get_mem_info(),
)
chunks: list[str] = []
current: list[str] = []
current_tokens = 0
for unit in units:
unit_tokens = _estimate_tokens(unit)
if unit_tokens == 0:
continue
if current_tokens + unit_tokens > max_tokens and current:
chunk = " ".join(current).strip()
if chunk:
chunks.append(chunk)
if overlap_tokens > 0:
overlap: list[str] = []
overlap_count = 0
for prev in reversed(current):
prev_tokens = _estimate_tokens(prev)
if overlap_count + prev_tokens > overlap_tokens:
break
overlap.insert(0, prev)
overlap_count += prev_tokens
current = overlap
current_tokens = overlap_count
else:
current = []
current_tokens = 0
current.append(unit)
current_tokens += unit_tokens
if current:
chunk = " ".join(current).strip()
if chunk:
chunks.append(chunk)
return chunks
def _extract_text_from_file(file_path: str, file_type: str | None) -> str:
file_type = (file_type or '').lower()
if file_type in {'txt', 'md', 'csv', 'json'}:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.read()
if file_type == 'pdf':
try:
from PyPDF2 import PdfReader
except Exception as e:
raise RuntimeError('PyPDF2 is required to parse PDF files') from e
reader = PdfReader(file_path)
return "\n".join(page.extract_text() or "" for page in reader.pages)
if file_type in {'docx', 'doc'}:
try:
import docx
except Exception as e:
raise RuntimeError('python-docx is required to parse DOCX files') from e
doc = docx.Document(file_path)
return "\n".join(p.text for p in doc.paragraphs)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.read()
def _send_group_event(room_group_name: str, event_type: str, content: dict):
channel_layer = get_channel_layer()
async_to_sync(channel_layer.group_send)(
room_group_name,
{
"type": "mlstore_event",
"event_type": event_type,
"content": content,
"timestamp": timezone.now().isoformat(),
}
)
def _persist_event(execution: AgentRun, event_type: str, content: dict):
AgentEvent.objects.create(
execution=execution,
event_type=event_type,
content=content,
)
def _update_agent_status(agent: Agent, status: str):
agent.status = status
if status == "running":
agent.started_at = timezone.now()
elif status in ("completed", "failed"):
agent.completed_at = timezone.now()
agent.save()
@shared_task
def start_fine_tune_run_task(execution_id: str):
logger.info(f"Fine-tune run task started for execution: {execution_id}")
try:
execution = AgentRun.objects.get(uuid=execution_id)
except AgentRun.DoesNotExist:
logger.error(f"Execution not found: {execution_id}")
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
agent = execution.agent
room_group_name = f"mlstore_agent_{agent.uuid}"
logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}")
execution.status = "running"
execution.started_at = timezone.now()
execution.save()
_update_agent_status(agent, "running")
logger.info(f"Execution {execution_id} status updated to 'running'")
from apps.mlstore.services import BASE_MODEL_CACHE
logger.info(f"Base model cache directory: {BASE_MODEL_CACHE}")
input_data = execution.input_data or {}
base_model = input_data.get("base_model") or agent.model.name
training_files = input_data.get("training_files") or []
org_training_files = []
role_uuid = input_data.get("role_uuid")
if not training_files and agent.organization:
training_files_qs = TrainingFile.objects.filter(
role__organization=agent.organization,
is_processed=False
).select_related('uploaded_by', 'role')
if role_uuid:
try:
role = Role.objects.get(uuid=role_uuid, organization=agent.organization)
training_files_qs = training_files_qs.filter(role=role)
except Role.DoesNotExist:
logger.warning(f"Role {role_uuid} not found for organization {agent.organization.name}")
org_training_files = list(training_files_qs)
training_files = [tf.file.path for tf in org_training_files if tf.file]
logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}")
hyperparams = input_data.get("hyperparams") or {}
name = input_data.get("name") or agent.model.name
if not input_data.get("version"):
existing_models = AgentModel.objects.filter(name=name).order_by('-version')
if existing_models.exists():
last_version = existing_models.first().version
try:
if last_version.startswith('v'):
num = int(last_version[1:])
version = f"v{num + 1}"
else:
version = f"v1"
except:
version = "v1"
else:
version = "v1"
else:
version = input_data.get("version")
logger.info(f"Fine-tune parameters: base_model={base_model}, name={name}, version={version}")
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"})
try:
result = services.fine_tune_model(base_model, training_files, hyperparams, name, version)
logger.info(f"Fine-tune result received: {result.get('status')}")
logger.debug(f"Full fine-tune result: {result}")
if isinstance(result, dict) and result.get("status") == "completed":
model_path = result.get("model_path") or result.get("path") or ""
model_version = result.get("version") or version
new_model = AgentModel.objects.create(name=name, version=model_version, path=model_path)
agent.model = new_model
agent.save()
logger.info(f"Fine-tune completed. New model created: {new_model.uuid} at {model_path}")
if org_training_files:
file_ids = [tf.id for tf in org_training_files]
TrainingFile.objects.filter(id__in=file_ids).update(is_processed=True)
logger.info(f"Marked {len(org_training_files)} training files as processed")
execution.status = "completed"
execution.output_data = {
"result": result,
"model_id": new_model.id,
"model_uuid": str(new_model.uuid),
}
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "completed")
logger.info(f"Execution {execution_id} completed successfully")
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_completed",
"execution_id": str(execution.uuid),
"output_data": execution.output_data,
},
)
return {"status": "completed", "execution_id": execution_id, "model_id": new_model.id}
logger.warning(f"Fine-tune did not complete successfully. Status: {result.get('status')}")
execution.status = "failed"
execution.error_message = str(result)
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "failed")
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": result})
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": result})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_error",
"execution_id": str(execution.uuid),
"error_message": str(result),
},
)
return {"status": "failed", "execution_id": execution_id, "result": result}
except Exception as e:
logger.error(f"Fine-tune task failed with exception for execution {execution_id}: {str(e)}", exc_info=True)
traceback.print_exc()
execution.status = "failed"
execution.error_message = str(e)
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "failed")
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_error",
"execution_id": str(execution.uuid),
"error_message": str(e),
},
)
return {"status": "error", "execution_id": execution_id, "error": str(e)}
@shared_task
def ingest_training_file_task(training_file_uuid: str):
logger.info(f"Ingest task started for training_file_uuid={training_file_uuid}")
started_at = time.time()
try:
training_file = TrainingFile.objects.select_related('role').get(uuid=training_file_uuid)
except TrainingFile.DoesNotExist:
logger.error(f"Training file not found: {training_file_uuid}")
return {"status": "error", "error": "training_file_not_found"}
if training_file.is_processed:
logger.info(f"Training file already processed: {training_file_uuid}")
return {"status": "skipped", "reason": "already_processed"}
if not training_file.file:
logger.error(f"Training file has no file attached: {training_file_uuid}")
return {"status": "error", "error": "file_missing"}
try:
file_path = training_file.file.path
file_size = os.path.getsize(file_path) if os.path.exists(file_path) else 0
logger.info(
"Ingesting file: name=%s type=%s size_bytes=%s role=%s path=%s",
training_file.file_name,
training_file.file_type,
file_size,
training_file.role_id,
file_path,
)
except Exception as e:
logger.warning(f"Failed to stat training file for {training_file_uuid}: {str(e)}")
try:
training_file.status = 'ingesting'
training_file.save(update_fields=['status'])
extract_started = time.time()
text = _extract_text_from_file(training_file.file.path, training_file.file_type)
logger.info(
"Extracted text length=%s for training_file_uuid=%s in %.2fs mem=%s",
len(text),
training_file_uuid,
time.time() - extract_started,
_get_mem_info(),
)
chunk_started = time.time()
chunks = _chunk_text(text)
logger.info(
"Chunked text into %s chunks in %.2fs (sample lengths: %s) mem=%s",
len(chunks),
time.time() - chunk_started,
[len(c) for c in chunks[:5]],
_get_mem_info(),
)
if not chunks:
raise RuntimeError("No text extracted from file")
with transaction.atomic():
logger.info("Clearing existing RAG docs for training_file_uuid=%s mem=%s", training_file_uuid, _get_mem_info())
RoleRagDocument.objects.filter(training_file=training_file).delete()
logger.info("Preparing %s RAG docs for bulk_create mem=%s", len(chunks), _get_mem_info())
existing_hashes = set(
RoleRagDocument.objects.filter(role=training_file.role)
.values_list('content_hash', flat=True)
)
documents = []
skipped = 0
for index, chunk in enumerate(chunks):
content_hash = sha256(chunk.encode('utf-8')).hexdigest()
if content_hash in existing_hashes:
skipped += 1
continue
documents.append(
RoleRagDocument(
role=training_file.role,
training_file=training_file,
content=chunk,
embedding=None,
content_hash=content_hash,
metadata={
"file_name": training_file.file_name,
"file_type": training_file.file_type,
"chunk_size": len(chunk),
"source": "training_file",
},
chunk_index=index,
)
)
logger.info("Skipped %s duplicate chunks based on content_hash", skipped)
logger.info("Bulk creating RAG docs count=%s mem=%s", len(documents), _get_mem_info())
RoleRagDocument.objects.bulk_create(documents, batch_size=500)
training_file.status = 'chunked'
training_file.is_processed = True
training_file.save(update_fields=['status', 'is_processed'])
elapsed = time.time() - started_at
logger.info(
"Ingested training file %s into %s RAG chunks in %.2fs",
training_file_uuid,
len(chunks),
elapsed,
)
logger.info(f"Enqueueing embedding task for training_file_uuid={training_file_uuid}")
embed_training_file_task.delay(training_file_uuid)
return {"status": "completed", "chunks": len(chunks)}
except Exception as e:
elapsed = time.time() - started_at
logger.error(f"Failed to ingest training file {training_file_uuid}: {str(e)}", exc_info=True)
logger.error(f"Ingest task failed after {elapsed:.2f}s for training_file_uuid={training_file_uuid}")
try:
TrainingFile.objects.filter(uuid=training_file_uuid).update(status='failed')
except Exception:
pass
return {"status": "error", "error": str(e)}
@shared_task
def embed_training_file_task(training_file_uuid: str):
"""Generate embeddings for all documents of a training file.
This task is called after chunking to embed the document chunks
using the configured embedding provider (OpenAI, Google, or local).
"""
logger.info(f"Embedding task started for training_file_uuid={training_file_uuid}")
started_at = time.time()
try:
training_file = TrainingFile.objects.select_related('role').get(uuid=training_file_uuid)
except TrainingFile.DoesNotExist:
logger.error(f"Training file not found: {training_file_uuid}")
return {"status": "error", "error": "training_file_not_found"}
try:
documents = list(RoleRagDocument.objects.filter(training_file=training_file))
if not documents:
logger.warning(f"No RAG documents found for training_file_uuid={training_file_uuid}")
return {"status": "skipped", "reason": "no_documents"}
logger.info(
f"Starting to embed {len(documents)} documents for training_file_uuid={training_file_uuid} "
f"mem={_get_mem_info()}"
)
num_embedded, num_failed = services.batch_embed_documents(documents, batch_size=32)
if num_failed == 0:
training_file.status = 'embedded'
training_file.save(update_fields=['status'])
logger.info(f"Successfully embedded all documents for training_file_uuid={training_file_uuid}")
elif num_embedded > 0:
training_file.status = 'embedded'
training_file.save(update_fields=['status'])
logger.warning(
f"Partially embedded {num_embedded} documents, {num_failed} failed "
f"for training_file_uuid={training_file_uuid}"
)
else:
training_file.status = 'failed'
training_file.save(update_fields=['status'])
logger.error(f"Failed to embed any documents for training_file_uuid={training_file_uuid}")
return {"status": "error", "error": "embedding_failed", "num_failed": num_failed}
elapsed = time.time() - started_at
logger.info(
f"Embedding task completed for {training_file_uuid}: "
f"embedded={num_embedded}, failed={num_failed}, time={elapsed:.2f}s"
)
return {
"status": "completed",
"num_embedded": num_embedded,
"num_failed": num_failed,
"elapsed": elapsed,
}
except Exception as e:
elapsed = time.time() - started_at
logger.error(
f"Failed to embed training file {training_file_uuid}: {str(e)}",
exc_info=True
)
try:
TrainingFile.objects.filter(uuid=training_file_uuid).update(status='failed')
except Exception:
pass
return {"status": "error", "error": str(e), "elapsed": elapsed}
@shared_task
def infer_run_task(execution_id: str):
logger.info(f"Inference run task started for execution: {execution_id}")
try:
execution = AgentRun.objects.get(uuid=execution_id)
except AgentRun.DoesNotExist:
logger.error(f"Execution not found: {execution_id}")
return {"status": "error", "error": "execution_not_found", "execution_id": execution_id}
agent = execution.agent
room_group_name = f"mlstore_agent_{agent.uuid}"
logger.info(f"Agent: {agent.uuid}, User: {execution.user.email_address}")
execution.status = "running"
execution.started_at = timezone.now()
execution.save()
_update_agent_status(agent, "running")
logger.info(f"Execution {execution_id} status updated to 'running'")
input_data = execution.input_data or {}
prompt = input_data.get("prompt") or input_data.get("query") or ""
options = dict(input_data.get("options") or {})
role_uuid = input_data.get("role_uuid") or options.get("role_uuid")
rag_top_k = int(input_data.get("rag_top_k", 5))
rag_similarity_threshold = float(input_data.get("rag_similarity_threshold", 0.5))
options.setdefault("temperature", 0.2)
options.setdefault("top_p", 0.9)
options.setdefault("max_tokens", 200)
options.setdefault("stop", ["\n\n", "References:", "Sources:"])
logger.info(f"Prompt length: {len(prompt)} characters")
if not role_uuid:
logger.warning(f"No role_uuid provided for inference run {execution_id}")
execution.status = "failed"
execution.error_message = "role_uuid_required"
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "failed")
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "role_uuid_required"})
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "role_uuid_required"})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_error",
"execution_id": str(execution.uuid),
"error_message": "role_uuid_required",
},
)
return {"status": "failed", "execution_id": execution_id, "error": "role_uuid_required"}
if role_uuid and prompt:
try:
context = services.get_context_for_query(
query=prompt,
role_uuid=str(role_uuid),
top_k=rag_top_k,
similarity_threshold=rag_similarity_threshold,
)
if context:
logger.info(f"RAG context retrieved for role={role_uuid} (top_k={rag_top_k})")
prompt = (
"You are a technical assistant.\n\n"
"Answer the question using ONLY the information in the context.\n"
"Do NOT:\n"
"- ask follow-up questions\n"
"- include hashtags\n"
"- include references or sources\n"
"- repeat the question\n"
"- add headings or sections\n"
"- add information not present in the context\n\n"
"Answer in 3-6 concise sentences.\n"
"If the context is insufficient, say: \"The context does not provide enough information.\"\n\n"
"Context:\n"
f"{context}\n\n"
"Question:\n"
f"{prompt}\n\n"
"Answer:"
)
else:
logger.info(f"No RAG context found for role={role_uuid}")
except Exception as e:
logger.warning(f"RAG context retrieval failed for role={role_uuid}: {e}")
if not prompt:
logger.warning(f"No prompt provided for inference run {execution_id}")
execution.status = "failed"
execution.error_message = "prompt_required"
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "failed")
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_error",
"execution_id": str(execution.uuid),
"error_message": "prompt_required",
},
)
return {"status": "failed", "execution_id": execution_id, "error": "prompt_required"}
_send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "infer"})
_persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "infer"})
try:
try:
logger.info(f"Loading model: {agent.model.path}")
services.load_model_for_inference(agent.model.path)
except Exception as e:
logger.warning(f"Failed to preload model: {str(e)}")
pass
logger.info(f"Starting inference with model: {agent.model.path}")
result = services.infer_with_model(agent.model.path, prompt, options)
execution.status = "completed"
execution.output_data = {"result": result}
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "completed")
logger.info(f"Inference execution {execution_id} completed successfully")
_send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "result": result})
_persist_event(execution, "completed", {"execution_id": str(execution.uuid), "result": result})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_completed",
"execution_id": str(execution.uuid),
"output_data": execution.output_data,
},
)
return {"status": "completed", "execution_id": execution_id}
except Exception as e:
logger.error(f"Inference task failed with exception for execution {execution_id}: {str(e)}", exc_info=True)
traceback.print_exc()
execution.status = "failed"
execution.error_message = str(e)
execution.completed_at = timezone.now()
execution.save()
_update_agent_status(agent, "failed")
_send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)})
_persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)})
async_to_sync(get_channel_layer().group_send)(
room_group_name,
{
"type": "mlstore_error",
"execution_id": str(execution.uuid),
"error_message": str(e),
},
)
return {"status": "failed", "execution_id": execution_id, "error": str(e)}