Added rag implementation for testing of a local model
This commit is contained in:
parent
42cd79662d
commit
12e0f141fe
6 changed files with 243 additions and 109 deletions
|
|
@ -1,29 +0,0 @@
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .llm import get_llm_for_domain
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleAgent:
|
|
||||||
"""Minimal agent abstraction that calls a local LLM and returns responses."""
|
|
||||||
|
|
||||||
def __init__(self, name: str, domain: str, system_message: str | None = None):
|
|
||||||
self.name = name
|
|
||||||
self.domain = domain
|
|
||||||
self.system_message = system_message or "You are an assistant."
|
|
||||||
self._llm = get_llm_for_domain(domain)
|
|
||||||
|
|
||||||
def run(self, prompt: str, **kwargs: Any) -> str:
|
|
||||||
full_prompt = f"{self.system_message}\n\nUser: {prompt}"
|
|
||||||
logger.debug("Agent %s running prompt: %s", self.name, prompt)
|
|
||||||
return self._llm.generate(full_prompt)
|
|
||||||
|
|
||||||
|
|
||||||
def build_agents_for_domains(domains: list[str]) -> dict[str, SimpleAgent]:
|
|
||||||
agents = {}
|
|
||||||
for d in domains:
|
|
||||||
agents[d] = SimpleAgent(name=f"agent-{d}", domain=d, system_message=f"You are a tutor for {d}.")
|
|
||||||
return agents
|
|
||||||
|
|
@ -1,63 +0,0 @@
|
||||||
"""Lightweight local LLM wrappers.
|
|
||||||
|
|
||||||
This file provides simple wrappers for `llama_cpp` and `transformers` backends.
|
|
||||||
They are intentionally minimal — adapt to your runtime and model formats.
|
|
||||||
"""
|
|
||||||
from typing import Optional
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM:
|
|
||||||
def generate(self, prompt: str) -> str:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaCPPWrapper(BaseLLM):
|
|
||||||
def __init__(self, model_path: str):
|
|
||||||
try:
|
|
||||||
from llama_cpp import Llama
|
|
||||||
|
|
||||||
self._llm = Llama(model_path=model_path)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("llama_cpp is unavailable or model failed to load")
|
|
||||||
self._llm = None
|
|
||||||
|
|
||||||
def generate(self, prompt: str) -> str:
|
|
||||||
if self._llm is None:
|
|
||||||
raise RuntimeError("Llama model not available")
|
|
||||||
resp = self._llm(prompt)
|
|
||||||
return resp.get("text") if isinstance(resp, dict) else str(resp)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformersWrapper(BaseLLM):
|
|
||||||
def __init__(self, model_name_or_path: str):
|
|
||||||
try:
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
import torch
|
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map="auto")
|
|
||||||
except Exception:
|
|
||||||
logger.exception("transformers not available or model failed to load")
|
|
||||||
self.model = None
|
|
||||||
self.tokenizer = None
|
|
||||||
|
|
||||||
def generate(self, prompt: str) -> str:
|
|
||||||
if self.model is None or self.tokenizer is None:
|
|
||||||
raise RuntimeError("Transformers model not available")
|
|
||||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
|
||||||
outputs = self.model.generate(**inputs, max_new_tokens=256)
|
|
||||||
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_llm_for_domain(domain: str, prefer: str | None = None) -> BaseLLM:
|
|
||||||
# Basic loader: choose Llama (gguf) if file exists, else fall back to transformers
|
|
||||||
model_dir = "models" / domain
|
|
||||||
gguf = model_dir / "model.gguf"
|
|
||||||
if gguf.exists():
|
|
||||||
return LlamaCPPWrapper(str(gguf))
|
|
||||||
# fallback: try transformers
|
|
||||||
return TransformersWrapper(str(model_dir))
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
import logging
|
|
||||||
|
|
||||||
from .models import AgentRun
|
|
||||||
from .langgraph_adapter import SimpleAgent
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def run_agent(agent: SimpleAgent, prompt: str) -> str:
|
|
||||||
"""Run the agent and store an AgentRun record using the Django ORM."""
|
|
||||||
out = agent.run(prompt)
|
|
||||||
try:
|
|
||||||
AgentRun.objects.create(agent_name=agent.name, input_text=prompt, output_text=out)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to persist agent run via Django ORM")
|
|
||||||
return out
|
|
||||||
34
apps/domains/migrations/0001_initial.py
Normal file
34
apps/domains/migrations/0001_initial.py
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Generated by Django 5.2.8 on 2025-11-19 14:22
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
initial = True
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='Domain',
|
||||||
|
fields=[
|
||||||
|
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||||
|
('name', models.CharField(max_length=255, unique=True)),
|
||||||
|
('description', models.TextField(blank=True, default='')),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='Dataset',
|
||||||
|
fields=[
|
||||||
|
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||||
|
('name', models.CharField(max_length=255)),
|
||||||
|
('description', models.TextField(blank=True, default='')),
|
||||||
|
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||||
|
('updated_at', models.DateTimeField(auto_now=True)),
|
||||||
|
('domain', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='datasets', to='domains.domain')),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
@ -1,3 +1,22 @@
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
# Create your models here.
|
|
||||||
|
class Domain(models.Model):
|
||||||
|
|
||||||
|
name = models.CharField(max_length = 255, unique = True)
|
||||||
|
description = models.TextField(blank = True, default = "")
|
||||||
|
|
||||||
|
def __str__(self) -> str: # pragma: no cover - trivial
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(models.Model):
|
||||||
|
|
||||||
|
domain = models.ForeignKey(Domain, on_delete = models.CASCADE, related_name = "datasets")
|
||||||
|
name = models.CharField(max_length = 255)
|
||||||
|
description = models.TextField(blank = True, default = "")
|
||||||
|
created_at = models.DateTimeField(auto_now_add = True)
|
||||||
|
updated_at = models.DateTimeField(auto_now = True)
|
||||||
|
|
||||||
|
def __str__(self) -> str: # pragma: no cover - trivial
|
||||||
|
return f"{self.name} ({self.domain.name})"
|
||||||
189
notebooks/local-model-rag-implementation.ipynb
Normal file
189
notebooks/local-model-rag-implementation.ipynb
Normal file
|
|
@ -0,0 +1,189 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "45d62106",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Basic RAG Implementation with a local LLM"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 26,
|
||||||
|
"id": "4c312410",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from gpt4all import GPT4All\n",
|
||||||
|
"from sentence_transformers import SentenceTransformer\n",
|
||||||
|
"from chromadb import PersistentClient\n",
|
||||||
|
"from docx import Document\n",
|
||||||
|
"\n",
|
||||||
|
"MODEL = \"Meta-Llama-3-8B-Instruct.Q4_0.gguf\"\n",
|
||||||
|
"CONTEXT_SIZE = 8192\n",
|
||||||
|
"EMBEDDER = \"all-MiniLM-L6-v2\"\n",
|
||||||
|
"RAG_PATH = \"./build/rag_db\"\n",
|
||||||
|
"DOCS_PATH = \"C:\\\\Users\\\\nalab\\\\Downloads\\\\fNIRS_Glossary_Hardware.docx\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"id": "90bae527",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"model = GPT4All(model_name = MODEL, n_ctx = CONTEXT_SIZE, allow_download = True)\n",
|
||||||
|
"embedder = SentenceTransformer(EMBEDDER)\n",
|
||||||
|
"client = PersistentClient(path = RAG_PATH)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class EmbeddingFunctionWrapper:\n",
|
||||||
|
" def __init__(self, model):\n",
|
||||||
|
" self.model = model\n",
|
||||||
|
"\n",
|
||||||
|
" def name(self):\n",
|
||||||
|
" return \"sentence-transformers\"\n",
|
||||||
|
"\n",
|
||||||
|
" def __call__(self, input):\n",
|
||||||
|
" if isinstance(input, str):\n",
|
||||||
|
" texts = [input]\n",
|
||||||
|
" embs = self.model.encode(texts).tolist()\n",
|
||||||
|
" return embs[0]\n",
|
||||||
|
" else:\n",
|
||||||
|
" texts = list(input)\n",
|
||||||
|
" return self.model.encode(texts).tolist()\n",
|
||||||
|
"\n",
|
||||||
|
"embedding_fn = EmbeddingFunctionWrapper(embedder)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "34efbc7c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"doc = Document(DOCS_PATH)\n",
|
||||||
|
"docx_content = \"\\n\".join([paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()])\n",
|
||||||
|
"chunk_size = 1000\n",
|
||||||
|
"documents = [docx_content[i:i+chunk_size] for i in range(0, len(docx_content), chunk_size) if docx_content[i:i+chunk_size].strip()]\n",
|
||||||
|
"embeddings = embedder.encode(documents).tolist()\n",
|
||||||
|
"collection = client.get_or_create_collection(\n",
|
||||||
|
" name = \"knowledge_base\",\n",
|
||||||
|
" embedding_function = embedding_fn,\n",
|
||||||
|
")\n",
|
||||||
|
"collection.add(\n",
|
||||||
|
" documents=documents,\n",
|
||||||
|
" embeddings=embeddings,\n",
|
||||||
|
" ids=[f\"doc{i}\" for i in range(len(documents))]\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ed2cc1ff",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def retrieve(query, top_k = 1):\n",
|
||||||
|
" query_embedding = embedder.encode([query]).tolist()[0]\n",
|
||||||
|
" try:\n",
|
||||||
|
" results = collection.query(query_texts=[query], n_results=top_k)\n",
|
||||||
|
" return results[\"documents\"][0]\n",
|
||||||
|
" except Exception:\n",
|
||||||
|
" results = collection.query(query_embeddings=[query_embedding], n_results=top_k)\n",
|
||||||
|
" return results[\"documents\"][0]\n",
|
||||||
|
"\n",
|
||||||
|
"def rag_answer(query):\n",
|
||||||
|
" retrieved_docs = retrieve(query)\n",
|
||||||
|
" context = \"\\n\\n\".join(retrieved_docs)\n",
|
||||||
|
" max_context_length = 500\n",
|
||||||
|
" if len(context) > max_context_length:\n",
|
||||||
|
" context = context[:max_context_length] + \"...\"\n",
|
||||||
|
"\n",
|
||||||
|
" prompt = f\"\"\"\n",
|
||||||
|
"Use the context to answer the question.\n",
|
||||||
|
"Context:\n",
|
||||||
|
"{context}\n",
|
||||||
|
"Question:\n",
|
||||||
|
"{query}\n",
|
||||||
|
"Answer:\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
" print(f\"Prompt length: {len(prompt)}\")\n",
|
||||||
|
" return model.generate(prompt, max_tokens=200)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6fa9fd10",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Number of documents: 68\n",
|
||||||
|
"Document lengths: [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 63]\n",
|
||||||
|
"Retrieved docs length: 1\n",
|
||||||
|
"Prompt length: 630\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"query = \"What can Frequency domain multidistance NIRS estimate?\"\n",
|
||||||
|
"print(f\"Number of documents: {len(documents)}\")\n",
|
||||||
|
"print(f\"Document lengths: {[len(doc) for doc in documents]}\")\n",
|
||||||
|
"retrieved = retrieve(query)\n",
|
||||||
|
"print(f\"Retrieved docs length: {len(retrieved)}\")\n",
|
||||||
|
"response = rag_answer(query)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 32,
|
||||||
|
"id": "5a82353e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'Frequency-domain (FD) multidistance NIRS technique can estimate absolute values of absorption and scattering of the medium, and subsequently chromophore concentrations.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 32,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.13.9"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue