-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathingest.py
More file actions
94 lines (78 loc) · 3.25 KB
/
ingest.py
File metadata and controls
94 lines (78 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Ingest network knowledge base docs into ChromaDB via LangChain."""
import shutil
import sys
from pathlib import Path
from dotenv import load_dotenv
load_dotenv(Path(__file__).parent / ".env")
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from tools.rag import _CHROMA_DIR, _COLLECTION, _EMBEDDING_MODEL
DOCS_DIR = Path(__file__).parent / "docs"
CHROMA_DIR = Path(_CHROMA_DIR)
COLLECTION_NAME = _COLLECTION
CHUNK_SIZE = 800
CHUNK_OVERLAP = 100
_RFC_PROTOCOL_MAP = {
"rfc2328": "ospf",
"rfc3101": "ospf",
# Future: "rfc4271": "bgp", "rfc7868": "eigrp"
}
def extract_metadata(file_path: Path) -> dict:
"""Derive vendor, topic, and protocol metadata from filename."""
name = file_path.stem
if name.startswith("vendor_"):
parts = name[len("vendor_"):].split("_")
# Convention: vendor_<vendor>_<protocol>.md (e.g. vendor_cisco_ios_bgp.md)
# Current files: vendor_<vendor>.md (all OSPF)
vendor = "_".join(parts[:2]) if len(parts) >= 2 else parts[0]
protocol = parts[2] if len(parts) > 2 else "ospf"
return {"vendor": vendor, "topic": "vendor_guide", "source": file_path.name, "protocol": protocol}
elif name.startswith("rfc"):
rfc_id = name.split("_")[0]
protocol = _RFC_PROTOCOL_MAP.get(rfc_id, "general")
return {"vendor": "all", "topic": "rfc", "source": file_path.name, "protocol": protocol}
return {"vendor": "all", "topic": "general", "source": file_path.name, "protocol": "general"}
def ingest():
"""Load protocol docs, chunk, embed, and store in ChromaDB."""
md_files = sorted(DOCS_DIR.glob("*.md"))
if not md_files:
print(f"No .md files found in {DOCS_DIR}")
sys.exit(1)
documents = []
for fp in md_files:
text = fp.read_text(encoding="utf-8")
metadata = extract_metadata(fp)
documents.append(Document(page_content=text, metadata=metadata))
print(f"Loaded {len(documents)} document(s) from {DOCS_DIR}")
splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
separators=["\n## ", "\n### ", "\n\n", "\n", " "],
)
chunks = []
for doc in documents:
splits = splitter.split_documents([doc])
for chunk in splits:
chunk.metadata = doc.metadata.copy()
chunks.extend(splits)
# Prepend contextual headers for better embedding quality
for chunk in chunks:
src = chunk.metadata.get("source", "unknown")
proto = chunk.metadata.get("protocol", "general")
chunk.page_content = f"[Source: {src} | Protocol: {proto}]\n{chunk.page_content}"
print(f"Split into {len(chunks)} chunk(s)")
embeddings = HuggingFaceEmbeddings(model_name=_EMBEDDING_MODEL)
Chroma.from_documents(
documents=chunks,
embedding=embeddings,
persist_directory=str(CHROMA_DIR),
collection_name=COLLECTION_NAME,
)
print(f"Stored in ChromaDB at {CHROMA_DIR}")
if __name__ == "__main__":
if "--clean" in sys.argv and CHROMA_DIR.exists():
shutil.rmtree(CHROMA_DIR)
print(f"Cleaned {CHROMA_DIR}")
ingest()