@@ -34,56 +34,64 @@ def save_and_plot_shap_values(dataloader, model):
3434 data , target = getData (dataloader )
3535 data = data .to (device )
3636 target = target .to (device )
37-
3837 model = model .to (device )
3938
4039 explainer = shap .DeepExplainer (model , data )
41- shap_values = explainer .shap_values (data ) # List of [class][samples, features]
40+ shap_values = explainer .shap_values (data )
41+
42+ # --- ROBUST SHAP SHAPE NORMALIZATION ---
43+ # SHAP can return a list of 10 arrays OR a list of 1 array containing all classes.
44+ # This block forces the data into a standard shape: (batch_size, num_classes, 28, 28)
45+ if isinstance (shap_values , list ):
46+ if len (shap_values ) == 10 :
47+ # Case A: List of 10 classes. Convert to array and swap axes to (batch, class, ...)
48+ shap_tensor = np .array (shap_values ).swapaxes (0 , 1 )
49+ elif len (shap_values ) == 1 :
50+ # Case B: List of 1 containing everything. Extract the array directly.
51+ shap_tensor = np .array (shap_values [0 ])
52+ else :
53+ shap_tensor = np .array (shap_values )
54+ else :
55+ # Case C: Returned a raw numpy array right out of the gate
56+ shap_tensor = np .array (shap_values )
57+
58+ # Flatten out the channel dimension and strictly enforce (10_images, 10_classes, 28, 28)
59+ shap_tensor = shap_tensor .reshape (len (data ), 10 , 28 , 28 )
60+ # ---------------------------------------
4261
4362 save_dir = 'SHAP'
4463 os .makedirs (save_dir , exist_ok = True )
4564
4665 # Create a 10x11 grid: 1 original + 10 SHAP values
4766 fig , axes = plt .subplots (10 , 11 , figsize = (20 , 22 ))
48- last_img = None # For colorbar
67+ last_img = None
4968
5069 for i in range (len (data )):
5170 label = target [i ].item ()
52- shap_i = [class_shap [i ] for class_shap in shap_values ] # SHAP per class, for this image
5371
5472 # Save original image
5573 np .save (f'{ save_dir } /{ i } _original.npy' , data [i ].cpu ().numpy ())
5674 axes [i , 0 ].imshow (data [i ].cpu ().reshape (28 , 28 ), cmap = 'gray' )
5775 axes [i , 0 ].set_title (f'Label: { label } ' )
5876 axes [i , 0 ].axis ('off' )
5977
60- for j in range (min (10 , len (shap_i ))):
61- shap_array = shap_i [j ]
62- try :
63- reshaped = shap_array .reshape (10 , 28 , 28 )[j ] # extract correct class
64- except Exception as e :
65- print (f"[ERROR] SHAP reshape failed for sample { i } , class { j } : { e } " )
66- continue
67-
68- np .save (f'{ save_dir } /{ i } _shap_{ j } .npy' , shap_array )
78+ # 1. Main Grid Plotting
79+ for j in range (10 ):
80+ reshaped = shap_tensor [i , j ] # Safely extracts the exact 28x28 grid
81+ np .save (f'{ save_dir } /{ i } _shap_{ j } .npy' , reshaped )
6982 last_img = axes [i , j + 1 ].imshow (reshaped , cmap = 'jet' )
7083 axes [i , j + 1 ].axis ('off' )
7184
72-
73- # Fill remaining columns
74- for j in range (len (shap_i ) + 1 , 11 ):
75- axes [i , j ].axis ('off' )
76-
77- # Save row as standalone image
85+ # 2. Save row as standalone image
7886 row_fig , row_axes = plt .subplots (1 , 11 , figsize = (20 , 2 ))
7987 row_axes [0 ].imshow (data [i ].cpu ().reshape (28 , 28 ), cmap = 'gray' )
8088 row_axes [0 ].set_title (f'Label: { label } ' )
8189 row_axes [0 ].axis ('off' )
82- for j in range (min (10 , len (shap_i ))):
83- row_axes [j + 1 ].imshow (shap_i [j ][:784 ].reshape (28 , 28 ), cmap = 'jet' )
90+
91+ for j in range (10 ):
92+ row_axes [j + 1 ].imshow (shap_tensor [i , j ], cmap = 'jet' )
8493 row_axes [j + 1 ].axis ('off' )
85- for j in range (len (shap_i ) + 1 , 11 ):
86- row_axes [j ].axis ('off' )
94+
8795 plt .tight_layout ()
8896 row_fig .savefig (f'{ save_dir } /row_{ i } .png' )
8997 plt .close (row_fig )
0 commit comments