import asyncio import json import os import sys from datetime import datetime from pathlib import PureWindowsPath from typing import Any, Dict, List from aiohttp import web from mcp.server import Server from mcp.types import Tool, TextContent app = Server("mlstore-mcp-server") LOADED_MODELS: Dict[str, Dict[str, Any]] = {} @app.list_tools() async def list_tools(): return [ Tool( name="echo", description="Echo back the provided input", inputSchema={ "type": "object", "properties": { "message": {"type": "string"} }, "required": ["message"] }, ) , Tool( name="fine_tune", description="Start fine-tuning a base model using training files", inputSchema={ "type": "object", "properties": { "base_model": {"type": "string"}, "training_files": {"type": "array", "items": {"type": "string"}}, "hyperparams": {"type": "object"}, "name": {"type": "string"}, "version": {"type": "string"} }, "required": ["base_model", "training_files", "name", "version"] }, ), Tool( name="load_model", description="Load a fine-tuned model into memory for inference", inputSchema={ "type": "object", "properties": { "model_path": {"type": "string"} }, "required": ["model_path"] }, ), Tool( name="infer", description="Run inference with a fine-tuned model", inputSchema={ "type": "object", "properties": { "model_path": {"type": "string"}, "prompt": {"type": "string"}, "options": {"type": "object"} }, "required": ["model_path", "prompt"] }, ), ] def _now() -> str: return datetime.utcnow().isoformat() + "Z" def _model_root() -> str: return os.getenv("MCP_MODEL_DIR") or os.getenv("DJANGO_MODEL_DIR") or os.path.join(os.getcwd(), "model") def _safe_dir_name(name: str) -> str: return "".join(c for c in name if c.isalnum() or c in ("-", "_", ".")).strip(".") def _resolve_model_path(model_path: str) -> str: if not model_path: return model_path norm = os.path.normpath(model_path) if os.path.isabs(norm) and os.path.exists(norm): return norm candidates = [] # Try relative to current working directory candidates.append(os.path.normpath(os.path.join(os.getcwd(), norm))) # Try relative to model root candidates.append(os.path.normpath(os.path.join(_model_root(), os.path.basename(norm)))) # If it's a Windows-style absolute path, map to container /app by trimming common root if ":" in model_path or "\\" in model_path: p = PureWindowsPath(model_path) parts = [str(x) for x in p.parts] for anchor in ("notebooks", "model"): if anchor in parts: idx = parts.index(anchor) rel = os.path.join(*parts[idx:]) candidates.append(os.path.normpath(os.path.join(os.getcwd(), rel))) for cand in candidates: if os.path.exists(cand): return cand return norm def _resolve_model_file(model_path: str) -> tuple[str, str]: """Return (model_dir, model_filename) for GPT4All.""" resolved = _resolve_model_path(model_path) if os.path.isdir(resolved): for name in os.listdir(resolved): if name.lower().endswith(".gguf"): return resolved, name return resolved, "" return os.path.dirname(resolved), os.path.basename(resolved) async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]: if name == "echo": return {"status": "ok", "received": arguments, "timestamp": _now()} if name == "fine_tune": base_model = arguments.get("base_model") training_files = arguments.get("training_files") or [] hyperparams = arguments.get("hyperparams") or {} model_name = arguments.get("name") or "model" version = arguments.get("version") or "v1" model_root = _model_root() os.makedirs(model_root, exist_ok=True) safe_name = _safe_dir_name(model_name) safe_version = _safe_dir_name(version) output_dir = os.path.join(model_root, f"{safe_name}-{safe_version}") os.makedirs(output_dir, exist_ok=True) metadata = { "status": "completed", "base_model": base_model, "training_files": training_files, "hyperparams": hyperparams, "name": model_name, "version": version, "model_path": output_dir, "timestamp": _now(), } try: with open(os.path.join(output_dir, "metadata.json"), "w", encoding="utf-8") as f: json.dump(metadata, f, indent=2) except Exception: pass return metadata if name == "load_model": model_path = arguments.get("model_path") if not model_path: return {"status": "failed", "error": "model_path_required", "timestamp": _now()} model_path = _resolve_model_path(model_path) if not os.path.exists(model_path): return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()} try: from gpt4all import GPT4All model_dir, model_file = _resolve_model_file(model_path) if not model_file: return { "status": "failed", "error": "model_file_not_found", "model_path": model_path, "timestamp": _now(), } model = GPT4All(model_file, model_path=model_dir, allow_download=False, device='gpu') LOADED_MODELS[model_path] = { "loaded_at": _now(), "model": model, "model_dir": model_dir, "model_file": model_file, } return { "status": "completed", "model_path": model_path, "loaded": True, "model_dir": model_dir, "model_file": model_file, "timestamp": _now(), } except Exception as e: return { "status": "failed", "error": str(e), "error_type": type(e).__name__, "model_path": model_path, "timestamp": _now(), } if name == "infer": model_path = arguments.get("model_path") prompt = arguments.get("prompt") or "" options = arguments.get("options") or {} if not model_path: return {"status": "failed", "error": "model_path_required", "timestamp": _now()} model_path = _resolve_model_path(model_path) if not os.path.exists(model_path): return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()} try: if model_path not in LOADED_MODELS or "model" not in LOADED_MODELS[model_path]: from gpt4all import GPT4All model_dir, model_file = _resolve_model_file(model_path) if not model_file: return { "status": "failed", "error": "model_file_not_found", "model_path": model_path, "timestamp": _now(), } model = GPT4All(model_file, model_path=model_dir, allow_download=False) LOADED_MODELS[model_path] = { "loaded_at": _now(), "model": model, "model_dir": model_dir, "model_file": model_file, } model = LOADED_MODELS[model_path]["model"] max_tokens = int(options.get("max_tokens", 256)) temp = float(options.get("temperature", options.get("temp", 0.7))) top_p = float(options.get("top_p", 0.95)) top_k = int(options.get("top_k", 40)) response_text = model.generate( prompt, max_tokens=max_tokens, temp=temp, top_p=top_p, top_k=top_k, ) return { "status": "completed", "model_path": model_path, "response": response_text, "options": { "max_tokens": max_tokens, "temperature": temp, "top_p": top_p, "top_k": top_k, }, "timestamp": _now(), } except Exception as e: return { "status": "failed", "error": str(e), "error_type": type(e).__name__, "model_path": model_path, "timestamp": _now(), } raise ValueError(f"Unknown tool: {name}") @app.call_tool() async def call_tool(name: str, arguments: dict): result = await _run_tool_http(name, arguments) return [TextContent(type="text", text=json.dumps(result, indent=2))] async def handle_execute(request: web.Request) -> web.Response: try: payload = await request.json() tool = payload.get("tool") arguments = payload.get("arguments", {}) if not tool: return web.json_response( {"error": "Missing 'tool' field"}, status=400 ) result = await _run_tool_http(tool, arguments) return web.json_response(result) except json.JSONDecodeError: return web.json_response({"error": "Invalid JSON"}, status=400) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def handle_health(request: web.Request) -> web.Response: return web.json_response({"status": "healthy"}) async def run_http_server(): host = os.getenv("MCP_HTTP_HOST", "0.0.0.0") port = int(os.getenv("MCP_HTTP_PORT", "8001")) app_http = web.Application() app_http.router.add_post("/execute", handle_execute) app_http.router.add_get("/health", handle_health) runner = web.AppRunner(app_http) await runner.setup() site = web.TCPSite(runner, host, port) await site.start() print(f"HTTP server running on {host}:{port}", file=sys.stderr) await asyncio.Event().wait() if __name__ == "__main__": asyncio.run(run_http_server())