{ "cells": [ { "cell_type": "markdown", "id": "45d62106", "metadata": {}, "source": [ "# Basic RAG Implementation with a local LLM" ] }, { "cell_type": "code", "execution_count": 33, "id": "4c312410", "metadata": {}, "outputs": [], "source": [ "from gpt4all import GPT4All\n", "from sentence_transformers import SentenceTransformer\n", "from chromadb import PersistentClient\n", "from docx import Document\n", "\n", "MODEL = \"Meta-Llama-3-8B-Instruct.Q4_0.gguf\"\n", "CONTEXT_SIZE = 8192\n", "EMBEDDER = \"all-MiniLM-L6-v2\"\n", "RAG_PATH = \"./build/rag_db\"\n", "DOCS_PATH = \"fNIRS_Glossary_Hardware.docx\"" ] }, { "cell_type": "code", "execution_count": 34, "id": "90bae527", "metadata": {}, "outputs": [], "source": [ "\n", "model = GPT4All(model_name = MODEL, n_ctx = CONTEXT_SIZE, allow_download = True)\n", "embedder = SentenceTransformer(EMBEDDER)\n", "client = PersistentClient(path = RAG_PATH)\n", "\n", "\n", "class EmbeddingFunctionWrapper:\n", " def __init__(self, model):\n", " self.model = model\n", "\n", " def name(self):\n", " return \"sentence-transformers\"\n", "\n", " def __call__(self, input):\n", " if isinstance(input, str):\n", " texts = [input]\n", " embs = self.model.encode(texts).tolist()\n", " return embs[0]\n", " else:\n", " texts = list(input)\n", " return self.model.encode(texts).tolist()\n", "\n", "embedding_fn = EmbeddingFunctionWrapper(embedder)" ] }, { "cell_type": "code", "execution_count": 35, "id": "34efbc7c", "metadata": {}, "outputs": [], "source": [ "doc = Document(DOCS_PATH)\n", "docx_content = \"\\n\".join([paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()])\n", "chunk_size = 1000\n", "documents = [docx_content[i:i+chunk_size] for i in range(0, len(docx_content), chunk_size) if docx_content[i:i+chunk_size].strip()]\n", "embeddings = embedder.encode(documents).tolist()\n", "collection = client.get_or_create_collection(\n", " name = \"knowledge_base\",\n", " embedding_function = embedding_fn,\n", ")\n", "collection.add(\n", " documents=documents,\n", " embeddings=embeddings,\n", " ids=[f\"doc{i}\" for i in range(len(documents))]\n", ")" ] }, { "cell_type": "code", "execution_count": 36, "id": "ed2cc1ff", "metadata": {}, "outputs": [], "source": [ "def retrieve(query, top_k = 1):\n", " query_embedding = embedder.encode([query]).tolist()[0]\n", " try:\n", " results = collection.query(query_texts=[query], n_results=top_k)\n", " return results[\"documents\"][0]\n", " except Exception:\n", " results = collection.query(query_embeddings=[query_embedding], n_results=top_k)\n", " return results[\"documents\"][0]\n", "\n", "def rag_answer(query):\n", " retrieved_docs = retrieve(query)\n", " context = \"\\n\\n\".join(retrieved_docs)\n", " max_context_length = 500\n", " if len(context) > max_context_length:\n", " context = context[:max_context_length] + \"...\"\n", "\n", " prompt = f\"\"\"\n", "Use the context to answer the question.\n", "Context:\n", "{context}\n", "Question:\n", "{query}\n", "Answer:\n", "\"\"\"\n", " print(f\"Prompt length: {len(prompt)}\")\n", " return model.generate(prompt, max_tokens=200)" ] }, { "cell_type": "code", "execution_count": 37, "id": "6fa9fd10", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of documents: 68\n", "Document lengths: [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 63]\n", "Retrieved docs length: 1\n", "Prompt length: 627\n" ] } ], "source": [ "query = \"What can Frequency domain multidistance NIRS estimate?\"\n", "print(f\"Number of documents: {len(documents)}\")\n", "print(f\"Document lengths: {[len(doc) for doc in documents]}\")\n", "retrieved = retrieve(query)\n", "print(f\"Retrieved docs length: {len(retrieved)}\")\n", "response = rag_answer(query)" ] }, { "cell_type": "code", "execution_count": 38, "id": "5a82353e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Frequency-domain (FD) multidistance NIRS technique can estimate absolute values of absorption and scattering of the medium, and subsequently chromophore concentrations.'" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "response" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.9" } }, "nbformat": 4, "nbformat_minor": 5 }