Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"torch>=2.0.0",
"transformers>=4.36.0",
"uvicorn>=0.24.0",
"safetensors>=0.4.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -71,5 +72,9 @@ pythonpath = [

[dependency-groups]
dev = [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
"pytest-mock>=3.10.0",
"ruff",
"gguf>=0.6.0",
]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ nltk>=3.8.0
python-dateutil>=2.8.0
jsonschema>=4.17.0
sentencepiece>=0.1.99
safetensors>=0.4.0

# Test dependencies
pytest>=7.0.0
Expand Down
10 changes: 5 additions & 5 deletions src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .schemas import (
DataSource,
ConfidenceLevel,
DataSource,
ConfidenceLevel,
ExtractionResult,
GenerateRequest,
BatchRequest,
AIBOMResponse,
GenerateRequest,
BatchRequest,
AIBOMResponse,
EnhancementReport
)
from .registry import get_field_registry_manager
Expand Down
52 changes: 52 additions & 0 deletions src/models/config_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
config.json parsing for HuggingFace model repositories.

Extracts hyperparameters using llama.cpp's find_hparam key fallback chains.
Works for any model format (safetensors, GGUF, pytorch, etc.) — the config.json
schema is format-agnostic.
"""
from typing import Dict, List, Optional, Union

# Exact key fallback order from llama.cpp convert_hf_to_gguf.py
# (see research.md section 12.4 for source line references)
HPARAM_KEYS: Dict[str, List[str]] = {
"block_count": ["n_layers", "num_hidden_layers", "n_layer", "num_layers"],
"context_length": ["max_position_embeddings", "n_ctx", "n_positions",
"max_length", "max_sequence_length", "model_max_length"],
"embedding_length": ["hidden_size", "n_embd", "dim"],
"feed_forward_length": ["intermediate_size", "n_inner", "hidden_dim"],
"attention_head_count": ["num_attention_heads", "n_head", "n_heads"],
"attention_head_count_kv": ["num_key_value_heads", "n_kv_heads"],
"rope_dimension_count": ["rotary_dim", "rope_dim"],
"vocab_size": ["vocab_size"],
"architecture": ["model_type"],
}


ParsedConfig = Dict[str, Optional[Union[str, int]]]


def parse_config(config: dict) -> ParsedConfig:
"""Extract hyperparameters from config.json using llama.cpp's find_hparam key fallback chains.

Handles VLM models that nest text params under text_config (llama.cpp L800-802).

Returns a dict with canonical keys (block_count, embedding_length, etc.)
and None for any fields not found in the config.
"""
# VLM merge: text_config values override root, mirroring llama.cpp
if "text_config" in config:
merged = dict(config)
merged.update(config["text_config"])
config = merged

result: ParsedConfig = {}
for canonical_name, candidate_keys in HPARAM_KEYS.items():
value = None
for key in candidate_keys:
if key in config:
value = config[key]
break
result[canonical_name] = value

return result
20 changes: 18 additions & 2 deletions src/models/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,18 @@
from .schemas import DataSource, ConfidenceLevel, ExtractionResult
from .registry import get_field_registry_manager
from .model_file_extractors import ModelFileExtractor, default_extractors
from .config_parsing import parse_config as _parse_hparams_from_config

logger = logging.getLogger(__name__)


def _build_hyperparameters_from_config(config_data: dict) -> Optional[str]:
"""Build hyperparameter JSON from config.json using llama.cpp key fallback chains."""
parsed = _parse_hparams_from_config(config_data)
hp = {k: v for k, v in parsed.items() if v is not None and k != "architecture"}
return json.dumps(hp) if hp else None


class EnhancedExtractor:
"""
Registry-integrated enhanced extractor that automatically picks up new fields
Expand Down Expand Up @@ -512,6 +521,13 @@ def _try_model_card_extraction(self, field_name: str, context: Dict[str, Any]) -

def _try_config_extraction(self, field_name: str, context: Dict[str, Any]) -> Any:
"""Try to extract field from configuration files"""
# Hyperparameter extraction from config.json using llama.cpp key fallback chains
if field_name == "hyperparameter":
config_data = context.get("config_data")
if config_data:
return _build_hyperparameters_from_config(config_data)
return None

# Config file mappings
config_mappings = {
'model_type': ('config_data', 'model_type'),
Expand All @@ -520,13 +536,13 @@ def _try_config_extraction(self, field_name: str, context: Dict[str, Any]) -> An
'tokenizer_class': ('tokenizer_config', 'tokenizer_class'),
'typeOfModel': ('config_data', 'model_type')
}

