1+ import copy
12from typing import Iterable , Callable , List
3+ import tensorflow as tf
24from tensorflow .keras .layers import Concatenate
35from tensorflow import Tensor
6+ from tensorflow .keras .losses import Loss
7+ from tensorflow .python .keras .utils .generic_utils import custom_object_scope
48from psyki .logic .datalog .grammar import optimize_datalog_formula
59from psyki .logic import Fuzzifier , Formula
610from tensorflow .keras import Model
7- from psyki .ski import Injector
11+ from psyki .ski import Injector , EnrichedModel
12+ from psyki .utils import model_deep_copy
813
914
1015class KBANN (Injector ):
1116 """
1217 Implementation of KBANN algorithm described by G. Towell in https://doi.org/10.1016/0004-3702(94)90105-8
1318 """
1419
15- def __init__ (self , predictor : Model , feature_mapping : dict [str , int ], fuzzifier : str , omega : float = 4 ):
20+ def __init__ (self ,
21+ predictor : Model ,
22+ feature_mapping : dict [str , int ],
23+ fuzzifier : str ,
24+ omega : float = 4. ,
25+ gamma : float = 10E-3 ):
1626 """
1727 @param predictor: the predictor.
1828 @param feature_mapping: a map between variables in the logic formulae and indices of dataset features. Example:
@@ -25,9 +35,46 @@ def __init__(self, predictor: Model, feature_mapping: dict[str, int], fuzzifier:
2535 """
2636 # self.feature_mapping: dict[str, int] = feature_mapping
2737 # Use as default fuzzifiers SubNetworkBuilder.
38+ # TODO: analyse this warning that sometimes comes out, this should not be armful.
39+ tf .get_logger ().setLevel ('ERROR' )
2840 self ._predictor = predictor
2941 self ._fuzzifier = Fuzzifier .get (fuzzifier )([self ._predictor .input , feature_mapping , omega ])
3042 self ._fuzzy_functions : Iterable [Callable ] = ()
43+ self .gamma = gamma
44+
45+ class ConstrainedModel (EnrichedModel ):
46+
47+ def __init__ (self , model : Model , gamma : float , custom_objects : dict ):
48+ super ().__init__ (model , custom_objects )
49+ self .gamma = gamma
50+ self .init_weights = copy .deepcopy (self .weights )
51+
52+ class CustomLoss (Loss ):
53+
54+ def __init__ (self , original_loss : Callable , model : Model , init_weights , gamma : float ):
55+ self .original_loss = original_loss
56+ self .model = model
57+ self .init_weights = init_weights
58+ self .gamma = gamma
59+ super ().__init__ ()
60+
61+ def call (self , y_true , y_pred ):
62+ return self .original_loss (y_true , y_pred ) + self .gamma * self ._cost_factor ()
63+
64+ def _cost_factor (self ):
65+ weights_quadratic_diff = 0
66+ for init_weight , current_weight in zip (self .init_weights , self .model .weights ):
67+ weights_quadratic_diff += tf .math .reduce_sum ((init_weight - current_weight ) ** 2 )
68+ # weights_quadratic_diff = tf.math.reduce_sum((tf.ragged.constant(self.init_weights) - tf.ragged.constant(self.weights)) ** 2)
69+ return weights_quadratic_diff / (1 + weights_quadratic_diff )
70+
71+ def copy (self ) -> EnrichedModel :
72+ with custom_object_scope (self .custom_objects ):
73+ model = model_deep_copy (Model (self .input , self .output ))
74+ return KBANN .ConstrainedModel (model , self .gamma , self .custom_objects )
75+
76+ def loss_function (self , original_function : Callable ) -> Callable :
77+ return self .CustomLoss (original_function , self , self .init_weights , self .gamma )
3178
3279 def inject (self , rules : List [Formula ]) -> Model :
3380 # Prevent side effect on the original rules during optimization.
@@ -37,5 +84,5 @@ def inject(self, rules: List[Formula]) -> Model:
3784 predictor_input : Tensor = self ._predictor .input
3885 modules = self ._fuzzifier .visit (rules_copy )
3986 x = Concatenate (axis = 1 )(modules )
40- new_predictor = Model (predictor_input , x )
41- return self ._fuzzifier .enriched_model ( new_predictor )
87+ #return self._fuzzifier.enriched_model( Model(predictor_input, x) )
88+ return self .ConstrainedModel ( Model ( predictor_input , x ), self . gamma , self . _fuzzifier .custom_objects )
0 commit comments