64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
|
|
"""Lightweight local LLM wrappers.
|
||
|
|
|
||
|
|
This file provides simple wrappers for `llama_cpp` and `transformers` backends.
|
||
|
|
They are intentionally minimal — adapt to your runtime and model formats.
|
||
|
|
"""
|
||
|
|
from typing import Optional
|
||
|
|
import logging
|
||
|
|
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class BaseLLM:
|
||
|
|
def generate(self, prompt: str) -> str:
|
||
|
|
raise NotImplementedError()
|
||
|
|
|
||
|
|
|
||
|
|
class LlamaCPPWrapper(BaseLLM):
|
||
|
|
def __init__(self, model_path: str):
|
||
|
|
try:
|
||
|
|
from llama_cpp import Llama
|
||
|
|
|
||
|
|
self._llm = Llama(model_path=model_path)
|
||
|
|
except Exception:
|
||
|
|
logger.exception("llama_cpp is unavailable or model failed to load")
|
||
|
|
self._llm = None
|
||
|
|
|
||
|
|
def generate(self, prompt: str) -> str:
|
||
|
|
if self._llm is None:
|
||
|
|
raise RuntimeError("Llama model not available")
|
||
|
|
resp = self._llm(prompt)
|
||
|
|
return resp.get("text") if isinstance(resp, dict) else str(resp)
|
||
|
|
|
||
|
|
|
||
|
|
class TransformersWrapper(BaseLLM):
|
||
|
|
def __init__(self, model_name_or_path: str):
|
||
|
|
try:
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
import torch
|
||
|
|
|
||
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map="auto")
|
||
|
|
except Exception:
|
||
|
|
logger.exception("transformers not available or model failed to load")
|
||
|
|
self.model = None
|
||
|
|
self.tokenizer = None
|
||
|
|
|
||
|
|
def generate(self, prompt: str) -> str:
|
||
|
|
if self.model is None or self.tokenizer is None:
|
||
|
|
raise RuntimeError("Transformers model not available")
|
||
|
|
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||
|
|
outputs = self.model.generate(**inputs, max_new_tokens=256)
|
||
|
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||
|
|
|
||
|
|
|
||
|
|
def get_llm_for_domain(domain: str, prefer: str | None = None) -> BaseLLM:
|
||
|
|
# Basic loader: choose Llama (gguf) if file exists, else fall back to transformers
|
||
|
|
model_dir = "models" / domain
|
||
|
|
gguf = model_dir / "model.gguf"
|
||
|
|
if gguf.exists():
|
||
|
|
return LlamaCPPWrapper(str(gguf))
|
||
|
|
# fallback: try transformers
|
||
|
|
return TransformersWrapper(str(model_dir))
|