From 4114ba54bd90e5c22566836f2a09f7990c5bd410 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 9 Nov 2025 03:10:23 +0000 Subject: [PATCH] Add confidence analysis function for explainability Added a new analyze_confidence() method to the ZeroFEC class that provides detailed insights into the correction process: - Breaks down entailment and ROUGE scores for the selected correction - Categorizes confidence levels (high/medium/low) - Ranks all candidate corrections with their scores - Provides warnings for potential issues (low confidence, significant changes, etc.) - Includes statistics across all candidates This enhances the framework's interpretability goal by making the correction decision process more transparent and explainable to users. Updated README.md with usage examples and documentation. --- README.md | 34 +++++++++++++- zerofec/zerofec.py | 112 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 144 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 146cf15..d22d39a 100644 --- a/README.md +++ b/README.md @@ -66,13 +66,45 @@ corrected_claim = zerofec.correct(sample) ``` The `corrected_claim` dictionary now contains a key `final_answer`, which is the final correction, as well as all intermediate outputs that provide interpretability. -Batch processing is supported via: +Batch processing is supported via: ```python zerofec.batch_correct(samples) ``` where `samples` is a list of dictionary. +### Confidence Analysis + +To gain deeper insights into the correction process and understand the model's confidence, use the `analyze_confidence` method: + +```python +# After running correction +corrected_claim = zerofec.correct(sample) + +# Analyze the confidence of the correction +confidence_analysis = zerofec.analyze_confidence(corrected_claim) + +print(f"Confidence Level: {confidence_analysis['confidence_level']}") +print(f"Entailment Score: {confidence_analysis['final_entailment_score']:.3f}") +print(f"ROUGE Score: {confidence_analysis['final_rouge_score']:.3f}") +print(f"Combined Score: {confidence_analysis['final_combined_score']:.3f}") + +# Check for warnings +if confidence_analysis['warnings']: + print("Warnings:", confidence_analysis['warnings']) + +# View all candidate corrections ranked by score +for candidate in confidence_analysis['all_candidates'][:3]: # Top 3 + print(f"{candidate['rank']}. {candidate['correction']} (score: {candidate['combined_score']:.3f})") +``` + +The `analyze_confidence` function provides: +- **Confidence level**: 'high' (≥0.8), 'medium' (0.5-0.8), or 'low' (<0.5) based on entailment score +- **Score breakdown**: Entailment, ROUGE, and combined scores for the selected correction +- **All candidates**: Ranked list of all correction candidates with their scores +- **Warnings**: Potential issues like low confidence, significant changes, or model uncertainty +- **Statistics**: Aggregated metrics across all candidates + For additional information about `model_args`, please refer to `main.py`. To run prediction on the two datasets we used for our experiments, use the following diff --git a/zerofec/zerofec.py b/zerofec/zerofec.py index b8b1bd9..e00088f 100644 --- a/zerofec/zerofec.py +++ b/zerofec/zerofec.py @@ -5,6 +5,7 @@ from .models.entailment_model import EntailmentModel from tqdm import tqdm from typing import List, Dict +import numpy as np import nltk nltk.download('punkt') @@ -44,4 +45,113 @@ def correct(self, sample: Dict): def batch_correct(self, samples: List[Dict]): return [self.correct(sample) for sample in tqdm(samples, total=len(samples))] - \ No newline at end of file + + def analyze_confidence(self, sample: Dict) -> Dict: + ''' + Analyzes the confidence scores of a corrected sample to provide explainability. + + Args: + sample: Dict containing correction results with fields: + - input_claim: original claim + - final_answer: selected correction + - correction: list of candidate corrections + - correction_scores: entailment scores for each correction + - rouge_scores: ROUGE similarity scores to original claim + - evidence: list of evidence passages + + Returns: + Dict containing confidence analysis with fields: + - final_correction: the selected correction + - final_entailment_score: entailment confidence (0-1) + - final_rouge_score: ROUGE similarity to original (0-1) + - final_combined_score: combined score used for selection + - confidence_level: 'high', 'medium', or 'low' + - all_candidates: list of all corrections with their scores + - is_changed: whether the claim was modified + - warnings: list of potential issues + - statistics: overall score statistics + ''' + + if 'final_answer' not in sample: + return { + 'error': 'Sample has not been processed through correction pipeline', + 'suggestion': 'Run correct() or batch_correct() first' + } + + analysis = { + 'final_correction': sample['final_answer'], + 'original_claim': sample['input_claim'], + 'is_changed': sample['final_answer'] != sample['input_claim'] + } + + # Find which correction was selected + if sample.get('correction_scores') and sample.get('rouge_scores'): + corrections = sample['correction'] + entailment_scores = sample['correction_scores'] + rouge_scores = sample['rouge_scores'] + combined_scores = np.array(entailment_scores) + np.array(rouge_scores) + + selected_idx = np.argmax(combined_scores) + + # Final scores + analysis['final_entailment_score'] = float(entailment_scores[selected_idx]) + analysis['final_rouge_score'] = float(rouge_scores[selected_idx]) + analysis['final_combined_score'] = float(combined_scores[selected_idx]) + + # Confidence level assessment + entailment_score = analysis['final_entailment_score'] + if entailment_score >= 0.8: + analysis['confidence_level'] = 'high' + elif entailment_score >= 0.5: + analysis['confidence_level'] = 'medium' + else: + analysis['confidence_level'] = 'low' + + # All candidates with scores + analysis['all_candidates'] = [] + for i, (corr, ent, rou, comb) in enumerate(zip(corrections, entailment_scores, rouge_scores, combined_scores)): + analysis['all_candidates'].append({ + 'rank': i + 1, + 'correction': corr, + 'entailment_score': float(ent), + 'rouge_score': float(rou), + 'combined_score': float(comb), + 'is_selected': i == selected_idx + }) + + # Sort by combined score + analysis['all_candidates'].sort(key=lambda x: x['combined_score'], reverse=True) + + # Statistics + analysis['statistics'] = { + 'num_candidates': len(corrections), + 'avg_entailment_score': float(np.mean(entailment_scores)), + 'max_entailment_score': float(np.max(entailment_scores)), + 'min_entailment_score': float(np.min(entailment_scores)), + 'avg_rouge_score': float(np.mean(rouge_scores)), + 'score_variance': float(np.var(combined_scores)) + } + + # Warnings + analysis['warnings'] = [] + + if analysis['final_entailment_score'] < 0.5: + analysis['warnings'].append('Low entailment confidence: correction may not be well-supported by evidence') + + if analysis['final_rouge_score'] > 0.9 and analysis['is_changed']: + analysis['warnings'].append('High ROUGE score despite change: correction may be nearly identical to original') + + if analysis['final_rouge_score'] < 0.3: + analysis['warnings'].append('Low ROUGE score: correction differs significantly from original claim') + + if analysis['statistics']['score_variance'] < 0.01: + analysis['warnings'].append('Low score variance: model may be uncertain between multiple candidates') + + if not analysis['is_changed']: + analysis['warnings'].append('No correction made: original claim retained (may indicate high original accuracy or low confidence in alternatives)') + + else: + analysis['error'] = 'No correction scores found in sample' + analysis['warnings'] = ['Sample may not have completed the entailment scoring step'] + + return analysis