Skip to content

Commit 9dfb86b

Browse files
committed
feat(kbann): constraining version of kbann injector available
1 parent 9c0711b commit 9dfb86b

3 files changed

Lines changed: 73 additions & 11 deletions

File tree

psyki/ski/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ def kins(model: Model,
4545
def kbann(model: Model,
4646
feature_mapping: dict[str, int],
4747
fuzzifier: str = 'towell',
48-
omega: int = 4) -> Injector:
48+
omega: float = 4.,
49+
gamma: float = 10E-3) -> Injector:
4950
from psyki.ski.kbann import KBANN
50-
return KBANN(model, feature_mapping, fuzzifier, omega)
51+
return KBANN(model, feature_mapping, fuzzifier, omega, gamma)
5152

5253

5354
class EnrichedModel(Model):

psyki/ski/kbann/__init__.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
1+
import copy
12
from typing import Iterable, Callable, List
3+
import tensorflow as tf
24
from tensorflow.keras.layers import Concatenate
35
from tensorflow import Tensor
6+
from tensorflow.keras.losses import Loss
7+
from tensorflow.python.keras.utils.generic_utils import custom_object_scope
48
from psyki.logic.datalog.grammar import optimize_datalog_formula
59
from psyki.logic import Fuzzifier, Formula
610
from tensorflow.keras import Model
7-
from psyki.ski import Injector
11+
from psyki.ski import Injector, EnrichedModel
12+
from psyki.utils import model_deep_copy
813

914

1015
class KBANN(Injector):
1116
"""
1217
Implementation of KBANN algorithm described by G. Towell in https://doi.org/10.1016/0004-3702(94)90105-8
1318
"""
1419

15-
def __init__(self, predictor: Model, feature_mapping: dict[str, int], fuzzifier: str, omega: float = 4):
20+
def __init__(self,
21+
predictor: Model,
22+
feature_mapping: dict[str, int],
23+
fuzzifier: str,
24+
omega: float = 4.,
25+
gamma: float = 10E-3):
1626
"""
1727
@param predictor: the predictor.
1828
@param feature_mapping: a map between variables in the logic formulae and indices of dataset features. Example:
@@ -25,9 +35,46 @@ def __init__(self, predictor: Model, feature_mapping: dict[str, int], fuzzifier:
2535
"""
2636
# self.feature_mapping: dict[str, int] = feature_mapping
2737
# Use as default fuzzifiers SubNetworkBuilder.
38+
# TODO: analyse this warning that sometimes comes out, this should not be armful.
39+
tf.get_logger().setLevel('ERROR')
2840
self._predictor = predictor
2941
self._fuzzifier = Fuzzifier.get(fuzzifier)([self._predictor.input, feature_mapping, omega])
3042
self._fuzzy_functions: Iterable[Callable] = ()
43+
self.gamma = gamma
44+
45+
class ConstrainedModel(EnrichedModel):
46+
47+
def __init__(self, model: Model, gamma: float, custom_objects: dict):
48+
super().__init__(model, custom_objects)
49+
self.gamma = gamma
50+
self.init_weights = copy.deepcopy(self.weights)
51+
52+
class CustomLoss(Loss):
53+
54+
def __init__(self, original_loss: Callable, model: Model, init_weights, gamma: float):
55+
self.original_loss = original_loss
56+
self.model = model
57+
self.init_weights = init_weights
58+
self.gamma = gamma
59+
super().__init__()
60+
61+
def call(self, y_true, y_pred):
62+
return self.original_loss(y_true, y_pred) + self.gamma * self._cost_factor()
63+
64+
def _cost_factor(self):
65+
weights_quadratic_diff = 0
66+
for init_weight, current_weight in zip(self.init_weights, self.model.weights):
67+
weights_quadratic_diff += tf.math.reduce_sum((init_weight - current_weight) ** 2)
68+
# weights_quadratic_diff = tf.math.reduce_sum((tf.ragged.constant(self.init_weights) - tf.ragged.constant(self.weights)) ** 2)
69+
return weights_quadratic_diff / (1 + weights_quadratic_diff)
70+
71+
def copy(self) -> EnrichedModel:
72+
with custom_object_scope(self.custom_objects):
73+
model = model_deep_copy(Model(self.input, self.output))
74+
return KBANN.ConstrainedModel(model, self.gamma, self.custom_objects)
75+
76+
def loss_function(self, original_function: Callable) -> Callable:
77+
return self.CustomLoss(original_function, self, self.init_weights, self.gamma)
3178

3279
def inject(self, rules: List[Formula]) -> Model:
3380
# Prevent side effect on the original rules during optimization.
@@ -37,5 +84,5 @@ def inject(self, rules: List[Formula]) -> Model:
3784
predictor_input: Tensor = self._predictor.input
3885
modules = self._fuzzifier.visit(rules_copy)
3986
x = Concatenate(axis=1)(modules)
40-
new_predictor = Model(predictor_input, x)
41-
return self._fuzzifier.enriched_model(new_predictor)
87+
#return self._fuzzifier.enriched_model(Model(predictor_input, x))
88+
return self.ConstrainedModel(Model(predictor_input, x), self.gamma, self._fuzzifier.custom_objects)

test/psyki/injectors/test_injection.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from tensorflow.keras.losses import SparseCategoricalCrossentropy
23
from psyki.ski import Injector
34
from sklearn.datasets import load_iris
45
from sklearn.model_selection import train_test_split, StratifiedKFold
@@ -93,7 +94,6 @@ class TestInjectionOnSpliceJunction(unittest.TestCase):
9394
x = get_binary_data(data.iloc[:, :-1], AGGREGATE_FEATURE_MAPPING)
9495
y.columns = [x.shape[1]]
9596
data = x.join(y)
96-
9797
data, test = train_test_split(data, train_size=1000, random_state=0, stratify=data.iloc[:, -1])
9898
k_fold = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)
9999
train_indices, _ = list(k_fold.split(data.iloc[:, :-1], data.iloc[:, -1:]))[0]
@@ -106,18 +106,28 @@ class TestInjectionOnSpliceJunction(unittest.TestCase):
106106
predictor = get_mlp(input_layer, 3, 3, [64, 32], 'relu', 'softmax', dropout=True)
107107
predictor = Model(input_layer, predictor)
108108

109-
def common_test_function(self, injector: Injector, batch_size: int, acceptable_accuracy: float):
109+
def common_test_function(self, injector: Injector, batch_size: int, acceptable_accuracy: float, constrain=False):
110110
model = injector.inject(self.rules)
111111
# Test if clone is successful
112112
cloned_model = model.copy()
113113
del injector
114114

115-
model.compile('adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
115+
if constrain:
116+
loss = model.loss_function(SparseCategoricalCrossentropy())
117+
else:
118+
loss = 'sparse_categorical_crossentropy'
119+
120+
model.compile('adam', loss=loss, metrics=['accuracy'])
116121
model.fit(self.train_x, self.train_y, batch_size=batch_size, epochs=self.EPOCHS, verbose=self.VERBOSE, callbacks=self.early_stop)
117122
accuracy = model.evaluate(self.test_x, self.test_y)[1]
118123
del model
119124

120-
cloned_model.compile('adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
125+
if constrain:
126+
loss = cloned_model.loss_function(SparseCategoricalCrossentropy())
127+
else:
128+
loss = 'sparse_categorical_crossentropy'
129+
130+
cloned_model.compile('adam', loss=loss, metrics=['accuracy'])
121131
cloned_model.fit(self.train_x, self.train_y, batch_size=batch_size, epochs=self.EPOCHS, verbose=self.VERBOSE, callbacks=self.early_stop)
122132
accuracy_cm = cloned_model.evaluate(self.test_x, self.test_y)[1]
123133
del cloned_model
@@ -127,7 +137,11 @@ def common_test_function(self, injector: Injector, batch_size: int, acceptable_a
127137

128138
def test_kbann(self):
129139
injector = Injector.kbann(self.predictor, get_splice_junction_extended_feature_mapping(), 'towell', 1)
130-
self.common_test_function(injector, batch_size=16, acceptable_accuracy=0.95)
140+
self.common_test_function(injector, batch_size=16, acceptable_accuracy=0.957)
141+
142+
def test_kbann_with_constraining(self):
143+
injector = Injector.kbann(self.predictor, get_splice_junction_extended_feature_mapping(), 'towell', 1, gamma=10E-5)
144+
self.common_test_function(injector, batch_size=16, acceptable_accuracy=0.958, constrain=True)
131145

132146
def test_kins(self):
133147
injector = Injector.kins(self.predictor, get_splice_junction_extended_feature_mapping())

0 commit comments

Comments
 (0)