Skip to content

[Bug Fix] Fix Train-Inference Mismatch#61

Open
Jayce-Ping wants to merge 4 commits intomainfrom
Fix_train-inference_mismatch
Open

[Bug Fix] Fix Train-Inference Mismatch#61
Jayce-Ping wants to merge 4 commits intomainfrom
Fix_train-inference_mismatch

Conversation

@Jayce-Ping
Copy link
Collaborator

There were small mismatch of next_latents_mean, next_latents and log_prob between train and inference, the causes are:

  1. Different precision of next_latents were used for log_prob computation: During inference, it was actually used as float32 while during training, next_latents were passed directly into the scheduler.step as bfloat16. Casting next_latents to the actual input_dtype will fix the issue.

  2. Small precision difference between timesteps and sigmas. Passing next_t for forward function during sampling fixes it.

There were small mismatch of `next_latents_mean`, `next_latents` and `log_prob` between train and inference, the causes are:
1. Different precision of `next_latents` were used for `log_prob` computation: During inference, it was actually used as `float32` while during training, `next_latents` were passed directly into the `scheduler.step` as `bfloat16`. Casting `next_latents` to the actual `input_dtype` will fix the issue.
2. Small precision difference between `timesteps` and `sigmas`. Passing `next_t` for `forward` function during sampling fixes it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant