Skip to content

Add loss function wrapper to add float regularization to any loss function#414

Open
treo wants to merge 1 commit intomasterfrom
treo/float_loss
Open

Add loss function wrapper to add float regularization to any loss function#414
treo wants to merge 1 commit intomasterfrom
treo/float_loss

Conversation

@treo
Copy link
Copy Markdown

@treo treo commented Apr 25, 2020

Implements the float idea presented in https://arxiv.org/abs/2002.08709

…ction

Signed-off-by: Paul Dubs <paul.dubs@gmail.com>
@treo treo requested a review from AlexDBlack April 25, 2020 20:40
@AlexDBlack
Copy link
Copy Markdown

AlexDBlack commented Apr 26, 2020

Them implementation and test looks good, but one point here is batch level vs. example level implementation

Example level: loss = mean( abs(exampleLoss - b) + b)
Batch level: loss = abs( mean(exampleLoss) - b) + b

The difference being that if any example is below the flood level, we do gradient ascent on that example; vs. only doing when the average over the entire batch is below the flood level.

Page 3 of the paper (line 3) (also pg8 eq9) seems to be batch level, whereas it looks like you have implemented example level.

Consider the following example:
2 examples in minibatch, with score[0] = 0.2; score[1] = 1.0, b=0.5

Under batch level loss, avgLoss=0.6 and hence no gradient ascent is performed.
Under example level loss, we do gradient ascent on example 0 and gradient descent on example 1.
This in this case, they are not equivalent.

Not sure which is better (they didn't explore this in the paper I think?) - but one option is having both and making it configurable (but doing batch level by default)

@treo
Copy link
Copy Markdown
Author

treo commented Apr 26, 2020

From an intuitive point of view an example level approach makes the most sense for me. If we make sure that each individual example doesn't overfit, then all of them don't overfit.
If we apply it on the batch level, then individual examples, that happen to be in an otherwise not quite as fitting batch, may overfit despite the overall batch getting not quite the best result.

If we look at eqn. 9 and go to batch size = 1, then we do get example level = batch level, which then gives us that the example case will be an upper bound for the full batch case.

Technically, the example level also makes sense because we often use the example level loss (via computeScoreArray) in different parts of DL4J, and it would be impossible to apply the batch level change in those cases.

I've tried the batch-level gradient reversal anyway, i.e. do gradient ascent if the average batch loss falls lower than the float level, and the results on MNIST were unimpressive. I've also tried it with a float level of 0.1, and the results again speak for themselves:
After 10 epochs, no l2, Adam(lr=0.0005), batchSize=64:
No Float loss at all:
accuracy: 0,9905
average loss range in 10th epoch: 0.0000128 to 0.03

per batch:
accuracy: 0,9864
average loss range in 10th epoch: 0.08 to 0.15

per example:
accuracy: 0.9922
average loss range in 10th epoch: 0.11 to 0.13

So I thought maybe floatLevel = 0.1 is too high, and went a bit lower to 0.05:
per batch:
accuracy: 0,9848
average loss range in 10th epoch: 0.025 to 0.059

per example:
accuracy: 0.9922
average loss range in 10th epoch: 0.062 to 0.069

Overall, it looks like the per example variant does work better, as it provides better outcomes.

I've also reached out to the first author of the paper, to ask how it was meant to be implemented, as the paper does indeed suggest that the per batch variant is meant. But I don't have an answer yet.

@AlexDBlack
Copy link
Copy Markdown

AlexDBlack commented Apr 27, 2020

Yeah, intuitively the example level makes more sense to me
like for a binary classifier, you could overfit one class, underfit the other, but still not get any benefit from the float loss depending on the average loss

In order to get this into the release (before we have confirmation) - maybe we provide both options, make example vs. batch a required arg, and (maybe) later make one or the other the default (i.e., add another constructor)?

@treo
Copy link
Copy Markdown
Author

treo commented Apr 27, 2020

The author got back to me, and he has given the confirmation that they have worked on a per-batch basis.

Thank you for reading our paper and implementing it for Deeplearning4J!
In our paper, we meant the batch-level flooding and this will only be equivalent to example-level flooding when the mini-batch size is 1.

We haven't tried it on a per-example basis, but have tried to reduce the training loss variance in a different way (which I think has similar effects), but it seemed to have only negative effects, at least on the synthetic experiments we tried. It's interesting that your MNIST experiments are showing the opposite results!

If the true p(y|x) can be very different for different examples (e.g. where class conditionals can have heavy overlap with each other), then I was thinking that flooding on a per-example basis can be harmful, because it would force all examples to have the same loss. If the true p(y|x) is not that different for different examples, perhaps it may not be that harmful (or even useful, as shown in your experiments). My guess is that easy image datasets like MNIST have a separable data distribution so the true p(y|x) may be similar for all examples. The synthetic experiments we tried had heavy distribution overlap so the true p(y|x) was different between examples, and that could be a possible reason it didn't work out for us. However these are only my intuition and we haven't done serious experiments in this direction yet.

I guess we'll keep this out of the beta7 release, and I'll investigate the effectiveness of the loss function further and try to reproduce their results.

At the moment I'm a bit at a loss of how a batch-level loss would be implemented correctly in those cases where example level loss (via computeScoreArray) is used directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants