Add loss function wrapper to add float regularization to any loss function#414
Add loss function wrapper to add float regularization to any loss function#414
Conversation
…ction Signed-off-by: Paul Dubs <paul.dubs@gmail.com>
|
Them implementation and test looks good, but one point here is batch level vs. example level implementation Example level: loss = mean( abs(exampleLoss - b) + b) The difference being that if any example is below the flood level, we do gradient ascent on that example; vs. only doing when the average over the entire batch is below the flood level. Page 3 of the paper (line 3) (also pg8 eq9) seems to be batch level, whereas it looks like you have implemented example level. Consider the following example: Under batch level loss, avgLoss=0.6 and hence no gradient ascent is performed. Not sure which is better (they didn't explore this in the paper I think?) - but one option is having both and making it configurable (but doing batch level by default) |
|
From an intuitive point of view an example level approach makes the most sense for me. If we make sure that each individual example doesn't overfit, then all of them don't overfit. If we look at eqn. 9 and go to batch size = 1, then we do get example level = batch level, which then gives us that the example case will be an upper bound for the full batch case. Technically, the example level also makes sense because we often use the example level loss (via computeScoreArray) in different parts of DL4J, and it would be impossible to apply the batch level change in those cases. I've tried the batch-level gradient reversal anyway, i.e. do gradient ascent if the average batch loss falls lower than the float level, and the results on MNIST were unimpressive. I've also tried it with a float level of 0.1, and the results again speak for themselves: per batch: per example: So I thought maybe floatLevel = 0.1 is too high, and went a bit lower to 0.05: per example: Overall, it looks like the per example variant does work better, as it provides better outcomes. I've also reached out to the first author of the paper, to ask how it was meant to be implemented, as the paper does indeed suggest that the per batch variant is meant. But I don't have an answer yet. |
|
Yeah, intuitively the example level makes more sense to me In order to get this into the release (before we have confirmation) - maybe we provide both options, make example vs. batch a required arg, and (maybe) later make one or the other the default (i.e., add another constructor)? |
|
The author got back to me, and he has given the confirmation that they have worked on a per-batch basis.
I guess we'll keep this out of the beta7 release, and I'll investigate the effectiveness of the loss function further and try to reproduce their results. At the moment I'm a bit at a loss of how a batch-level loss would be implemented correctly in those cases where example level loss (via computeScoreArray) is used directly. |
Implements the float idea presented in https://arxiv.org/abs/2002.08709