Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
712ad60
clarify that output_orig can be a single unit or multiple units (but …
dennislwei Aug 26, 2025
7cd863a
handle segmented reference outputs (multiple output units) in ProbSca…
dennislwei Aug 26, 2025
4b5018c
update test case in test_prob_scalarized_model.py
dennislwei Aug 27, 2025
0e54a23
handle multiple output units in CLIME (fit_linear_model) with ProbSca…
dennislwei Aug 27, 2025
5c1ebd1
handle multiple output units in LSHAP with ProbScalarizedModel only
dennislwei Aug 27, 2025
a53ff76
update PerturbCurveEvaluator
dennislwei Aug 27, 2025
fb2ce85
update MExGen notebooks to save output token IDs and pass them to
dennislwei Aug 29, 2025
bfbb47c
handle multiple output units/score columns in PerturbCurveEvaluator w…
dennislwei Sep 2, 2025
d3b0128
segment output text if desired
dennislwei Sep 12, 2025
55df9af
merge non-alphanumeric output units into adjacent units
dennislwei Sep 13, 2025
d1af6ea
fix finding token boundaries of output units
dennislwei Sep 16, 2025
ae36906
filter out output units with zero length in terms of tokens
dennislwei Sep 16, 2025
ffca359
further fix finding token boundaries of output units
dennislwei Sep 29, 2025
92caefe
Revert "filter out output units with zero length in terms of tokens"
dennislwei Sep 29, 2025
fe34b3c
handle output units with zero length by skipping the mean over tokens
dennislwei Sep 29, 2025
e8110e6
squeeze importance score arrays only if num_output_units==1
dennislwei Sep 29, 2025
e15f1f7
fix if-else structure
dennislwei Sep 29, 2025
51fd0c5
find_unit_boundaries(): keep idx_token in range
dennislwei Oct 2, 2025
278ed04
squeeze only axis=1 of importance score arrays
dennislwei Oct 4, 2025
d5db662
_compute_log_probs_vllm: break large batches into multiple calls
dennislwei Nov 2, 2025
4bdfa81
MExGenExplainer.segment_output: fix indenting
dennislwei Nov 14, 2025
96cafa4
a few comment clarifications and formatting fixes
dennislwei Jan 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 98 additions & 129 deletions examples/mexgen/RAG.ipynb

Large diffs are not rendered by default.

160 changes: 80 additions & 80 deletions examples/mexgen/question_answering.ipynb

Large diffs are not rendered by default.

799 changes: 403 additions & 396 deletions examples/mexgen/summarization.ipynb

Large diffs are not rendered by default.

95 changes: 62 additions & 33 deletions icx360/algorithms/mexgen/clime.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CLIME(MExGenExplainer):
based on the model's inputs or outputs.
"""
def explain_instance(self, input_orig, unit_types="p", output_orig=None,
ind_segment=True, segment_type="s", max_phrase_length=10,
ind_segment=True, segment_type="s", max_phrase_length=10, segment_type_output=None,
model_params={}, scalarize_params={},
oversampling_factor=10, max_units_replace=2, empty_subset=True, replacement_str="",
num_nonzeros=None, debias=True):
Expand All @@ -51,14 +51,18 @@ def explain_instance(self, input_orig, unit_types="p", output_orig=None,
"n" for not to be perturbed/attributed to.
If str, applies to all units in input_orig, otherwise unit-specific.
output_orig (str or List[str] or icx360.utils.model_wrappers.GeneratedOutput or None):
[output] Output for original input if provided, otherwise None.
[output] Output for original input.
Can be a single unit (str), segmented into units (List[str]), a GeneratedOutput object, or None.
ind_segment (bool or List[bool]):
[segmentation] Whether to segment input text.
If bool, applies to all units; if List[bool], applies to each unit individually.
segment_type (str):
[segmentation] Type of units to segment into: "s" for sentences, "w" for words, "ph" for phrases.
max_phrase_length (int):
[segmentation] Maximum phrase length in terms of spaCy tokens (default 10).
segment_type_output (str or None):
[segmentation] Type of units to segment output text into:
"s" for sentences, "ph" for phrases, None for no segmentation.
model_params (dict):
Additional keyword arguments for model generation (for the self.model.generate() method).
scalarize_params (dict):
Expand Down Expand Up @@ -101,6 +105,8 @@ def explain_instance(self, input_orig, unit_types="p", output_orig=None,

# 2) Generate output for original input or wrap provided output
output_orig = self.generate_or_wrap_output(input_orig, output_orig, model_params)
# Segment output text if needed
output_orig = self.segment_output(output_orig, segment_type_output, max_phrase_length)

# 3) Enumerate subsets of units that will be perturbed/replaced
idx_replace = (np.array(unit_types) != "n").nonzero()[0]
Expand Down Expand Up @@ -130,7 +136,7 @@ def explain_instance(self, input_orig, unit_types="p", output_orig=None,
coef[key], intercept[key], num_nonzeros_out[key] = fit_linear_model(features, target[key].cpu().numpy(), subset_weights, num_nonzeros, debias)

else:
# Single target vector
# Single target array (could contain multiple columns)
coef, intercept, num_nonzeros_out = fit_linear_model(features, target.cpu().numpy(), subset_weights, num_nonzeros, debias)

# 8) Construct output dictionary
Expand Down Expand Up @@ -186,8 +192,8 @@ def fit_linear_model(features, target, sample_weights, num_nonzeros, debias):
Args:
features ((num_perturb, num_units) np.ndarray):
Feature values.
target ((num_perturb,) np.ndarray):
Target values to predict.
target ((num_perturb,) or (num_perturb, num_output_units) np.ndarray):
Target values to predict (one column for each output unit).
sample_weights ((num_perturb,) np.ndarray):
Sample weights.
num_nonzeros (int or None):
Expand All @@ -196,51 +202,74 @@ def fit_linear_model(features, target, sample_weights, num_nonzeros, debias):
Refit linear model with no penalty after selecting features.

Returns:
coef ((num_units,) np.ndarray):
Coefficients of linear model.
intercept (float):
Intercept of linear model.
num_nonzeros (int):
Actual number of non-zero coefficients.
coef ((num_units,) or (num_units, num_output_units) np.ndarray):
Coefficients of linear model(s) (one per output unit).
intercept (float or (num_output_units,) np.ndarray):
Intercept(s) of linear model(s) (one per output unit).
num_nonzeros (List[int]):
Actual numbers of non-zero coefficients.
"""
num_units = features.shape[1]
# Promote target array to 2D if needed
target = target[:, None] if target.ndim == 1 else target
num_output_units = target.shape[1]

if num_nonzeros is None:
# Fit dense linear model over the units that were perturbed (`active`)
active = features.any(axis=0).nonzero()[0]
coef = np.zeros(num_units)
coef = np.zeros((num_units, num_output_units))
lr = LinearRegression()
lr.fit(features[:, active], target, sample_weight=sample_weights)
coef[active] = lr.coef_
coef[active, :] = lr.coef_.T
intercept = lr.intercept_

else:
# Fit sparse linear model

# Center feature and target values
features_mean = features.mean(axis=0)
target_mean = target.mean()
target_mean = target.mean(axis=0)
features_centered = features - features_mean
target_centered = target - target_mean

# Call lars_path to obtain sparse linear model with num_nonzeros coefficients
# NOTE: may return fewer than num_nonzeros if coefficients leave the active set
alphas, active, coef = lars_path(np.sqrt(sample_weights)[:, None] * features_centered, np.sqrt(sample_weights) * target_centered, max_iter=num_nonzeros, method="lasso", return_path=False)

if debias:
coef = np.zeros(num_units)
if len(active):
# Refit linear model on selected features with no penalty
lr = LinearRegression()
lr.fit(features[:, active], target, sample_weight=sample_weights)
coef[active] = lr.coef_
intercept = lr.intercept_
# Initialize outputs
coef = np.zeros((num_units, num_output_units))
intercept = np.zeros(num_output_units)
active = [None] * num_output_units

# Iterate over output units
for u in range(num_output_units):
# Call lars_path to obtain sparse linear model with num_nonzeros coefficients
# NOTE: may return fewer than num_nonzeros if coefficients leave the active set
alphas, active[u], coef[:, u] = lars_path(np.sqrt(sample_weights)[:, None] * features_centered,
np.sqrt(sample_weights) * target_centered[:, u],
max_iter=num_nonzeros,
method="lasso",
return_path=False)

if debias:
coef[:, u] = np.zeros(num_units)
if len(active[u]):
# Refit linear model on selected features with no penalty
lr = LinearRegression()
lr.fit(features[:, active[u]], target[:, u], sample_weight=sample_weights)
coef[active[u], u] = lr.coef_
intercept[u] = lr.intercept_
else:
# No active set, coefficients all zero
intercept[u] = target_mean[u]
else:
# No active set, coefficients all zero
intercept = target_mean
else:
# Compute intercept to account for centering
intercept = target_mean - coef @ features_mean

# Compute intercept to account for centering
intercept[u] = target_mean[u] - coef[:, u] @ features_mean

if num_output_units == 1:
coef, intercept = coef.squeeze(axis=1), intercept.squeeze()
# Actual number(s) of non-zero coefficients
if type(active[0]) is int:
# Single active set (single list of indices) so number of non-zeros is same for all output units
num_nonzeros = [len(active)] * num_output_units
else:
# Multiple active sets, one for each output unit
num_nonzeros = map(len, active)
# Negate coefficients so that important units have positive coefficients
return -coef, intercept, len(active)
return -coef, intercept, num_nonzeros
23 changes: 16 additions & 7 deletions icx360/algorithms/mexgen/lshap.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class LSHAP(MExGenExplainer):
based on the model's inputs or outputs.
"""
def explain_instance(self, input_orig, unit_types="p", ind_interest=None, output_orig=None,
ind_segment=True, segment_type="s", max_phrase_length=10,
ind_segment=True, segment_type="s", max_phrase_length=10, segment_type_output=None,
model_params={}, scalarize_params={},
num_neighbors=2, max_units_replace=2, replacement_str=""):
"""
Expand All @@ -53,14 +53,18 @@ def explain_instance(self, input_orig, unit_types="p", ind_interest=None, output
[input] Indicator of units to attribute to ("of interest").
Default None means np.array(unit_types) != "n".
output_orig (str or List[str] or icx360.utils.model_wrappers.GeneratedOutput or None):
[output] Output for original input if provided, otherwise None.
[output] Output for original input.
Can be a single unit (str), segmented into units (List[str]), a GeneratedOutput object, or None.
ind_segment (bool or List[bool]):
[segmentation] Whether to segment input text.
If bool, applies to all units; if List[bool], applies to each unit individually.
segment_type (str):
[segmentation] Type of units to segment into: "s" for sentences, "w" for words, "ph" for phrases.
max_phrase_length (int):
[segmentation] Maximum phrase length in terms of spaCy tokens (default 10).
segment_type_output (str or None):
[segmentation] Type of units to segment output text into:
"s" for sentences, "ph" for phrases, None for no segmentation.
model_params (dict):
Additional keyword arguments for model generation (for the self.model.generate() method).
scalarize_params (dict):
Expand Down Expand Up @@ -106,6 +110,9 @@ def explain_instance(self, input_orig, unit_types="p", ind_interest=None, output

# 2) Generate output for original input or wrap provided output
output_orig = self.generate_or_wrap_output(input_orig, output_orig, model_params)
# Segment output text if needed
output_orig = self.segment_output(output_orig, segment_type_output, max_phrase_length)
num_output_units = 1 if type(output_orig.output_text[0]) is str else len(output_orig.output_text[0])

# 3) Initialize quantities
# Initialize importance scores
Expand All @@ -115,7 +122,7 @@ def explain_instance(self, input_orig, unit_types="p", ind_interest=None, output
for key in self.scalarized_model.sim_scores:
importance_scores[key] = np.zeros(num_units)
else:
importance_scores = np.zeros(num_units)
importance_scores = np.zeros((num_units, num_output_units))

# Initialize quantities associated with units of interest
idx_replace_i = [None] * len(idx_interest)
Expand Down Expand Up @@ -187,21 +194,23 @@ def explain_instance(self, input_orig, unit_types="p", ind_interest=None, output
importance_scores[key][idx_interest[i]] = np.inner(scalar_outputs_excl_interest - scalar_outputs_incl_interest, 1 / normalization)

else:
# Extract scalarized output corresponding to original input/empty subset
scalar_output_orig = scalar_outputs[0].item()
# Extract scalarized output(s) corresponding to original input/empty subset
scalar_output_orig = scalar_outputs[[0]].cpu().numpy()
# Extract scalarized outputs for this unit of interest
scalar_outputs_excl_interest = scalar_outputs[idx_excl_interest].cpu().numpy()
scalar_outputs_incl_interest = scalar_outputs[idx_incl_interest].cpu().numpy()
# Prepend output corresponding to empty subset
scalar_outputs_excl_interest = np.append(scalar_output_orig, scalar_outputs_excl_interest)
scalar_outputs_excl_interest = np.append(scalar_output_orig, scalar_outputs_excl_interest, axis=0)

# 9) Compute Shapley values
normalization = get_normalization_constants(len(idx_replace_i[i]), max_units_replace) * (max_units_replace + 1)
importance_scores[idx_interest[i]] = np.inner(scalar_outputs_excl_interest - scalar_outputs_incl_interest, 1 / normalization)
importance_scores[idx_interest[i]] = np.dot(1 / normalization, scalar_outputs_excl_interest - scalar_outputs_incl_interest)

# 10) Construct output dictionary
if type(importance_scores) is not dict:
# Convert importance_scores to dictionary
if num_output_units == 1:
importance_scores = importance_scores.squeeze(axis=1)
if isinstance(self.scalarized_model, ProbScalarizedModel):
# Label scores with type of scalarizer
importance_scores = {"prob": importance_scores}
Expand Down
38 changes: 32 additions & 6 deletions icx360/algorithms/mexgen/mexgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from icx360.algorithms.lbbe import LocalBBExplainer
from icx360.utils.model_wrappers import GeneratedOutput, HFModel
from icx360.utils.scalarizers import ProbScalarizedModel, TextScalarizedModel
from icx360.utils.segmenters import SpaCySegmenter, exclude_non_alphanumeric
from icx360.utils.segmenters import SpaCySegmenter, exclude_non_alphanumeric, merge_non_alphanumeric


class MExGenExplainer(LocalBBExplainer):
Expand Down Expand Up @@ -115,7 +115,8 @@ def generate_or_wrap_output(self, input_orig, output_orig=None, model_params={})
input_orig (List[str]):
Original input segmented into units.
output_orig (str or List[str] or icx360.utils.model_wrappers.GeneratedOutput or None):
Output for original input if provided, otherwise None.
Output for original input.
Can be a single unit (str), segmented into units (List[str]), a GeneratedOutput object, or None.
model_params (dict):
Additional keyword arguments for model generation (for the self.model.generate() method).

Expand All @@ -130,11 +131,8 @@ def generate_or_wrap_output(self, input_orig, output_orig=None, model_params={})
# Generate output for original input
output_orig = self.model.generate([input_orig], text_only=False, **model_params)
elif type(output_orig) in (str, list):
if type(output_orig) is str:
output_orig = [output_orig]

# Wrap output text in a GeneratedOutput object
output_orig = GeneratedOutput(output_text=output_orig)
output_orig = GeneratedOutput(output_text=[output_orig])

if isinstance(self.model, HFModel):
# Also include output token IDs for HFModel
Expand All @@ -145,3 +143,31 @@ def generate_or_wrap_output(self, input_orig, output_orig=None, model_params={})
raise TypeError("output_orig must be a str, List[str], GeneratedOutput, or None.")

return output_orig

def segment_output(self, output_orig, segment_type_output=None, max_phrase_length=10):
"""
Segment output text (if needed).

Args:
output_orig (icx360.utils.model_wrappers.GeneratedOutput):
Object containing output for original input, in particular output text (output_orig.output_text).
segment_type_output (str or None):
Type of units to segment into: "s" for sentences, "ph" for phrases, None for no segmentation.
max_phrase_length (int):
Maximum phrase length in terms of spaCy tokens (default 10).

Returns:
output_orig (icx360.utils.model_wrappers.GeneratedOutput):
Output object with possibly segmented text.
"""
if type(output_orig.output_text[0]) is str and segment_type_output is not None:
# Output text not already segmented and segmentation requested, call segmenter
output_orig.output_text[0], _, _ = self.segmenter.segment_units(output_orig.output_text[0],
unit_types="p",
segment_type=segment_type_output,
max_phrase_length=max_phrase_length)

# Merge non-alphanumeric units into adjacent units
output_orig.output_text[0] = merge_non_alphanumeric(output_orig.output_text[0])

return output_orig
Loading