Skip to content

Commit f77df20

Browse files
committed
updates
1 parent fcd34e6 commit f77df20

2 files changed

Lines changed: 32 additions & 24 deletions

File tree

manuscripts/PEARC24/MNIST/3p1_APSO_W_Anchor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import matplotlib.pyplot as plt
88

99
import Adversarial_Observation as AO
10-
from Adversarial_Observation.Swarm import ParticleSwarm # ← This is your custom swarm class
10+
from Adversarial_Observation.Swarm import ParticleSwarm
1111

1212
# --- Global Config ---
1313
optimize = 3 # Target class for the attack

manuscripts/PEARC24/MNIST/5_create_SHAP.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,56 +34,64 @@ def save_and_plot_shap_values(dataloader, model):
3434
data, target = getData(dataloader)
3535
data = data.to(device)
3636
target = target.to(device)
37-
3837
model = model.to(device)
3938

4039
explainer = shap.DeepExplainer(model, data)
41-
shap_values = explainer.shap_values(data) # List of [class][samples, features]
40+
shap_values = explainer.shap_values(data)
41+
42+
# --- ROBUST SHAP SHAPE NORMALIZATION ---
43+
# SHAP can return a list of 10 arrays OR a list of 1 array containing all classes.
44+
# This block forces the data into a standard shape: (batch_size, num_classes, 28, 28)
45+
if isinstance(shap_values, list):
46+
if len(shap_values) == 10:
47+
# Case A: List of 10 classes. Convert to array and swap axes to (batch, class, ...)
48+
shap_tensor = np.array(shap_values).swapaxes(0, 1)
49+
elif len(shap_values) == 1:
50+
# Case B: List of 1 containing everything. Extract the array directly.
51+
shap_tensor = np.array(shap_values[0])
52+
else:
53+
shap_tensor = np.array(shap_values)
54+
else:
55+
# Case C: Returned a raw numpy array right out of the gate
56+
shap_tensor = np.array(shap_values)
57+
58+
# Flatten out the channel dimension and strictly enforce (10_images, 10_classes, 28, 28)
59+
shap_tensor = shap_tensor.reshape(len(data), 10, 28, 28)
60+
# ---------------------------------------
4261

4362
save_dir = 'SHAP'
4463
os.makedirs(save_dir, exist_ok=True)
4564

4665
# Create a 10x11 grid: 1 original + 10 SHAP values
4766
fig, axes = plt.subplots(10, 11, figsize=(20, 22))
48-
last_img = None # For colorbar
67+
last_img = None
4968

5069
for i in range(len(data)):
5170
label = target[i].item()
52-
shap_i = [class_shap[i] for class_shap in shap_values] # SHAP per class, for this image
5371

5472
# Save original image
5573
np.save(f'{save_dir}/{i}_original.npy', data[i].cpu().numpy())
5674
axes[i, 0].imshow(data[i].cpu().reshape(28, 28), cmap='gray')
5775
axes[i, 0].set_title(f'Label: {label}')
5876
axes[i, 0].axis('off')
5977

60-
for j in range(min(10, len(shap_i))):
61-
shap_array = shap_i[j]
62-
try:
63-
reshaped = shap_array.reshape(10, 28, 28)[j] # extract correct class
64-
except Exception as e:
65-
print(f"[ERROR] SHAP reshape failed for sample {i}, class {j}: {e}")
66-
continue
67-
68-
np.save(f'{save_dir}/{i}_shap_{j}.npy', shap_array)
78+
# 1. Main Grid Plotting
79+
for j in range(10):
80+
reshaped = shap_tensor[i, j] # Safely extracts the exact 28x28 grid
81+
np.save(f'{save_dir}/{i}_shap_{j}.npy', reshaped)
6982
last_img = axes[i, j+1].imshow(reshaped, cmap='jet')
7083
axes[i, j+1].axis('off')
7184

72-
73-
# Fill remaining columns
74-
for j in range(len(shap_i) + 1, 11):
75-
axes[i, j].axis('off')
76-
77-
# Save row as standalone image
85+
# 2. Save row as standalone image
7886
row_fig, row_axes = plt.subplots(1, 11, figsize=(20, 2))
7987
row_axes[0].imshow(data[i].cpu().reshape(28, 28), cmap='gray')
8088
row_axes[0].set_title(f'Label: {label}')
8189
row_axes[0].axis('off')
82-
for j in range(min(10, len(shap_i))):
83-
row_axes[j+1].imshow(shap_i[j][:784].reshape(28, 28), cmap='jet')
90+
91+
for j in range(10):
92+
row_axes[j+1].imshow(shap_tensor[i, j], cmap='jet')
8493
row_axes[j+1].axis('off')
85-
for j in range(len(shap_i) + 1, 11):
86-
row_axes[j].axis('off')
94+
8795
plt.tight_layout()
8896
row_fig.savefig(f'{save_dir}/row_{i}.png')
8997
plt.close(row_fig)

0 commit comments

Comments
 (0)