if field_name in config_mappings:
config_type, config_key = config_mappings[field_name]
config_source = context.get(config_type)
if config_source:
return config_source.get(config_key)

return None

def _try_text_pattern_extraction(self, field_name: str, context: Dict[str, Any]) -> Any:
Expand Down
37 changes: 28 additions & 9 deletions src/models/model_file_extractors.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import logging
from typing import Protocol, Dict, Any, List, runtime_checkable
from typing import Protocol, Dict, List, Union, runtime_checkable

from .gguf_metadata import fetch_gguf_metadata_from_repo, map_to_metadata
from huggingface_hub import list_repo_files

from .gguf_metadata import fetch_gguf_metadata_from_repo, map_to_metadata as gguf_map_to_metadata
from .safetensors_metadata import fetch_safetensors_metadata, map_to_metadata as st_map_to_metadata

logger = logging.getLogger(__name__)


@runtime_checkable
class ModelFileExtractor(Protocol):
def can_extract(self, model_id: str) -> bool: ...
def extract_metadata(self, model_id: str) -> Dict[str, Any]: ...
def extract_metadata(self, model_id: str) -> Dict[str, Union[str, int, dict]]: ...


class GGUFFileExtractor:

def can_extract(self, model_id: str) -> bool:
try:
from huggingface_hub import list_repo_files
return any(f.endswith(".gguf") for f in list_repo_files(model_id))
except Exception:
return False

def extract_metadata(self, model_id: str) -> Dict[str, Any]:
from huggingface_hub import list_repo_files

def extract_metadata(self, model_id: str) -> Dict[str, Union[str, int, dict]]:
try:
files = list_repo_files(model_id)
gguf_files = [f for f in files if f.endswith(".gguf")]
Expand All @@ -34,11 +34,30 @@ def extract_metadata(self, model_id: str) -> Dict[str, Any]:
if model_info is None:
return {}

return map_to_metadata(model_info)
return gguf_map_to_metadata(model_info)
except Exception as e:
logger.warning(f"GGUF extraction failed for {model_id}: {e}")
return {}


class SafetensorsFileExtractor:

def can_extract(self, model_id: str) -> bool:
try:
return any(f.endswith(".safetensors") for f in list_repo_files(model_id))
except Exception:
return False

def extract_metadata(self, model_id: str) -> Dict[str, Union[str, int, dict]]:
try:
info = fetch_safetensors_metadata(model_id)
if info is None:
return {}
return st_map_to_metadata(info)
except Exception as e:
logger.warning(f"Safetensors extraction failed for {model_id}: {e}")
return {}


def default_extractors() -> List[ModelFileExtractor]:
return [GGUFFileExtractor()]
return [SafetensorsFileExtractor(), GGUFFileExtractor()]
177 changes: 177 additions & 0 deletions src/models/safetensors_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""
Safetensors Metadata Extraction for AIBOM Generator

