-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy path__init__.py
More file actions
60 lines (46 loc) · 2.2 KB
/
__init__.py
File metadata and controls
60 lines (46 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import importlib
from typing import Optional, Union, Dict
from src import memory_systems
# Baseline classes are loaded lazily by SolverFactory.create. This keeps one
# baseline's optional dependency (for example mem0) from blocking unrelated
# baselines such as AutoSkill at import time.
def load_class(class_type):
module_path, class_name = class_type.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
class SolverFactory:
"""Solver factory backed by `src.memory_systems`.
The `method_to_class` dict is kept as a backward-compatible view; the
source of truth is the registry.
"""
@classmethod
def _build_method_to_class(cls):
return {
name: (spec.solver_class, spec.config_class)
for name, spec in memory_systems._REGISTRY.items()
}
# Snapshot at class-definition time so external code that reads
# SolverFactory.method_to_class continues to work.
method_to_class = None # populated below
@staticmethod
def _config_accepts(config_class, key: str) -> bool:
fields = getattr(config_class, "model_fields", None) or getattr(config_class, "__fields__", None)
if isinstance(fields, dict) and key in fields:
return True
init = getattr(config_class, "__init__", None)
code = getattr(init, "__code__", None)
return bool(code is not None and key in code.co_varnames)
@classmethod
def create(cls, method_name: str, config: Dict, **kwargs):
spec = memory_systems.get(method_name)
solver_class = load_class(spec.solver_class)
config_class = load_class(spec.config_class)
memory_cache_dir = kwargs.get("memory_cache_dir", None)
if memory_cache_dir is not None and cls._config_accepts(config_class, "memory_cache_dir"):
config["memory_cache_dir"] = memory_cache_dir
for key, value in kwargs.items():
if cls._config_accepts(config_class, key):
config[key] = value
agent_config = config_class(**config)
return solver_class(agent_config, memory_cache_dir=memory_cache_dir)
SolverFactory.method_to_class = SolverFactory._build_method_to_class()