@@ -224,6 +224,9 @@ def run_step(self, gradient, kl_new):
224224 # Enlarge the step
225225 if not self .fixed_step :
226226 self .step *= self .increment_step
227+
228+ # Perform the minimization step for new direction
229+ self .current_x = self .old_x - self .step * self .direction
227230 else :
228231 # Proceed with the line minimization
229232
@@ -243,14 +246,16 @@ def run_step(self, gradient, kl_new):
243246 print ("Step too large (scalar = {} | kl_ratio = {}), reducing to {}" .format (scalar , kl_ratio , self .step ))
244247 #print("Direction: ", self.direction)
245248 #print("Gradient: ", gradient)
249+
250+ # Try again with reduced step
251+ self .current_x = self .old_x - self .step * self .direction
246252 else :
247253 # The step is good, therefore next step perform a new direction
248254 self .new_direction = True
249255 if self .verbose :
250256 print ("Good step found with {}, try increment" .format (self .step ))
251-
252- # Perform the minimiziation step
253- self .current_x = self .old_x - self .step * self .direction
257+ # DO NOT update current_x - we accept the current position
258+ # (current_x was already updated in the previous step)
254259
255260
256261 def update_dyn (self , new_kl_ratio , dyn_gradient , structure_gradient = None ):
0 commit comments