From 2de25c4a0eb8e4046bee2792727e32214c7b852a Mon Sep 17 00:00:00 2001 From: Viswamedha Nalabotu Date: Tue, 27 Jan 2026 22:17:22 +0000 Subject: [PATCH] Cleaned tasks file, updated field for model --- .../0002_alter_agentrun_input_data.py | 15 +++ apps/mlstore/models.py | 2 +- apps/mlstore/tasks.py | 97 +++++++------------ 3 files changed, 53 insertions(+), 61 deletions(-) create mode 100644 apps/mlstore/migrations/0002_alter_agentrun_input_data.py diff --git a/apps/mlstore/migrations/0002_alter_agentrun_input_data.py b/apps/mlstore/migrations/0002_alter_agentrun_input_data.py new file mode 100644 index 0000000..443706e --- /dev/null +++ b/apps/mlstore/migrations/0002_alter_agentrun_input_data.py @@ -0,0 +1,15 @@ +from django.db import migrations, models + +class Migration(migrations.Migration): + + dependencies = [ + ('mlstore', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='agentrun', + name='input_data', + field=models.JSONField(blank=True, default=dict), + ), + ] diff --git a/apps/mlstore/models.py b/apps/mlstore/models.py index 846bb09..4ad9f4e 100644 --- a/apps/mlstore/models.py +++ b/apps/mlstore/models.py @@ -63,7 +63,7 @@ class AgentRun(TimeStampMixin, Model): user = ForeignKey(User, on_delete = CASCADE, related_name = 'agent_runs') status = CharField(max_length = 20, choices = RUN_CHOICES, default = 'queued') - input_data = JSONField(default = dict) + input_data = JSONField(default = dict, blank = True) output_data = JSONField(default = dict, blank = True) error_message = TextField(blank = True, default = "") started_at = DateTimeField(null = True, blank = True) diff --git a/apps/mlstore/tasks.py b/apps/mlstore/tasks.py index 89c4b54..68c992d 100644 --- a/apps/mlstore/tasks.py +++ b/apps/mlstore/tasks.py @@ -1,62 +1,16 @@ -from celery import shared_task -from django.utils import timezone -from channels.layers import get_channel_layer -from asgiref.sync import async_to_sync -from . import services -from .models import AgentModel, Agent, AgentRun, AgentEvent -import traceback import logging +import traceback +from asgiref.sync import async_to_sync +from celery import shared_task +from channels.layers import get_channel_layer +from django.utils import timezone + +from apps.orgs.models import TrainingFile +from . import services +from .models import Agent, AgentEvent, AgentModel, AgentRun logger = logging.getLogger(__name__) - -@shared_task -def start_fine_tune_task(base_model: str, training_files: list, hyperparams: dict, name: str, version: str): - """Start a fine-tune via MCP, and register the resulting model on success. - - This task calls `services.fine_tune_model`, expects a dict result with `status` and on success - `model_path` and optionally `version`. - """ - try: - result = services.fine_tune_model(base_model, training_files, hyperparams, name, version) - - if isinstance(result, dict) and result.get("status") == "completed": - model_path = result.get("model_path") or result.get("path") or "" - model_version = result.get("version") or version - m = AgentModel.objects.create(name=name, version=model_version, path=model_path) - return {"status": "ok", "model_id": m.id, "model_uuid": str(m.uuid), "model_path": model_path, "result": result} - - return {"status": "failed", "result": result} - - except Exception as e: - traceback.print_exc() - return {"status": "error", "error": str(e)} - - -@shared_task -def infer_with_model_task(model_id: int, prompt: str, options: dict = None): - """Run inference by requesting the MCP server to use the stored model. - - Looks up the `AgentModel` by `model_id`, calls `services.infer_with_model`, and returns the response. - """ - try: - model = AgentModel.objects.get(id=model_id) - except AgentModel.DoesNotExist: - return {"status": "error", "error": "model_not_found", "model_id": model_id} - - try: - services.load_model_for_inference(model.path) - except Exception: - pass - - try: - out = services.infer_with_model(model.path, prompt, options or {}) - return {"status": "completed", "model_id": model_id, "response": out} - except Exception as e: - traceback.print_exc() - return {"status": "failed", "error": str(e)} - - def _send_group_event(room_group_name: str, event_type: str, content: dict): channel_layer = get_channel_layer() async_to_sync(channel_layer.group_send)( @@ -113,18 +67,35 @@ def start_fine_tune_run_task(execution_id: str): base_model = input_data.get("base_model") or agent.model.name training_files = input_data.get("training_files") or [] + org_training_files = [] if not training_files and agent.organization: - from apps.orgs.models import TrainingFile - org_training_files = TrainingFile.objects.filter( + org_training_files = list(TrainingFile.objects.filter( organization=agent.organization, is_processed=False - ).select_related('uploaded_by') + ).select_related('uploaded_by')) training_files = [tf.file.path for tf in org_training_files if tf.file] logger.info(f"Fetched {len(training_files)} training files from organization {agent.organization.name}") hyperparams = input_data.get("hyperparams") or {} - name = input_data.get("name") or f"{agent.model.name}-ft" - version = input_data.get("version") or "v1" + name = input_data.get("name") or agent.model.name + + if not input_data.get("version"): + existing_models = AgentModel.objects.filter(name=name).order_by('-version') + if existing_models.exists(): + last_version = existing_models.first().version + try: + if last_version.startswith('v'): + num = int(last_version[1:]) + version = f"v{num + 1}" + else: + version = f"v1" + except: + version = "v1" + else: + version = "v1" + else: + version = input_data.get("version") + logger.info(f"Fine-tune parameters: base_model={base_model}, name={name}, version={version}") _send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"}) @@ -142,6 +113,12 @@ def start_fine_tune_run_task(execution_id: str): agent.model = new_model agent.save() logger.info(f"Fine-tune completed. New model created: {new_model.uuid} at {model_path}") + + if org_training_files: + file_ids = [tf.id for tf in org_training_files] + TrainingFile.objects.filter(id__in=file_ids).update(is_processed=True) + logger.info(f"Marked {len(org_training_files)} training files as processed") + execution.status = "completed" execution.output_data = {