260 lines
8.9 KiB
Python
260 lines
8.9 KiB
Python
import os
|
|
import httpx
|
|
from mcp.server.fastmcp import FastMCP
|
|
from mcp.server.transport_security import TransportSecuritySettings
|
|
import psycopg
|
|
from pgvector.psycopg import register_vector
|
|
import uuid
|
|
import logging
|
|
import json
|
|
from pypdf import PdfReader
|
|
from contextlib import contextmanager
|
|
|
|
# Configuration
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://postgres:password@postgres.knowledge-mcp.svc:5432/knowledge")
|
|
TEI_URL = os.getenv("TEI_URL", "http://text-embeddings.tei.svc.cluster.local:8080")
|
|
EMBEDDING_DIM = 768 # BAAI/bge-base-en-v1.5
|
|
|
|
# Initialize
|
|
logging.basicConfig(level=logging.INFO)
|
|
mcp = FastMCP(
|
|
"knowledge-mcp",
|
|
host="0.0.0.0",
|
|
port=8000,
|
|
sse_path="/sse",
|
|
transport_security=TransportSecuritySettings(
|
|
enable_dns_rebinding_protection=True,
|
|
allowed_hosts=[
|
|
"localhost:*",
|
|
"127.0.0.1:*",
|
|
"knowledge-mcp:*",
|
|
"knowledge-mcp.knowledge-mcp.svc:*",
|
|
"knowledge-mcp.knowledge-mcp.svc.cluster.local:*",
|
|
],
|
|
allowed_origins=[],
|
|
),
|
|
)
|
|
|
|
@contextmanager
|
|
def get_db(init=False):
|
|
"""Provide a database connection."""
|
|
conn = psycopg.connect(DATABASE_URL, autocommit=True)
|
|
# Register vector type handler (skip during init phase before extension exists)
|
|
if not init:
|
|
try:
|
|
register_vector(conn)
|
|
except Exception as e:
|
|
logging.warning(f"Vector registration failed (ignoring if init): {e}")
|
|
try:
|
|
yield conn
|
|
finally:
|
|
conn.close()
|
|
|
|
def init_db():
|
|
"""Initialize database schema."""
|
|
try:
|
|
# Pass init=True to skip vector registration before extension exists
|
|
with get_db(init=True) as conn:
|
|
conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
|
|
# Now we can register it for the rest of the session if we wanted,
|
|
# but for this function we just need to create tables.
|
|
|
|
# Notebooks table (simple registry)
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS notebooks (
|
|
name TEXT PRIMARY KEY,
|
|
created_at TIMESTAMP DEFAULT NOW()
|
|
)
|
|
""")
|
|
|
|
# Chunks table
|
|
conn.execute(f"""
|
|
CREATE TABLE IF NOT EXISTS chunks (
|
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
|
notebook TEXT REFERENCES notebooks(name) ON DELETE CASCADE,
|
|
content TEXT NOT NULL,
|
|
embedding VECTOR({EMBEDDING_DIM}),
|
|
source TEXT,
|
|
metadata JSONB,
|
|
created_at TIMESTAMP DEFAULT NOW()
|
|
)
|
|
""")
|
|
|
|
# Index for fast search
|
|
conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS chunks_embedding_idx ON chunks
|
|
USING hnsw (embedding vector_cosine_ops)
|
|
""")
|
|
logging.info("Database initialized successfully.")
|
|
except Exception as e:
|
|
logging.error(f"Database initialization failed: {e}")
|
|
|
|
# Run init on import (or startup)
|
|
# In a real app, this might be a separate migration step, but for MCP self-contained:
|
|
try:
|
|
init_db()
|
|
except Exception as e:
|
|
logging.warning(f"Could not initialize DB immediately (might be waiting for connection): {e}")
|
|
|
|
def get_embedding(text: str) -> list[float]:
|
|
"""Get embedding from TEI."""
|
|
url = f"{TEI_URL}/embed"
|
|
try:
|
|
response = httpx.post(url, json={"inputs": text}, timeout=10.0)
|
|
response.raise_for_status()
|
|
return response.json()[0]
|
|
except Exception as e:
|
|
logging.error(f"Embedding failed: {e}")
|
|
raise
|
|
|
|
def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]:
|
|
"""Sliding window chunking."""
|
|
if len(text) <= chunk_size:
|
|
return [text]
|
|
|
|
chunks = []
|
|
start = 0
|
|
while start < len(text):
|
|
end = start + chunk_size
|
|
chunks.append(text[start:end])
|
|
start += (chunk_size - overlap)
|
|
return chunks
|
|
|
|
@mcp.tool()
|
|
def create_notebook(name: str) -> str:
|
|
"""Create a new RAG notebook."""
|
|
clean_name = name.lower().replace(" ", "-")
|
|
|
|
try:
|
|
with get_db() as conn:
|
|
# Check existence
|
|
res = conn.execute("SELECT 1 FROM notebooks WHERE name = %s", (clean_name,)).fetchone()
|
|
if res:
|
|
return f"Notebook '{clean_name}' already exists."
|
|
|
|
conn.execute("INSERT INTO notebooks (name) VALUES (%s)", (clean_name,))
|
|
return f"Notebook '{clean_name}' created successfully."
|
|
except Exception as e:
|
|
return f"Error creating notebook: {e}"
|
|
|
|
@mcp.tool()
|
|
def add_source(notebook: str, content: str, source_name: str, format: str = "text") -> str:
|
|
"""
|
|
Add content to a notebook.
|
|
format: 'text' or 'pdf_path' (local path inside container)
|
|
"""
|
|
clean_name = notebook.lower().replace(" ", "-")
|
|
|
|
try:
|
|
with get_db() as conn:
|
|
if not conn.execute("SELECT 1 FROM notebooks WHERE name = %s", (clean_name,)).fetchone():
|
|
return f"Error: Notebook '{clean_name}' does not exist."
|
|
except Exception as e:
|
|
return f"Database error: {e}"
|
|
|
|
text_to_process = ""
|
|
|
|
if format == "pdf_path":
|
|
try:
|
|
reader = PdfReader(content)
|
|
for page in reader.pages:
|
|
text_to_process += page.extract_text() + "\n"
|
|
except Exception as e:
|
|
return f"Error reading PDF: {e}"
|
|
else:
|
|
text_to_process = content
|
|
|
|
chunks = chunk_text(text_to_process)
|
|
count = 0
|
|
|
|
try:
|
|
with get_db() as conn:
|
|
for i, chunk in enumerate(chunks):
|
|
try:
|
|
vector = get_embedding(chunk)
|
|
meta = {
|
|
"chunk_index": i,
|
|
"total_chunks": len(chunks)
|
|
}
|
|
|
|
conn.execute("""
|
|
INSERT INTO chunks (notebook, content, embedding, source, metadata)
|
|
VALUES (%s, %s, %s, %s, %s)
|
|
""", (clean_name, chunk, vector, source_name, json.dumps(meta)))
|
|
count += 1
|
|
except Exception as e:
|
|
logging.error(f"Failed to process chunk {i}: {e}")
|
|
continue
|
|
return f"Added {count} chunks from '{source_name}' to '{clean_name}'."
|
|
|
|
except Exception as e:
|
|
return f"Failed to add source: {e}"
|
|
|
|
@mcp.tool()
|
|
def query_notebook(notebook: str, query: str, limit: int = 5) -> str:
|
|
"""Query the notebook for relevant context."""
|
|
clean_name = notebook.lower().replace(" ", "-")
|
|
|
|
try:
|
|
vector = get_embedding(query)
|
|
|
|
with get_db() as conn:
|
|
# Check notebook
|
|
if not conn.execute("SELECT 1 FROM notebooks WHERE name = %s", (clean_name,)).fetchone():
|
|
return f"Error: Notebook '{clean_name}' does not exist."
|
|
|
|
# Cosine distance (<=>) sort ASC (closest first)
|
|
results = conn.execute("""
|
|
SELECT content, source, (1 - (embedding <=> %s::vector)) as score
|
|
FROM chunks
|
|
WHERE notebook = %s
|
|
ORDER BY embedding <=> %s::vector ASC
|
|
LIMIT %s
|
|
""", (vector, clean_name, vector, limit)).fetchall()
|
|
|
|
output = []
|
|
for row in results:
|
|
content, source, score = row
|
|
output.append(f"[{score:.2f}] {source}: {content.replace(chr(10), ' ')}...")
|
|
|
|
if not output:
|
|
return "No relevant matches found."
|
|
|
|
return "\n".join(output)
|
|
|
|
except Exception as e:
|
|
return f"Query failed: {e}"
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
logging.info("Starting knowledge-mcp server via python entrypoint...")
|
|
|
|
# Strategy 1 (preferred): build an ASGI app and run via uvicorn.
|
|
# This avoids FastMCP CLI argument parsing ambiguity.
|
|
app = None
|
|
try:
|
|
if hasattr(mcp, "sse_app"):
|
|
app = mcp.sse_app()
|
|
logging.info("Using mcp.sse_app() with uvicorn")
|
|
elif hasattr(mcp, "http_app"):
|
|
app = mcp.http_app(path="/sse")
|
|
logging.info("Using mcp.http_app(path='/sse') with uvicorn")
|
|
except Exception as e:
|
|
logging.warning(f"ASGI app construction failed, will fallback to mcp.run(): {e}")
|
|
|
|
if app is not None:
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
else:
|
|
# Strategy 2: programmatic FastMCP run with explicit settings
|
|
# (works on newer FastMCP APIs).
|
|
try:
|
|
logging.info("Falling back to mcp.run(transport='sse', host='0.0.0.0', port=8000)")
|
|
mcp.run(transport="sse", host="0.0.0.0", port=8000)
|
|
except TypeError:
|
|
# Final fallback for older signatures.
|
|
logging.info("Fallback signature without host/port")
|
|
mcp.run(transport="sse")
|
|
|