Skip to content

Commit 546d62f

Browse files
[Pytorch] Pytorch only schedulers (huggingface#534)
* pytorch only schedulers * fix style * remove match_shape * pytorch only ddpm * remove SchedulerMixin * remove numpy from karras_ve * fix types * remove numpy from lms_discrete * remove numpy from pndm * fix typo * remove mixin and numpy from sde_vp and ve * remove remaining tensor_format * fix style * sigmas has to be torch tensor * removed set_format in readme * remove set format from docs * remove set_format from pipelines * update tests * fix typo * continue to use mixin * fix imports * removed unsed imports * match shape instead of assuming image shapes * remove import typo * update call to add_noise * use math instead of numpy * fix t_index * removed commented out numpy tests * timesteps needs to be discrete * cast timesteps to int in flax scheduler too * fix device mismatch issue * small fix * Update src/diffusers/schedulers/scheduling_pndm.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 3e6214d commit 546d62f

20 files changed

Lines changed: 204 additions & 311 deletions

pipelines/ddim/pipeline_ddim.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline):
3535

3636
def __init__(self, unet, scheduler):
3737
super().__init__()
38-
scheduler = scheduler.set_format("pt")
3938
self.register_modules(unet=unet, scheduler=scheduler)
4039

4140
@torch.no_grad()

pipelines/ddpm/pipeline_ddpm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline):
3535

3636
def __init__(self, unet, scheduler):
3737
super().__init__()
38-
scheduler = scheduler.set_format("pt")
3938
self.register_modules(unet=unet, scheduler=scheduler)
4039

4140
@torch.no_grad()

pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
4646
):
4747
super().__init__()
48-
scheduler = scheduler.set_format("pt")
4948
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
5049

5150
@torch.no_grad()

pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ class LDMPipeline(DiffusionPipeline):
2323

2424
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
2525
super().__init__()
26-
scheduler = scheduler.set_format("pt")
2726
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
2827

2928
@torch.no_grad()

pipelines/pndm/pipeline_pndm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class PNDMPipeline(DiffusionPipeline):
3939

4040
def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
4141
super().__init__()
42-
scheduler = scheduler.set_format("pt")
4342
self.register_modules(unet=unet, scheduler=scheduler)
4443

4544
@torch.no_grad()

pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __init__(
5757
feature_extractor: CLIPFeatureExtractor,
5858
):
5959
super().__init__()
60-
scheduler = scheduler.set_format("pt")
6160

6261
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
6362
warnings.warn(

pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __init__(
6969
feature_extractor: CLIPFeatureExtractor,
7070
):
7171
super().__init__()
72-
scheduler = scheduler.set_format("pt")
7372

7473
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
7574
warnings.warn(

pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def __init__(
8383
feature_extractor: CLIPFeatureExtractor,
8484
):
8585
super().__init__()
86-
scheduler = scheduler.set_format("pt")
8786
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
8887

8988
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
@@ -320,11 +319,11 @@ def __call__(
320319
if isinstance(self.scheduler, LMSDiscreteScheduler):
321320
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
322321
# masking
323-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index))
322+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
324323
else:
325324
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
326325
# masking
327-
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
326+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
328327

329328
latents = (init_latents_proper * mask) + (latents * (1 - mask))
330329

pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def __init__(
3535
feature_extractor: CLIPFeatureExtractor,
3636
):
3737
super().__init__()
38-
scheduler = scheduler.set_format("np")
3938
self.register_modules(
4039
vae_decoder=vae_decoder,
4140
text_encoder=text_encoder,

pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class KarrasVePipeline(DiffusionPipeline):
2929

3030
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
3131
super().__init__()
32-
scheduler = scheduler.set_format("pt")
3332
self.register_modules(unet=unet, scheduler=scheduler)
3433

3534
@torch.no_grad()

0 commit comments

Comments
 (0)