Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,45 @@ corrected_claim = zerofec.correct(sample)
```
The `corrected_claim` dictionary now contains a key `final_answer`, which is the final correction, as well as all intermediate outputs that provide interpretability.

Batch processing is supported via:
Batch processing is supported via:

```python
zerofec.batch_correct(samples)
```
where `samples` is a list of dictionary.

### Confidence Analysis

To gain deeper insights into the correction process and understand the model's confidence, use the `analyze_confidence` method:

```python
# After running correction
corrected_claim = zerofec.correct(sample)

# Analyze the confidence of the correction
confidence_analysis = zerofec.analyze_confidence(corrected_claim)

print(f"Confidence Level: {confidence_analysis['confidence_level']}")
print(f"Entailment Score: {confidence_analysis['final_entailment_score']:.3f}")
print(f"ROUGE Score: {confidence_analysis['final_rouge_score']:.3f}")
print(f"Combined Score: {confidence_analysis['final_combined_score']:.3f}")

# Check for warnings
if confidence_analysis['warnings']:
print("Warnings:", confidence_analysis['warnings'])

# View all candidate corrections ranked by score
for candidate in confidence_analysis['all_candidates'][:3]: # Top 3
print(f"{candidate['rank']}. {candidate['correction']} (score: {candidate['combined_score']:.3f})")
```

The `analyze_confidence` function provides:
- **Confidence level**: 'high' (≥0.8), 'medium' (0.5-0.8), or 'low' (<0.5) based on entailment score
- **Score breakdown**: Entailment, ROUGE, and combined scores for the selected correction
- **All candidates**: Ranked list of all correction candidates with their scores
- **Warnings**: Potential issues like low confidence, significant changes, or model uncertainty
- **Statistics**: Aggregated metrics across all candidates

For additional information about `model_args`, please refer to `main.py`.

To run prediction on the two datasets we used for our experiments, use the following
Expand Down
112 changes: 111 additions & 1 deletion zerofec/zerofec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .models.entailment_model import EntailmentModel
from tqdm import tqdm
from typing import List, Dict
import numpy as np

import nltk
nltk.download('punkt')
Expand Down Expand Up @@ -44,4 +45,113 @@ def correct(self, sample: Dict):
def batch_correct(self, samples: List[Dict]):

return [self.correct(sample) for sample in tqdm(samples, total=len(samples))]


def analyze_confidence(self, sample: Dict) -> Dict:
'''
Analyzes the confidence scores of a corrected sample to provide explainability.

Args:
sample: Dict containing correction results with fields:
- input_claim: original claim
- final_answer: selected correction
- correction: list of candidate corrections
- correction_scores: entailment scores for each correction
- rouge_scores: ROUGE similarity scores to original claim
- evidence: list of evidence passages

Returns:
Dict containing confidence analysis with fields:
- final_correction: the selected correction
- final_entailment_score: entailment confidence (0-1)
- final_rouge_score: ROUGE similarity to original (0-1)
- final_combined_score: combined score used for selection
- confidence_level: 'high', 'medium', or 'low'
- all_candidates: list of all corrections with their scores
- is_changed: whether the claim was modified
- warnings: list of potential issues
- statistics: overall score statistics
'''

if 'final_answer' not in sample:
return {
'error': 'Sample has not been processed through correction pipeline',
'suggestion': 'Run correct() or batch_correct() first'
}

analysis = {
'final_correction': sample['final_answer'],
'original_claim': sample['input_claim'],
'is_changed': sample['final_answer'] != sample['input_claim']
}

# Find which correction was selected
if sample.get('correction_scores') and sample.get('rouge_scores'):
corrections = sample['correction']
entailment_scores = sample['correction_scores']
rouge_scores = sample['rouge_scores']
combined_scores = np.array(entailment_scores) + np.array(rouge_scores)

selected_idx = np.argmax(combined_scores)

# Final scores
analysis['final_entailment_score'] = float(entailment_scores[selected_idx])
analysis['final_rouge_score'] = float(rouge_scores[selected_idx])
analysis['final_combined_score'] = float(combined_scores[selected_idx])

# Confidence level assessment
entailment_score = analysis['final_entailment_score']
if entailment_score >= 0.8:
analysis['confidence_level'] = 'high'
elif entailment_score >= 0.5:
analysis['confidence_level'] = 'medium'
else:
analysis['confidence_level'] = 'low'

# All candidates with scores
analysis['all_candidates'] = []
for i, (corr, ent, rou, comb) in enumerate(zip(corrections, entailment_scores, rouge_scores, combined_scores)):
analysis['all_candidates'].append({
'rank': i + 1,
'correction': corr,
'entailment_score': float(ent),
'rouge_score': float(rou),
'combined_score': float(comb),
'is_selected': i == selected_idx
})

# Sort by combined score
analysis['all_candidates'].sort(key=lambda x: x['combined_score'], reverse=True)

# Statistics
analysis['statistics'] = {
'num_candidates': len(corrections),
'avg_entailment_score': float(np.mean(entailment_scores)),
'max_entailment_score': float(np.max(entailment_scores)),
'min_entailment_score': float(np.min(entailment_scores)),
'avg_rouge_score': float(np.mean(rouge_scores)),
'score_variance': float(np.var(combined_scores))
}

# Warnings
analysis['warnings'] = []

if analysis['final_entailment_score'] < 0.5:
analysis['warnings'].append('Low entailment confidence: correction may not be well-supported by evidence')

if analysis['final_rouge_score'] > 0.9 and analysis['is_changed']:
analysis['warnings'].append('High ROUGE score despite change: correction may be nearly identical to original')

if analysis['final_rouge_score'] < 0.3:
analysis['warnings'].append('Low ROUGE score: correction differs significantly from original claim')

if analysis['statistics']['score_variance'] < 0.01:
analysis['warnings'].append('Low score variance: model may be uncertain between multiple candidates')

if not analysis['is_changed']:
analysis['warnings'].append('No correction made: original claim retained (may indicate high original accuracy or low confidence in alternatives)')

else:
analysis['error'] = 'No correction scores found in sample'
analysis['warnings'] = ['Sample may not have completed the entailment scoring step']

return analysis