Extracts hyperparameters from safetensors repos by combining:
1. config.json — all hyperparameters (mirroring llama.cpp's find_hparam approach)
2. Safetensors headers — tensor info (parameter count, dtype distribution)

See research.md section 12 for full format specification and design rationale.
"""
import json
import math
import logging
from collections import Counter
from dataclasses import dataclass, field
from typing import Dict, Optional, Union

from huggingface_hub import hf_hub_download, HfApi
from huggingface_hub.errors import EntryNotFoundError

from .config_parsing import parse_config

logger = logging.getLogger(__name__)


@dataclass
class SafetensorsModelInfo:
"""Model information extracted from safetensors repo for AIBOM.

Parallels GGUFModelInfo but sources hyperparameters from config.json
(not from the safetensors header, which contains only tensor definitions).
"""
# From config.json (via parse_config, same source as llama.cpp)
architecture: Optional[str] = None
context_length: Optional[int] = None
embedding_length: Optional[int] = None
block_count: Optional[int] = None
attention_head_count: Optional[int] = None
attention_head_count_kv: Optional[int] = None
feed_forward_length: Optional[int] = None
rope_dimension_count: Optional[int] = None
vocab_size: Optional[int] = None
# From tokenizer_config.json
tokenizer_class: Optional[str] = None
# From safetensors headers (via get_safetensors_metadata)
total_parameters: Optional[int] = None
dtype_counts: Dict[str, int] = field(default_factory=dict)
user_metadata: Dict[str, str] = field(default_factory=dict)


TensorInfoResult = Dict[str, Union[int, Dict[str, int]]]


def _extract_tensor_info(tensors: dict) -> TensorInfoResult:
"""Extract parameter count and dtype distribution from safetensors tensor metadata.

tensors: dict mapping tensor name → object with .dtype and .shape attributes
(from huggingface_hub's SafetensorsFileMetadata.tensors or compatible mock).
"""
total_parameters = 0
dtype_counter: Counter = Counter()

for _, tensor in tensors.items():
shape = tensor.shape
param_count = math.prod(shape) if shape else 0
total_parameters += param_count
dtype_counter[tensor.dtype] += 1

return {
"total_parameters": total_parameters,
"dtype_counts": dict(dtype_counter),
}


MetadataValue = Union[str, int, Dict[str, int]]
MetadataDict = Dict[str, MetadataValue]


def map_to_metadata(info: SafetensorsModelInfo) -> MetadataDict:
"""Map SafetensorsModelInfo to the same dict format as gguf_metadata.map_to_metadata().

Output structure mirrors GGUF: model_type, typeOfModel, vocab_size, context_length
at top level; hyperparameter dict with non-None hyperparams; safetensors-specific fields.
"""
metadata: MetadataDict = {}

# Core fields (same as gguf_metadata._map_core_fields)
if info.architecture is not None:
metadata["model_type"] = info.architecture
metadata["typeOfModel"] = info.architecture

if info.vocab_size is not None:
metadata["vocab_size"] = info.vocab_size

if info.context_length is not None:
metadata["context_length"] = info.context_length

if info.tokenizer_class is not None:
metadata["tokenizer_class"] = info.tokenizer_class

# Hyperparameter dict (same as gguf_metadata._map_hyperparameters)
hyperparams: Dict[str, int] = {}
for field_name in (
"context_length", "embedding_length", "block_count",
"attention_head_count", "attention_head_count_kv",
"feed_forward_length", "rope_dimension_count",
):
value = getattr(info, field_name)
if value is not None:
hyperparams[field_name] = value

if hyperparams:
metadata["hyperparameter"] = hyperparams

# Safetensors-specific
if info.total_parameters is not None:
metadata["safetensors_total_parameters"] = info.total_parameters

return metadata


def fetch_safetensors_metadata(
repo_id: str, *, hf_token: Optional[str] = None
) -> Optional[SafetensorsModelInfo]:
"""Fetch config.json + safetensors headers from a HuggingFace repo.

Returns None if config.json is missing (can't extract hyperparameters).
Returns partial info if safetensors headers are unavailable.
"""
# Step 1: Fetch config.json (required)
try:
config_path = hf_hub_download(repo_id, "config.json", token=hf_token)
with open(config_path) as f:
config = json.load(f)
except Exception as e:
logger.warning(f"Could not fetch config.json for {repo_id}: {e}")
return None

parsed = parse_config(config)

info = SafetensorsModelInfo(
architecture=parsed.get("architecture"),
context_length=parsed.get("context_length"),
embedding_length=parsed.get("embedding_length"),
block_count=parsed.get("block_count"),
attention_head_count=parsed.get("attention_head_count"),
attention_head_count_kv=parsed.get("attention_head_count_kv"),
feed_forward_length=parsed.get("feed_forward_length"),
rope_dimension_count=parsed.get("rope_dimension_count"),
vocab_size=parsed.get("vocab_size"),
)

# Step 2: Fetch tokenizer_config.json (optional — adds tokenizer_class)
try:
tok_path = hf_hub_download(repo_id, "tokenizer_config.json", token=hf_token)
with open(tok_path) as f:
tok_config = json.load(f)
info.tokenizer_class = tok_config.get("tokenizer_class")
except Exception:
pass

# Step 3: Fetch safetensors headers (optional — adds tensor info)
try:
api = HfApi()
repo_meta = api.get_safetensors_metadata(repo_id, token=hf_token)

# Aggregate tensors across all shard files
all_tensors = {}
for file_meta in repo_meta.files_metadata.values():
all_tensors.update(file_meta.tensors)

tensor_info = _extract_tensor_info(all_tensors)
info.total_parameters = tensor_info["total_parameters"]
info.dtype_counts = tensor_info["dtype_counts"]
except Exception as e:
logger.info(f"No safetensors metadata for {repo_id}: {e}")

return info
Loading