Describe the bug
I discovered this when testing #659. It is currently possible to import diffusers without any of PyTorch and Flax; the dummy classes are loaded. But when Flax is installed and PyTorch isn't, then importing fails.
This is because of two reasons:
modeling_flax_utils.py imports load_state_dict. This is so we can perform the weight conversion from a PyTorch checkout.
- Flax scheduler outputs are subclasses of
SchedulerOutput, which declares the sample as a PyTorch tensor. I think we should create a FlaxSchedulerOutput instead.
Reproduction
Doesn't work
pip uninstall torch
pip install flax
>>> import diffusers
Works
pip uninstall torch
pip uninstall flax
>>> import diffusers
Logs
No response
System Info
Diffusers @ 877bec8
Describe the bug
I discovered this when testing #659. It is currently possible to
import diffuserswithout any of PyTorch and Flax; the dummy classes are loaded. But when Flax is installed and PyTorch isn't, then importing fails.This is because of two reasons:
modeling_flax_utils.pyimportsload_state_dict. This is so we can perform the weight conversion from a PyTorch checkout.SchedulerOutput, which declares the sample as a PyTorch tensor. I think we should create aFlaxSchedulerOutputinstead.Reproduction
Doesn't work
Works
Logs
No response
System Info
Diffusers @
877bec8