This guide explains how to finetune embedding models following the ZEDD paper (arXiv:2601.12359v1) to improve AgentShield's detection accuracy.
The ZEDD algorithm detects prompt injections by measuring "drift" between original and cleaned text embeddings. Finetuning the embedding model improves this drift detection.
| Configuration | Accuracy | Cost | Latency |
|---|---|---|---|
| Base embeddings + heuristic cleaner | ~70% | Free | Fast |
| Base embeddings + LLM cleaner | ~80% | ~$0.0003/doc | Medium |
| Finetuned embeddings + LLM cleaner | ~90% | ~$0.0003/doc | Medium |
pip install datasets openai sentence-transformers transformers accelerate tqdm scikit-learnThe script uses GPT-4o-mini for text cleaning. You'll need ~$5-10 in OpenAI credits.
export OPENAI_API_KEY=sk-...python scripts/finetune_local.pyThe script will:
- Load the LLMail-Inject dataset
- Remove duplicates and filter by length
- Clean injected text with OpenAI API
- Generate clean-clean pairs
- Finetune MPNet with CosineSimilarityLoss
- Calibrate threshold using GMM (paper Section 4.3)
- Evaluate and export
from pyagentshield import AgentShield
shield = AgentShield(config={
"embeddings": {
"provider": "local",
"model": "./agentshield-embeddings-finetuned",
},
"cleaning": {
"method": "llm",
"llm_model": "gpt-4o-mini",
},
"zedd": {
"threshold": None, # Auto-loads from model's calibration.json
},
})
result = shield.scan("Some text to scan...")# Default settings (5000 samples, 3 epochs)
python scripts/finetune_local.py
# Custom settings
python scripts/finetune_local.py \
--max-samples 10000 \
--output-dir ./my-model \
--epochs 5 \
--batch-size 32| Option | Default | Description |
|---|---|---|
--max-samples |
5000 | Maximum samples to process |
--output-dir |
./agentshield-embeddings-finetuned |
Where to save the model |
--cache-dir |
./cache |
Cache directory for resuming |
--epochs |
3 | Number of training epochs |
--batch-size |
16 | Training batch size |
The paper uses two types of training pairs:
-
Injected-Clean pairs (label = 0.0)
- Original: Text with prompt injection
- Cleaned: Same text with injection removed by LLM
- Expected: High embedding drift
-
Clean-Clean pairs (label = 1.0)
- Original: Clean professional email
- Cleaned: Same email rephrased
- Expected: Low embedding drift
from sentence_transformers import losses
# CosineSimilarityLoss trains the model so that:
# - Similar pairs (clean-clean) have cosine similarity → 1.0
# - Dissimilar pairs (injected-clean) have cosine similarity → 0.0
train_loss = losses.CosineSimilarityLoss(model)The paper uses a 2-component GMM to find the optimal threshold:
- Fit GMM on drift scores
- Identify components: Lower mean = clean, higher mean = injected
- Find intersection: Where
w_clean * f_clean(x) = w_inject * f_inject(x) - Apply FP cap: Binary search to achieve ≤3% false positive rate
# pyagentshield.yaml
embeddings:
provider: local
model: ./agentshield-embeddings-finetuned
cleaning:
method: llm
llm_model: gpt-4o-mini
zedd:
threshold: null # Auto-load from calibration.json
behavior:
on_detect: flagexport AGENTSHIELD_EMBEDDINGS__MODEL=./agentshield-embeddings-finetuned
export AGENTSHIELD_CLEANING__METHOD=llm
export AGENTSHIELD_CLEANING__LLM_MODEL=gpt-4o-miniThe script uses sentence-transformers/all-mpnet-base-v2:
| Property | Value |
|---|---|
| Dimensions | 768 |
| Parameters | 110M |
| Speed | Medium |
| Quality | Best for sentence similarity |
Training:
- CPU: Works but slow
- GPU/MPS: Recommended
- RAM: 8GB+
- Time: ~1-2 hours (5K samples)
Inference:
- CPU: Works fine for small batches
- GPU/MPS: Recommended for production
| Phase | Cost |
|---|---|
| Cleaning injected text (~5K samples) | ~$1.50 |
| Generating clean pairs (~5K pairs) | ~$1.50 |
| Training (local) | Free |
| Total one-time | ~$3-5 |
| Inference per document | ~$0.0003 |
The script caches intermediate results in --cache-dir:
cache/
├── cleaned_injected.json # Cleaned injected samples
└── clean_pairs.json # Generated clean-clean pairs
If the script is interrupted, it will resume from where it left off.
- Reduce
--batch-size(try 8 or 4) - The script uses MPS on Mac automatically
- Ensure the output directory contains all model files
- Check
config.jsonexists in the output directory
- Ensure you trained with enough data (5K+ samples)
- Check if threshold was properly calibrated
- Verify
calibration.jsonexists in the model directory
- Check your OpenAI API key is valid
- Ensure you have sufficient credits
- The script checkpoints every 100 samples, so you can resume
from pyagentshield import AgentShield
# Auto-loads threshold from model's calibration.json
shield = AgentShield(config={
"embeddings": {
"model": "./agentshield-embeddings-finetuned",
},
})
# Scan documents
result = shield.scan("Document text here")
print(f"Suspicious: {result.is_suspicious}")
print(f"Confidence: {result.confidence:.2%}")shield = AgentShield(config={
"embeddings": {
"model": "./agentshield-embeddings-finetuned",
},
"zedd": {
"threshold": 0.25, # Override calibrated threshold
},
})- ZEDD Paper: arXiv:2601.12359v1 - "Zero-Shot Embedding Drift Detection"
- LLMail-Inject: https://huggingface.co/datasets/microsoft/llmail-inject-challenge
- sentence-transformers: https://www.sbert.net/
- GMM Threshold: Paper Section 4.3