-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_cache.py
More file actions
155 lines (124 loc) · 5.5 KB
/
model_cache.py
File metadata and controls
155 lines (124 loc) · 5.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Per-process model cache for the Basil model.
In parallel generation, each worker process loads the model once and reuses
it across all sessions within that process. This avoids reloading from disk
on every session (~2-5 seconds saved per session).
The cache is keyed on the model path, so if the path ever changes between
sessions (shouldn't happen during a generation run), it reloads automatically.
Thread safety: within a single worker process, sessions run sequentially,
so no concurrent access to the cached model.
"""
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import LORA_ADAPTER_SUBDIR, LORA_ACTIVATION_AGE_BAND
# Module-level cache (per-process due to multiprocessing spawn)
_cached_model = None
_cached_tokenizer = None
_cached_device = None
_cached_model_path = None
_cached_has_lora = False # Whether the cached model has a LoRA adapter loaded
def get_cached_model(model_path: str, verbose: bool = False, age_band: int = 0):
"""
Return (model, tokenizer, device), loading from disk only on the first
call or if the model path has changed.
When age_band < LORA_ACTIVATION_AGE_BAND, the LoRA adapter (if loaded)
is disabled so the trunk runs clean during the "absorption" phase.
Args:
model_path: Path to the Basil model directory.
verbose: Print loading messages.
age_band: Basil's current developmental age band (0-7).
Returns:
Tuple of (model, tokenizer, device_str).
"""
global _cached_model, _cached_tokenizer, _cached_device, _cached_model_path, _cached_has_lora
if _cached_model is not None and _cached_model_path == model_path:
if verbose:
print(f"[ModelCache] Reusing cached model: {model_path}")
# Toggle LoRA based on current age_band (may change between calls)
_toggle_lora(_cached_model, age_band, verbose)
return _cached_model, _cached_tokenizer, _cached_device
if verbose:
print(f"[ModelCache] Loading Basil model: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(model_path)
# Load LoRA adapter if it exists (Basil-policy objective)
adapter_path = os.path.join(model_path, LORA_ADAPTER_SUBDIR)
has_lora = False
if os.path.exists(adapter_path):
if verbose:
print(f"[ModelCache] Loading LoRA adapter from {adapter_path}")
from peft import PeftModel
model = PeftModel.from_pretrained(base_model, adapter_path)
has_lora = True
else:
model = base_model
model.eval()
# Clear generation_config to avoid conflicts with direct generate() kwargs
if hasattr(model, 'generation_config'):
model.generation_config.max_length = None
model.generation_config.max_new_tokens = None
model.generation_config.top_p = None
if torch.cuda.is_available():
model.to("cuda")
device = "cuda"
else:
device = "cpu"
# Store in cache
_cached_model = model
_cached_tokenizer = tokenizer
_cached_device = device
_cached_model_path = model_path
_cached_has_lora = has_lora
# Toggle LoRA based on age_band
_toggle_lora(model, age_band, verbose)
return model, tokenizer, device
def _toggle_lora(model, age_band: int, verbose: bool = False):
"""Enable or disable LoRA adapter layers based on age_band."""
global _cached_has_lora
if not _cached_has_lora:
return # No LoRA adapter loaded, nothing to toggle
if age_band < LORA_ACTIVATION_AGE_BAND:
# Absorption phase: disable LoRA so trunk runs clean
if hasattr(model, 'disable_adapter_layers'):
model.disable_adapter_layers()
if verbose:
print(f"[ModelCache] LoRA DISABLED (age_band={age_band} < {LORA_ACTIVATION_AGE_BAND})")
else:
# LoRA active phase
if hasattr(model, 'enable_adapter_layers'):
model.enable_adapter_layers()
if verbose:
print(f"[ModelCache] LoRA ENABLED (age_band={age_band} >= {LORA_ACTIVATION_AGE_BAND})")
def set_lora_strength(model, strength: float, verbose: bool = False):
"""
Scale the LoRA adapter contribution.
Args:
model: A PeftModel (or plain model — no-op if no LoRA layers).
strength: 0.0 = trunk only, 1.0 = full LoRA, 0.5 = half LoRA.
Uses PEFT's set_scale API on each LoRA layer.
verbose: Print diagnostics.
The PEFT set_scale method sets: scaling[adapter] = strength * (alpha / r).
So strength=1.0 restores the default LoRA contribution, and strength=0.0
zeroes it out (equivalent to trunk-only inference).
"""
global _cached_has_lora
if not _cached_has_lora:
if verbose:
print(f"[ModelCache] set_lora_strength({strength}) — no LoRA adapter loaded, skipping")
return
# Walk all LoRA layers and set their scale
n_layers = 0
try:
for module in model.modules():
if hasattr(module, 'set_scale') and hasattr(module, 'scaling'):
for adapter_name in list(module.scaling.keys()):
module.set_scale(adapter_name, strength)
n_layers += 1
except Exception as e:
print(f"[ModelCache] WARNING: Failed to set LoRA strength: {e}")
return
if verbose:
print(f"[ModelCache] LoRA strength set to {strength} across {n_layers} layers")