From 155e3f1aff9a95df155a4a7aae6950a37ad45f91 Mon Sep 17 00:00:00 2001 From: "Dave Lage (rockerBOO)" Date: Mon, 2 Jan 2023 13:12:49 -0500 Subject: [PATCH 01/10] Implement DPM Solver Singlestep Scheduler --- examples/stable-diffusion/main.rs | 8 +- src/schedulers/dpmsolver_singlestep.rs | 379 +++++++++++++++++++++++++ src/schedulers/mod.rs | 1 + 3 files changed, 386 insertions(+), 2 deletions(-) create mode 100644 src/schedulers/dpmsolver_singlestep.rs diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs index f245adc..ea41b18 100644 --- a/examples/stable-diffusion/main.rs +++ b/examples/stable-diffusion/main.rs @@ -69,7 +69,7 @@ // cargo run --release --example tensor-tools cp ./data/vae.npz ./data/vae.ot // cargo run --release --example tensor-tools cp ./data/unet.npz ./data/unet.ot use clap::Parser; -use diffusers::pipelines::stable_diffusion; +use diffusers::{pipelines::stable_diffusion, schedulers}; use diffusers::transformers::clip; use tch::{nn::Module, Device, Kind, Tensor}; @@ -241,7 +241,6 @@ fn run(args: Args) -> anyhow::Result<()> { let clip_device = cpu_or_cuda("clip"); let vae_device = cpu_or_cuda("vae"); let unet_device = cpu_or_cuda("unet"); - let scheduler = sd_config.build_scheduler(n_steps); let tokenizer = clip::Tokenizer::create(vocab_file, &sd_config.clip)?; println!("Running with prompt \"{prompt}\"."); @@ -276,6 +275,11 @@ fn run(args: Args) -> anyhow::Result<()> { // scale the initial noise by the standard deviation required by the scheduler latents *= scheduler.init_noise_sigma(); + let scheduler = sd_config.build_scheduler(n_steps); + // let mut scheduler = schedulers::dpmsolver_singlestep::DPMSolverSinglestepScheduler::new(n_steps, Default::default()); + // Using this scheduler requires mutability, so change the to the following + // for (timestep_index, ×tep) in scheduler.timesteps().to_owned().iter().enumerate() { + for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() { println!("Timestep {timestep_index}/{n_steps}"); let latent_model_input = Tensor::cat(&[&latents, &latents], 0); diff --git a/src/schedulers/dpmsolver_singlestep.rs b/src/schedulers/dpmsolver_singlestep.rs new file mode 100644 index 0000000..247fe15 --- /dev/null +++ b/src/schedulers/dpmsolver_singlestep.rs @@ -0,0 +1,379 @@ +use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use tch::{kind, Kind, Tensor}; + +/// The algorithm type for the solver. +/// +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub enum DPMSolverAlgorithmType { + /// Implements the algorithms defined in . + #[default] + DPMSolverPlusPlus, + /// Implements the algorithms defined in . + DPMSolver, +} + +/// The solver type for the second-order solver. +/// The solver type slightly affects the sample quality, especially for +/// small number of steps. +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub enum DPMSolverType { + #[default] + Midpoint, + Heun, +} + +#[derive(Debug, Clone)] +pub struct DPMSolverSinglestepSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// number of diffusion steps used to train the model. + pub train_timesteps: usize, + /// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + /// sampling, and `solver_order=3` for unconditional sampling. + pub solver_order: usize, + /// prediction type of the scheduler function + pub prediction_type: PredictionType, + /// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and + /// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`. + pub sample_max_value: f32, + /// The algorithm type for the solver + pub algorithm_type: DPMSolverAlgorithmType, + /// The solver type for the second-order solver. + pub solver_type: DPMSolverType, + /// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + /// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10. + pub lower_order_final: bool, +} + +impl Default for DPMSolverSinglestepSchedulerConfig { + fn default() -> Self { + Self { + train_timesteps: 1000, + beta_start: 0.0001, + beta_end: 0.02, + beta_schedule: BetaSchedule::Linear, + solver_order: 2, + prediction_type: PredictionType::Epsilon, + sample_max_value: 1.0, + algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus, + solver_type: DPMSolverType::Midpoint, + lower_order_final: true, + } + } +} + +pub struct DPMSolverSinglestepScheduler { + alphas_cumprod: Vec, + alpha_t: Vec, + sigma_t: Vec, + lambda_t: Vec, + init_noise_sigma: f64, + lower_order_nums: usize, + model_outputs: Vec, + timesteps: Vec, + pub config: DPMSolverSinglestepSchedulerConfig, +} + +impl DPMSolverSinglestepScheduler { + pub fn new(inference_steps: usize, config: DPMSolverSinglestepSchedulerConfig) -> Self { + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => Tensor::linspace( + config.beta_start.sqrt(), + config.beta_end.sqrt(), + config.train_timesteps as i64, + kind::FLOAT_CPU, + ) + .square(), + BetaSchedule::Linear => Tensor::linspace( + config.beta_start, + config.beta_end, + config.train_timesteps as i64, + kind::FLOAT_CPU, + ), + BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999), + }; + let alphas: Tensor = 1. - betas; + let alphas_cumprod = alphas.cumprod(0, Kind::Double); + + let alpha_t = alphas_cumprod.sqrt(); + let sigma_t = ((1. - &alphas_cumprod) as Tensor).sqrt(); + let lambda_t = alpha_t.log() - sigma_t.log(); + + let step = (config.train_timesteps - 1) as f64 / inference_steps as f64; + // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L199-L204 + let timesteps: Vec = (0..inference_steps + 1) + .map(|i| (i as f64 * step).round() as usize) + // discards the 0.0 element + .skip(1) + .rev() + .collect(); + + let mut model_outputs = Vec::::new(); + for _ in 0..config.solver_order { + model_outputs.push(Tensor::new()); + } + + Self { + alphas_cumprod: Vec::::from(alphas_cumprod), + alpha_t: Vec::::from(alpha_t), + sigma_t: Vec::::from(sigma_t), + lambda_t: Vec::::from(lambda_t), + init_noise_sigma: 1., + lower_order_nums: 0, + model_outputs, + timesteps, + config, + } + } + + /// Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + /// + /// DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + /// discretize an integral of the data prediction model. So we need to first convert the model output to the + /// corresponding type to match the algorithm. + /// + /// Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + /// DPM-Solver++ for both noise prediction model and data prediction model. + fn convert_model_output( + &self, + model_output: &Tensor, + timestep: usize, + sample: &Tensor, + ) -> Tensor { + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => { + match self.config.prediction_type { + PredictionType::Epsilon => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + (sample - sigma_t * model_output) / alpha_t + } + PredictionType::Sample => model_output.shallow_clone(), + PredictionType::VPrediction => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + alpha_t * sample - sigma_t * model_output + } + } + // TODO: implement Dynamic thresholding + // https://arxiv.org/abs/2205.11487 + } + DPMSolverAlgorithmType::DPMSolver => match self.config.prediction_type { + PredictionType::Epsilon => model_output.shallow_clone(), + PredictionType::Sample => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + (sample - alpha_t * model_output) / sigma_t + } + PredictionType::VPrediction => { + let alpha_t = self.alpha_t[timestep]; + let sigma_t = self.sigma_t[timestep]; + alpha_t * model_output + sigma_t * sample + } + }, + } + } + + /// One step for the first-order DPM-Solver (equivalent to DDIM). + /// See https://arxiv.org/abs/2206.00927 for the detailed derivation. + fn dpm_solver_first_order_update( + &self, + model_output: Tensor, + timestep: usize, + prev_timestep: usize, + sample: &Tensor, + ) -> Tensor { + let (lambda_t, lambda_s) = (self.lambda_t[prev_timestep], self.lambda_t[timestep]); + let (alpha_t, alpha_s) = (self.alpha_t[prev_timestep], self.alpha_t[timestep]); + let (sigma_t, sigma_s) = (self.sigma_t[prev_timestep], self.sigma_t[timestep]); + let h = lambda_t - lambda_s; + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => { + (sigma_t / sigma_s) * sample - (alpha_t * ((-h).exp() - 1.0)) * model_output + } + DPMSolverAlgorithmType::DPMSolver => { + (alpha_t / alpha_s) * sample - (sigma_t * (h.exp() - 1.0)) * model_output + } + } + } + + /// One step for the second-order multistep DPM-Solver. + fn singlestep_dpm_solver_second_order_update( + &self, + model_output_list: &Vec, + timestep_list: [usize; 2], + prev_timestep: usize, + sample: &Tensor, + ) -> Tensor { + let (t, s0, s1) = ( + prev_timestep, + timestep_list[timestep_list.len() - 1], + timestep_list[timestep_list.len() - 2], + ); + let (m0, m1) = ( + model_output_list[model_output_list.len() - 1].as_ref(), + model_output_list[model_output_list.len() - 2].as_ref(), + ); + let (lambda_t, lambda_s0, lambda_s1) = + (self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]); + let (alpha_t, alpha_s0) = (self.alpha_t[t], self.alpha_t[s0]); + let (sigma_t, sigma_s0) = (self.sigma_t[t], self.sigma_t[s0]); + let (h, h_0) = (lambda_t - lambda_s0, lambda_s0 - lambda_s1); + let r0 = h_0 / h; + let (d0, d1) = (m0, (1.0 / r0) * (m0 - m1)); + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => match self.config.solver_type { + // See https://arxiv.org/abs/2211.01095 for detailed derivations + DPMSolverType::Midpoint => { + (sigma_t / sigma_s0) * sample + - (alpha_t * ((-h).exp() - 1.0)) * d0 + - 0.5 * (alpha_t * ((-h).exp() - 1.0)) * d1 + } + DPMSolverType::Heun => { + (sigma_t / sigma_s0) * sample - (alpha_t * ((-h).exp() - 1.0)) * d0 + + (alpha_t * (((-h).exp() - 1.0) / h + 1.0)) * d1 + } + }, + DPMSolverAlgorithmType::DPMSolver => match self.config.solver_type { + // See https://arxiv.org/abs/2206.00927 for detailed derivations + DPMSolverType::Midpoint => { + (alpha_t / alpha_s0) * sample + - (sigma_t * (h.exp() - 1.0)) * d0 + - 0.5 * (sigma_t * (h.exp() - 1.0)) * d1 + } + DPMSolverType::Heun => { + (alpha_t / alpha_s0) * sample + - (sigma_t * (h.exp() - 1.0)) * d0 + - (sigma_t * ((h.exp() - 1.0) / h - 1.0)) * d1 + } + }, + } + } + + /// One step for the third-order multistep DPM-Solver + fn singlestep_dpm_solver_third_order_update( + &self, + model_output_list: &Vec, + timestep_list: [usize; 3], + prev_timestep: usize, + sample: &Tensor, + ) -> Tensor { + let (t, s0, s1, s2) = ( + prev_timestep, + timestep_list[timestep_list.len() - 1], + timestep_list[timestep_list.len() - 2], + timestep_list[timestep_list.len() - 3], + ); + let (m0, m1, m2) = ( + model_output_list[model_output_list.len() - 1].as_ref(), + model_output_list[model_output_list.len() - 2].as_ref(), + model_output_list[model_output_list.len() - 3].as_ref(), + ); + let (lambda_t, lambda_s0, lambda_s1, lambda_s2) = + (self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2]); + let (alpha_t, alpha_s0) = (self.alpha_t[t], self.alpha_t[s0]); + let (sigma_t, sigma_s0) = (self.sigma_t[t], self.sigma_t[s0]); + let (h, h_0, h_1) = (lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2); + let (r0, r1) = (h_0 / h, h_1 / h); + let d0 = m0; + let (d1_0, d1_1) = ((1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)); + let d1 = &d1_0 + (r0 / (r0 + r1)) * (&d1_0 - &d1_1.shallow_clone()); + let d2 = (1.0 / (r0 + r1)) * (d1_0 - d1_1.shallow_clone()); + + match self.config.algorithm_type { + DPMSolverAlgorithmType::DPMSolverPlusPlus => match self.config.solver_type { + DPMSolverType::Midpoint => { + (sigma_t / sigma_s0) * sample - (alpha_t * ((-h).exp() - 1.0)) * d0 + + (alpha_t * (((-h).exp() - 1.0) / h + 1.0)) * d1_1 + } + DPMSolverType::Heun => { + // See https://arxiv.org/abs/2206.00927 for detailed derivations + (sigma_t / sigma_s0) * sample - (alpha_t * ((-h).exp() - 1.0)) * d0 + + (alpha_t * (((-h).exp() - 1.0) / h + 1.0)) * d1 + - (alpha_t * (((-h).exp() - 1.0 + h) / h.powi(2) - 0.5)) * d2 + } + }, + DPMSolverAlgorithmType::DPMSolver => match self.config.solver_type { + DPMSolverType::Midpoint => { + (alpha_t / alpha_s0) * sample + - (sigma_t * (h.exp() - 1.0)) * d0 + - (sigma_t * ((h.exp() - 1.0) / h - 1.0)) * d1_1 + } + DPMSolverType::Heun => { + (alpha_t / alpha_s0) * sample + - (sigma_t * (h.exp() - 1.0)) * d0 + - (sigma_t * ((h.exp() - 1.0) / h - 1.0)) * d1 + - (sigma_t * ((h.exp() - 1.0 - h) / h.powi(2) - 0.5)) * d2 + } + }, + } + } + + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { + // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L457 + let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); + + let prev_timestep = + if step_index == self.timesteps.len() - 1 { 0 } else { self.timesteps[step_index + 1] }; + let lower_order_final = (step_index == self.timesteps.len() - 1) + && self.config.lower_order_final + && self.timesteps.len() < 15; + let lower_order_second = (step_index == self.timesteps.len() - 2) + && self.config.lower_order_final + && self.timesteps.len() < 15; + + let model_output = self.convert_model_output(model_output, timestep, sample); + for i in 0..self.config.solver_order - 1 { + self.model_outputs[i] = self.model_outputs[i + 1].shallow_clone(); + } + let m = self.model_outputs.len(); + self.model_outputs[m - 1] = model_output.shallow_clone(); + + let prev_sample = if self.config.solver_order == 1 + || self.lower_order_nums < 1 + || lower_order_final + { + self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) + } else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second { + let timestep_list = [self.timesteps[step_index - 1], timestep]; + self.singlestep_dpm_solver_second_order_update( + &self.model_outputs, + timestep_list, + prev_timestep, + sample, + ) + } else { + let timestep_list = + [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]; + self.singlestep_dpm_solver_third_order_update( + &self.model_outputs, + timestep_list, + prev_timestep, + sample, + ) + }; + + if self.lower_order_nums < self.config.solver_order { + self.lower_order_nums += 1; + } + + prev_sample + } + + pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { + self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned() + + (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} diff --git a/src/schedulers/mod.rs b/src/schedulers/mod.rs index e943de1..ac22531 100644 --- a/src/schedulers/mod.rs +++ b/src/schedulers/mod.rs @@ -8,6 +8,7 @@ use tch::{Kind, Tensor}; pub mod ddim; pub mod ddpm; pub mod dpmsolver_multistep; +pub mod dpmsolver_singlestep; pub mod euler_ancestral_discrete; pub mod euler_discrete; pub mod heun_discrete; From 352192fab54efe8903c3a73ed128d5a1efee0486 Mon Sep 17 00:00:00 2001 From: "Dave Lage (rockerBOO)" Date: Tue, 3 Jan 2023 17:45:38 -0500 Subject: [PATCH 02/10] Update to use solver_order and order_list --- examples/stable-diffusion/main.rs | 4 +- src/schedulers/dpmsolver_singlestep.rs | 191 ++++++++++++++++++++----- 2 files changed, 156 insertions(+), 39 deletions(-) diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs index ea41b18..531ee86 100644 --- a/examples/stable-diffusion/main.rs +++ b/examples/stable-diffusion/main.rs @@ -69,8 +69,8 @@ // cargo run --release --example tensor-tools cp ./data/vae.npz ./data/vae.ot // cargo run --release --example tensor-tools cp ./data/unet.npz ./data/unet.ot use clap::Parser; -use diffusers::{pipelines::stable_diffusion, schedulers}; use diffusers::transformers::clip; +use diffusers::{pipelines::stable_diffusion, schedulers}; use tch::{nn::Module, Device, Kind, Tensor}; const GUIDANCE_SCALE: f64 = 7.5; @@ -278,7 +278,7 @@ fn run(args: Args) -> anyhow::Result<()> { let scheduler = sd_config.build_scheduler(n_steps); // let mut scheduler = schedulers::dpmsolver_singlestep::DPMSolverSinglestepScheduler::new(n_steps, Default::default()); // Using this scheduler requires mutability, so change the to the following - // for (timestep_index, ×tep) in scheduler.timesteps().to_owned().iter().enumerate() { + // scheduler.timesteps().to_owned().iter().enumerate() for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() { println!("Timestep {timestep_index}/{n_steps}"); diff --git a/src/schedulers/dpmsolver_singlestep.rs b/src/schedulers/dpmsolver_singlestep.rs index 247fe15..4945af4 100644 --- a/src/schedulers/dpmsolver_singlestep.rs +++ b/src/schedulers/dpmsolver_singlestep.rs @@ -1,3 +1,5 @@ +use std::iter::repeat; + use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; use tch::{kind, Kind, Tensor}; @@ -52,9 +54,9 @@ pub struct DPMSolverSinglestepSchedulerConfig { impl Default for DPMSolverSinglestepSchedulerConfig { fn default() -> Self { Self { - train_timesteps: 1000, beta_start: 0.0001, beta_end: 0.02, + train_timesteps: 1000, beta_schedule: BetaSchedule::Linear, solver_order: 2, prediction_type: PredictionType::Epsilon, @@ -72,9 +74,10 @@ pub struct DPMSolverSinglestepScheduler { sigma_t: Vec, lambda_t: Vec, init_noise_sigma: f64, - lower_order_nums: usize, + order_list: Vec, model_outputs: Vec, timesteps: Vec, + sample: Option, pub config: DPMSolverSinglestepSchedulerConfig, } @@ -123,10 +126,11 @@ impl DPMSolverSinglestepScheduler { sigma_t: Vec::::from(sigma_t), lambda_t: Vec::::from(lambda_t), init_noise_sigma: 1., - lower_order_nums: 0, + order_list: get_order_list(inference_steps, config.solver_order, false), model_outputs, timesteps, config, + sample: None, } } @@ -182,7 +186,7 @@ impl DPMSolverSinglestepScheduler { /// See https://arxiv.org/abs/2206.00927 for the detailed derivation. fn dpm_solver_first_order_update( &self, - model_output: Tensor, + model_output: &Tensor, timestep: usize, prev_timestep: usize, sample: &Tensor, @@ -319,52 +323,53 @@ impl DPMSolverSinglestepScheduler { pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L457 - let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); + let step_index: usize = self.timesteps.iter().position(|&t| t == timestep).unwrap(); let prev_timestep = if step_index == self.timesteps.len() - 1 { 0 } else { self.timesteps[step_index + 1] }; - let lower_order_final = (step_index == self.timesteps.len() - 1) - && self.config.lower_order_final - && self.timesteps.len() < 15; - let lower_order_second = (step_index == self.timesteps.len() - 2) - && self.config.lower_order_final - && self.timesteps.len() < 15; - - let model_output = self.convert_model_output(model_output, timestep, sample); + + let model_output = self.convert_model_output(model_output, timestep, &sample); for i in 0..self.config.solver_order - 1 { self.model_outputs[i] = self.model_outputs[i + 1].shallow_clone(); } let m = self.model_outputs.len(); - self.model_outputs[m - 1] = model_output.shallow_clone(); - - let prev_sample = if self.config.solver_order == 1 - || self.lower_order_nums < 1 - || lower_order_final - { - self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) - } else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second { - let timestep_list = [self.timesteps[step_index - 1], timestep]; - self.singlestep_dpm_solver_second_order_update( + self.model_outputs[m - 1] = model_output; + + let order = self.order_list[step_index]; + + // For single-step solvers, we use the initial value at each time with order = 1. + if order == 1 { + self.sample = Some(sample.shallow_clone()); + }; + + let prev_sample = match order { + 1 => self.dpm_solver_first_order_update( + &self.model_outputs[self.model_outputs.len() - 1], + timestep, + prev_timestep, + &self.sample.as_ref().unwrap(), + ), + 2 => self.singlestep_dpm_solver_second_order_update( &self.model_outputs, - timestep_list, + [self.timesteps[step_index - 1], self.timesteps[step_index]], prev_timestep, - sample, - ) - } else { - let timestep_list = - [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]; - self.singlestep_dpm_solver_third_order_update( + self.sample.as_ref().unwrap(), + ), + 3 => self.singlestep_dpm_solver_third_order_update( &self.model_outputs, - timestep_list, + [ + self.timesteps[step_index - 2], + self.timesteps[step_index - 1], + self.timesteps[step_index], + ], prev_timestep, - sample, - ) + self.sample.as_ref().unwrap(), + ), + _ => { + panic!("invalid order"); + } }; - if self.lower_order_nums < self.config.solver_order { - self.lower_order_nums += 1; - } - prev_sample } @@ -377,3 +382,115 @@ impl DPMSolverSinglestepScheduler { self.init_noise_sigma } } + +/// Computes the solver order at each time step. +/// solver_order 1 2 3 +fn get_order_list<'a>(steps: usize, solver_order: usize, lower_order_final: bool) -> Vec { + if lower_order_final { + if solver_order == 3 { + if steps % 3 == 0 { + repeat(&[1, 2, 3][..]) + .take((steps / 3) - 1) + .chain([&[1, 2][..], &[1][..]]) + .flatten() + .map(|v| *v) + .collect() + } else if steps % 3 == 1 { + repeat(&[1, 2, 3][..]) + .take(steps / 3) + .chain([&[1][..]]) + .flatten() + .map(|v| *v) + .collect() + } else { + repeat(&[1, 2, 3][..]) + .take(steps / 3) + .chain([&[1][..], &[2][..]]) + .flatten() + .map(|v| *v) + .collect() + } + } else if solver_order == 2 { + if steps % 2 == 0 { + repeat(&[1, 2][..]).take(steps / 2).flatten().map(|v| *v).collect() + } else { + repeat(&[1, 2][..]) + .take(steps / 2) + .chain([&[1][..]]) + .flatten() + .map(|v| *v) + .collect() + } + } else if solver_order == 1 { + repeat(&[1][..]).take(steps).flatten().map(|v| *v).collect() + } else { + panic!("invalid solver_order"); + } + } else { + if solver_order == 3 { + repeat(&[1, 2, 3][..]).take(steps / 3).flatten().map(|v| *v).collect() + } else if solver_order == 2 { + repeat(&[1, 2][..]).take(steps / 2).flatten().map(|v| *v).collect() + } else if solver_order == 1 { + repeat(&[1][..]).take(steps).flatten().map(|v| *v).collect() + } else { + panic!("invalid solver_order"); + } + } +} + +#[cfg(test)] +mod tests { + use super::get_order_list; + + #[test] + fn order_list() { + let list = get_order_list(15, 2, false); + + assert_eq!(15, list.len()); + assert_eq!(vec![1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2], list); + + let list = get_order_list(16, 2, false); + + assert_eq!(16, list.len()); + assert_eq!(vec![1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2], list); + + let list = get_order_list(16, 1, false); + + assert_eq!(16, list.len()); + assert_eq!(vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], list); + + let list = get_order_list(16, 3, false); + + assert_eq!(16, list.len()); + assert_eq!(vec![1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3], list); + + let list = get_order_list(16, 3, true); + + assert_eq!(16, list.len()); + assert_eq!(vec![1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1], list); + + let list = get_order_list(16, 1, true); + + assert_eq!(16, list.len()); + assert_eq!(vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], list); + + let list = get_order_list(25, 1, true); + + assert_eq!(25, list.len()); + assert_eq!( + vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + list + ); + + let list = get_order_list(1, 1, true); + + assert_eq!(1, list.len()); + assert_eq!(vec![1], list); + + let list = get_order_list(2, 2, true); + + assert_eq!(2, list.len()); + assert_eq!(vec![1, 2], list); + } +} From 26bb145dd8049477b1f5acee340e92aedd411115 Mon Sep 17 00:00:00 2001 From: "Dave Lage (rockerBOO)" Date: Tue, 3 Jan 2023 20:30:59 -0500 Subject: [PATCH 03/10] Add more documentation --- src/schedulers/dpmsolver_singlestep.rs | 55 +++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/src/schedulers/dpmsolver_singlestep.rs b/src/schedulers/dpmsolver_singlestep.rs index 4945af4..384f0ec 100644 --- a/src/schedulers/dpmsolver_singlestep.rs +++ b/src/schedulers/dpmsolver_singlestep.rs @@ -142,6 +142,12 @@ impl DPMSolverSinglestepScheduler { /// /// Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or /// DPM-Solver++ for both noise prediction model and data prediction model. + /// + /// # Arguments + /// + /// * `model_output` - direct output from learned diffusion mode + /// * `timestep` - current discrete timestep in the diffusion chain + /// * `sample` - current instance of sample being created by diffusion process fn convert_model_output( &self, model_output: &Tensor, @@ -182,8 +188,15 @@ impl DPMSolverSinglestepScheduler { } } - /// One step for the first-order DPM-Solver (equivalent to DDIM). - /// See https://arxiv.org/abs/2206.00927 for the detailed derivation. + /// One step for the first-order DPM-Solver (equivalent to DDIM). + /// See https://arxiv.org/abs/2206.00927 for the detailed derivation. + /// + /// # Arguments + /// + /// * `model_output` - direct output from learned diffusion model + /// * `timestep` - current discrete timestep in the diffusion chain + /// * `prev_timestep` - previous discrete timestep in the diffusion chain + /// * `sample` - current instance of sample being created by diffusion process fn dpm_solver_first_order_update( &self, model_output: &Tensor, @@ -205,7 +218,15 @@ impl DPMSolverSinglestepScheduler { } } - /// One step for the second-order multistep DPM-Solver. + /// One step for the second-order multistep DPM-Solver. + /// It computes the solution at time `prev_timestep` from the time `timestep_list[-2]`. + /// + /// # Arguments + /// + /// * `model_output_list` - direct outputs from learned diffusion model at current and latter timesteps + /// * `timestep_list` - current and latter discrete timestep in the diffusion chain + /// * `prev_timestep` - previous discrete timestep in the diffusion chain + /// * `sample` - current instance of sample being created by diffusion process fn singlestep_dpm_solver_second_order_update( &self, model_output_list: &Vec, @@ -259,6 +280,14 @@ impl DPMSolverSinglestepScheduler { } /// One step for the third-order multistep DPM-Solver + /// It computes the solution at time `prev_timestep` from the time `timestep_list[-3]`. + /// + /// # Arguments + /// + /// * `model_output_list` - direct outputs from learned diffusion model at current and latter timesteps + /// * `timestep_list` - current and latter discrete timestep in the diffusion chain + /// * `prev_timestep` - previous discrete timestep in the diffusion chain + /// * `sample` - current instance of sample being created by diffusion process fn singlestep_dpm_solver_third_order_update( &self, model_output_list: &Vec, @@ -321,8 +350,15 @@ impl DPMSolverSinglestepScheduler { self.timesteps.as_slice() } + /// Step function propagating the sample with the singlestep DPM-Solver + /// + /// # Arguments + /// + /// * `model_output` - direct output from learned diffusion model + /// * `timestep` - current discrete timestep in the diffusion chain + /// * `sample` - current instance of sample being created by diffusion process pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { - // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L457 + // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535 let step_index: usize = self.timesteps.iter().position(|&t| t == timestep).unwrap(); let prev_timestep = @@ -384,8 +420,15 @@ impl DPMSolverSinglestepScheduler { } /// Computes the solver order at each time step. -/// solver_order 1 2 3 -fn get_order_list<'a>(steps: usize, solver_order: usize, lower_order_final: bool) -> Vec { +/// +/// # Arguments +/// +/// * `steps` - the number of diffusion steps used when generating samples with a pre-trained model +/// * `solver_order` - the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided +/// sampling, and `solver_order=3` for unconditional sampling. +/// * `lower_order_final` - whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable +/// this to use up all the function evaluations. +fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) -> Vec { if lower_order_final { if solver_order == 3 { if steps % 3 == 0 { From 2af1c2acc275c338199d6b157b485700ab20b42f Mon Sep 17 00:00:00 2001 From: "Dave Lage (rockerBOO)" Date: Sat, 7 Jan 2023 16:00:38 -0500 Subject: [PATCH 04/10] Refactor common parts of multistep/singlestep into dpmsolver --- src/schedulers/dpmsolver.rs | 67 ++++++++++++++++++++ src/schedulers/dpmsolver_multistep.rs | 74 ++--------------------- src/schedulers/dpmsolver_singlestep.rs | 84 ++++---------------------- src/schedulers/mod.rs | 1 + 4 files changed, 84 insertions(+), 142 deletions(-) create mode 100644 src/schedulers/dpmsolver.rs diff --git a/src/schedulers/dpmsolver.rs b/src/schedulers/dpmsolver.rs new file mode 100644 index 0000000..4aeb6b5 --- /dev/null +++ b/src/schedulers/dpmsolver.rs @@ -0,0 +1,67 @@ +use crate::schedulers::BetaSchedule; +use crate::schedulers::PredictionType; + +/// The algorithm type for the solver. +/// +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub enum DPMSolverAlgorithmType { + /// Implements the algorithms defined in . + #[default] + DPMSolverPlusPlus, + /// Implements the algorithms defined in . + DPMSolver, +} + +/// The solver type for the second-order solver. +/// The solver type slightly affects the sample quality, especially for +/// small number of steps. +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub enum DPMSolverType { + #[default] + Midpoint, + Heun, +} + +#[derive(Debug, Clone)] +pub struct DPMSolverSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// number of diffusion steps used to train the model. + pub train_timesteps: usize, + /// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + /// sampling, and `solver_order=3` for unconditional sampling. + pub solver_order: usize, + /// prediction type of the scheduler function + pub prediction_type: PredictionType, + /// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and + /// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`. + pub sample_max_value: f32, + /// The algorithm type for the solver + pub algorithm_type: DPMSolverAlgorithmType, + /// The solver type for the second-order solver. + pub solver_type: DPMSolverType, + /// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + /// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10. + pub lower_order_final: bool, +} + +impl Default for DPMSolverSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.0001, + beta_end: 0.02, + beta_schedule: BetaSchedule::Linear, + train_timesteps: 1000, + solver_order: 2, + prediction_type: PredictionType::Epsilon, + sample_max_value: 1.0, + algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus, + solver_type: DPMSolverType::Midpoint, + lower_order_final: true, + } + } +} diff --git a/src/schedulers/dpmsolver_multistep.rs b/src/schedulers/dpmsolver_multistep.rs index 055e524..7179bef 100644 --- a/src/schedulers/dpmsolver_multistep.rs +++ b/src/schedulers/dpmsolver_multistep.rs @@ -1,72 +1,6 @@ -use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; -use std::iter; +use super::{betas_for_alpha_bar, BetaSchedule, PredictionType, dpmsolver::{DPMSolverSchedulerConfig, DPMSolverAlgorithmType, DPMSolverType}}; use tch::{kind, Kind, Tensor}; -/// The algorithm type for the solver. -/// -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub enum DPMSolverAlgorithmType { - /// Implements the algorithms defined in . - #[default] - DPMSolverPlusPlus, - /// Implements the algorithms defined in . - DPMSolver, -} - -/// The solver type for the second-order solver. -/// The solver type slightly affects the sample quality, especially for -/// small number of steps. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub enum DPMSolverType { - #[default] - Midpoint, - Heun, -} - -#[derive(Debug, Clone)] -pub struct DPMSolverMultistepSchedulerConfig { - /// The value of beta at the beginning of training. - pub beta_start: f64, - /// The value of beta at the end of training. - pub beta_end: f64, - /// How beta evolved during training. - pub beta_schedule: BetaSchedule, - /// number of diffusion steps used to train the model. - pub train_timesteps: usize, - /// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided - /// sampling, and `solver_order=3` for unconditional sampling. - pub solver_order: usize, - /// prediction type of the scheduler function - pub prediction_type: PredictionType, - /// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and - /// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`. - pub sample_max_value: f32, - /// The algorithm type for the solver - pub algorithm_type: DPMSolverAlgorithmType, - /// The solver type for the second-order solver. - pub solver_type: DPMSolverType, - /// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically - /// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10. - pub lower_order_final: bool, -} - -impl Default for DPMSolverMultistepSchedulerConfig { - fn default() -> Self { - Self { - beta_start: 0.00085, - beta_end: 0.012, - beta_schedule: BetaSchedule::ScaledLinear, - train_timesteps: 1000, - solver_order: 2, - prediction_type: PredictionType::Epsilon, - sample_max_value: 1.0, - algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus, - solver_type: DPMSolverType::Midpoint, - lower_order_final: true, - } - } -} - pub struct DPMSolverMultistepScheduler { alphas_cumprod: Vec, alpha_t: Vec, @@ -74,13 +8,15 @@ pub struct DPMSolverMultistepScheduler { lambda_t: Vec, init_noise_sigma: f64, lower_order_nums: usize, + /// Direct outputs from learned diffusion model at current and latter timesteps model_outputs: Vec, + /// List of current discrete timesteps in the diffusion chain timesteps: Vec, - pub config: DPMSolverMultistepSchedulerConfig, + pub config: DPMSolverSchedulerConfig, } impl DPMSolverMultistepScheduler { - pub fn new(inference_steps: usize, config: DPMSolverMultistepSchedulerConfig) -> Self { + pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self { let betas = match config.beta_schedule { BetaSchedule::ScaledLinear => Tensor::linspace( config.beta_start.sqrt(), diff --git a/src/schedulers/dpmsolver_singlestep.rs b/src/schedulers/dpmsolver_singlestep.rs index 384f0ec..64db142 100644 --- a/src/schedulers/dpmsolver_singlestep.rs +++ b/src/schedulers/dpmsolver_singlestep.rs @@ -1,73 +1,12 @@ use std::iter::repeat; -use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use super::{ + betas_for_alpha_bar, + dpmsolver::{DPMSolverAlgorithmType, DPMSolverSchedulerConfig, DPMSolverType}, + BetaSchedule, PredictionType, +}; use tch::{kind, Kind, Tensor}; -/// The algorithm type for the solver. -/// -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub enum DPMSolverAlgorithmType { - /// Implements the algorithms defined in . - #[default] - DPMSolverPlusPlus, - /// Implements the algorithms defined in . - DPMSolver, -} - -/// The solver type for the second-order solver. -/// The solver type slightly affects the sample quality, especially for -/// small number of steps. -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub enum DPMSolverType { - #[default] - Midpoint, - Heun, -} - -#[derive(Debug, Clone)] -pub struct DPMSolverSinglestepSchedulerConfig { - /// The value of beta at the beginning of training. - pub beta_start: f64, - /// The value of beta at the end of training. - pub beta_end: f64, - /// How beta evolved during training. - pub beta_schedule: BetaSchedule, - /// number of diffusion steps used to train the model. - pub train_timesteps: usize, - /// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided - /// sampling, and `solver_order=3` for unconditional sampling. - pub solver_order: usize, - /// prediction type of the scheduler function - pub prediction_type: PredictionType, - /// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and - /// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`. - pub sample_max_value: f32, - /// The algorithm type for the solver - pub algorithm_type: DPMSolverAlgorithmType, - /// The solver type for the second-order solver. - pub solver_type: DPMSolverType, - /// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically - /// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10. - pub lower_order_final: bool, -} - -impl Default for DPMSolverSinglestepSchedulerConfig { - fn default() -> Self { - Self { - beta_start: 0.0001, - beta_end: 0.02, - train_timesteps: 1000, - beta_schedule: BetaSchedule::Linear, - solver_order: 2, - prediction_type: PredictionType::Epsilon, - sample_max_value: 1.0, - algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus, - solver_type: DPMSolverType::Midpoint, - lower_order_final: true, - } - } -} - pub struct DPMSolverSinglestepScheduler { alphas_cumprod: Vec, alpha_t: Vec, @@ -75,14 +14,17 @@ pub struct DPMSolverSinglestepScheduler { lambda_t: Vec, init_noise_sigma: f64, order_list: Vec, + /// Direct outputs from learned diffusion model at current and latter timesteps model_outputs: Vec, + /// List of current discrete timesteps in the diffusion chain timesteps: Vec, + /// Current instance of sample being created by diffusion process sample: Option, - pub config: DPMSolverSinglestepSchedulerConfig, + pub config: DPMSolverSchedulerConfig, } impl DPMSolverSinglestepScheduler { - pub fn new(inference_steps: usize, config: DPMSolverSinglestepSchedulerConfig) -> Self { + pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self { let betas = match config.beta_schedule { BetaSchedule::ScaledLinear => Tensor::linspace( config.beta_start.sqrt(), @@ -462,8 +404,6 @@ fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) -> .chain([&[1][..]]) .flatten() .map(|v| *v) - .collect() - } } else if solver_order == 1 { repeat(&[1][..]).take(steps).flatten().map(|v| *v).collect() } else { @@ -473,11 +413,9 @@ fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) -> if solver_order == 3 { repeat(&[1, 2, 3][..]).take(steps / 3).flatten().map(|v| *v).collect() } else if solver_order == 2 { - repeat(&[1, 2][..]).take(steps / 2).flatten().map(|v| *v).collect() + repeat(dbg!(&[1, 2][..])).take(dbg!(steps / 2)).flatten().map(|v| dbg!(*v)).collect() } else if solver_order == 1 { repeat(&[1][..]).take(steps).flatten().map(|v| *v).collect() - } else { - panic!("invalid solver_order"); } } } diff --git a/src/schedulers/mod.rs b/src/schedulers/mod.rs index ac22531..c5ff209 100644 --- a/src/schedulers/mod.rs +++ b/src/schedulers/mod.rs @@ -7,6 +7,7 @@ use tch::{Kind, Tensor}; pub mod ddim; pub mod ddpm; +pub mod dpmsolver; pub mod dpmsolver_multistep; pub mod dpmsolver_singlestep; pub mod euler_ancestral_discrete; From b3080b0092617b66bd1f6628c5064a1b7a192035 Mon Sep 17 00:00:00 2001 From: "Dave Lage (rockerBOO)" Date: Sat, 7 Jan 2023 16:05:27 -0500 Subject: [PATCH 05/10] correction --- src/schedulers/dpmsolver_singlestep.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/schedulers/dpmsolver_singlestep.rs b/src/schedulers/dpmsolver_singlestep.rs index 64db142..d118da7 100644 --- a/src/schedulers/dpmsolver_singlestep.rs +++ b/src/schedulers/dpmsolver_singlestep.rs @@ -404,6 +404,8 @@ fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) -> .chain([&[1][..]]) .flatten() .map(|v| *v) + .collect() + } } else if solver_order == 1 { repeat(&[1][..]).take(steps).flatten().map(|v| *v).collect() } else { @@ -416,6 +418,8 @@ fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) -> repeat(dbg!(&[1, 2][..])).take(dbg!(steps / 2)).flatten().map(|v| dbg!(*v)).collect() } else if solver_order == 1 { repeat(&[1][..]).take(steps).flatten().map(|v| *v).collect() + } else { + panic!("invalid solver_order"); } } } From 74e33bcbe852178fb24b93e1553a5bf5177a806a Mon Sep 17 00:00:00 2001 From: "Dave Lage (rockerBOO)" Date: Sat, 7 Jan 2023 23:31:33 -0500 Subject: [PATCH 06/10] update tests --- src/schedulers/dpmsolver_singlestep.rs | 62 +++++++++----------------- 1 file changed, 22 insertions(+), 40 deletions(-) diff --git a/src/schedulers/dpmsolver_singlestep.rs b/src/schedulers/dpmsolver_singlestep.rs index d118da7..70a9b27 100644 --- a/src/schedulers/dpmsolver_singlestep.rs +++ b/src/schedulers/dpmsolver_singlestep.rs @@ -415,7 +415,7 @@ fn get_order_list(steps: usize, solver_order: usize, lower_order_final: bool) -> if solver_order == 3 { repeat(&[1, 2, 3][..]).take(steps / 3).flatten().map(|v| *v).collect() } else if solver_order == 2 { - repeat(dbg!(&[1, 2][..])).take(dbg!(steps / 2)).flatten().map(|v| dbg!(*v)).collect() + repeat(&[1, 2][..]).take(steps / 2).flatten().map(|v| *v).collect() } else if solver_order == 1 { repeat(&[1][..]).take(steps).flatten().map(|v| *v).collect() } else { @@ -430,52 +430,34 @@ mod tests { #[test] fn order_list() { - let list = get_order_list(15, 2, false); + assert_eq!(get_order_list(15, 2, false), vec![1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2],); - assert_eq!(15, list.len()); - assert_eq!(vec![1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2], list); - - let list = get_order_list(16, 2, false); - - assert_eq!(16, list.len()); - assert_eq!(vec![1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2], list); - - let list = get_order_list(16, 1, false); - - assert_eq!(16, list.len()); - assert_eq!(vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], list); - - let list = get_order_list(16, 3, false); - - assert_eq!(16, list.len()); - assert_eq!(vec![1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3], list); - - let list = get_order_list(16, 3, true); - - assert_eq!(16, list.len()); - assert_eq!(vec![1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1], list); - - let list = get_order_list(16, 1, true); - - assert_eq!(16, list.len()); - assert_eq!(vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], list); + assert_eq!( + get_order_list(16, 2, false), + vec![1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2] + ); - let list = get_order_list(25, 1, true); + assert_eq!( + get_order_list(16, 1, false), + vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + ); - assert_eq!(25, list.len()); + assert_eq!(get_order_list(16, 3, false), vec![1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]); + assert_eq!( + get_order_list(16, 3, true), + vec![1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1] + ); + assert_eq!( + get_order_list(16, 1, true), + vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ); assert_eq!( + get_order_list(25, 1, true), vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - list ); - let list = get_order_list(1, 1, true); - - assert_eq!(1, list.len()); - assert_eq!(vec![1], list); - - let list = get_order_list(2, 2, true); + assert_eq!(get_order_list(1, 1, true), vec![1]); - assert_eq!(2, list.len()); - assert_eq!(vec![1, 2], list); + assert_eq!(get_order_list(2, 2, true), vec![1, 2]); } } From 9a2d1e88669af58489f38fa71f15cdf38004b15f Mon Sep 17 00:00:00 2001 From: "Dave Lage (rockerBOO)" Date: Sun, 8 Jan 2023 01:08:04 -0500 Subject: [PATCH 07/10] default lower_order_final is true --- src/schedulers/dpmsolver_singlestep.rs | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/schedulers/dpmsolver_singlestep.rs b/src/schedulers/dpmsolver_singlestep.rs index 70a9b27..356912a 100644 --- a/src/schedulers/dpmsolver_singlestep.rs +++ b/src/schedulers/dpmsolver_singlestep.rs @@ -49,7 +49,7 @@ impl DPMSolverSinglestepScheduler { let lambda_t = alpha_t.log() - sigma_t.log(); let step = (config.train_timesteps - 1) as f64 / inference_steps as f64; - // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L199-L204 + // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L172-L173 let timesteps: Vec = (0..inference_steps + 1) .map(|i| (i as f64 * step).round() as usize) // discards the 0.0 element @@ -62,13 +62,15 @@ impl DPMSolverSinglestepScheduler { model_outputs.push(Tensor::new()); } + let order_list = get_order_list(inference_steps, config.solver_order, true); + Self { alphas_cumprod: Vec::::from(alphas_cumprod), alpha_t: Vec::::from(alpha_t), sigma_t: Vec::::from(sigma_t), lambda_t: Vec::::from(lambda_t), init_noise_sigma: 1., - order_list: get_order_list(inference_steps, config.solver_order, false), + order_list, model_outputs, timesteps, config, @@ -292,6 +294,12 @@ impl DPMSolverSinglestepScheduler { self.timesteps.as_slice() } + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + sample + } + /// Step function propagating the sample with the singlestep DPM-Solver /// /// # Arguments @@ -311,7 +319,7 @@ impl DPMSolverSinglestepScheduler { self.model_outputs[i] = self.model_outputs[i + 1].shallow_clone(); } let m = self.model_outputs.len(); - self.model_outputs[m - 1] = model_output; + self.model_outputs[m - 1] = model_output.shallow_clone(); let order = self.order_list[step_index]; @@ -320,7 +328,7 @@ impl DPMSolverSinglestepScheduler { self.sample = Some(sample.shallow_clone()); }; - let prev_sample = match order { + match order { 1 => self.dpm_solver_first_order_update( &self.model_outputs[self.model_outputs.len() - 1], timestep, @@ -331,7 +339,7 @@ impl DPMSolverSinglestepScheduler { &self.model_outputs, [self.timesteps[step_index - 1], self.timesteps[step_index]], prev_timestep, - self.sample.as_ref().unwrap(), + &self.sample.as_ref().unwrap(), ), 3 => self.singlestep_dpm_solver_third_order_update( &self.model_outputs, @@ -341,14 +349,12 @@ impl DPMSolverSinglestepScheduler { self.timesteps[step_index], ], prev_timestep, - self.sample.as_ref().unwrap(), + &self.sample.as_ref().unwrap(), ), _ => { panic!("invalid order"); } - }; - - prev_sample + } } pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { From 84b86802a6aad1ec78ec665261844347ecafc2ae Mon Sep 17 00:00:00 2001 From: "Dave Lage (rockerBOO)" Date: Sun, 8 Jan 2023 01:15:58 -0500 Subject: [PATCH 08/10] Remove unused import --- examples/stable-diffusion/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs index 531ee86..8b57704 100644 --- a/examples/stable-diffusion/main.rs +++ b/examples/stable-diffusion/main.rs @@ -70,7 +70,7 @@ // cargo run --release --example tensor-tools cp ./data/unet.npz ./data/unet.ot use clap::Parser; use diffusers::transformers::clip; -use diffusers::{pipelines::stable_diffusion, schedulers}; +use diffusers::pipelines::stable_diffusion; use tch::{nn::Module, Device, Kind, Tensor}; const GUIDANCE_SCALE: f64 = 7.5; From fd7db587a12aaf28528bc5a52ea79d787dbfa48e Mon Sep 17 00:00:00 2001 From: "Dave Lage (rockerBOO)" Date: Sun, 8 Jan 2023 02:33:26 -0500 Subject: [PATCH 09/10] Add DPMSolverScheduler trait --- src/schedulers/dpmsolver.rs | 45 ++++++++++++++++++++++++++ src/schedulers/dpmsolver_multistep.rs | 40 +++++++++++------------ src/schedulers/dpmsolver_singlestep.rs | 34 ++++++++++--------- 3 files changed, 81 insertions(+), 38 deletions(-) diff --git a/src/schedulers/dpmsolver.rs b/src/schedulers/dpmsolver.rs index 4aeb6b5..570aa54 100644 --- a/src/schedulers/dpmsolver.rs +++ b/src/schedulers/dpmsolver.rs @@ -1,3 +1,5 @@ +use tch::Tensor; + use crate::schedulers::BetaSchedule; use crate::schedulers::PredictionType; @@ -65,3 +67,46 @@ impl Default for DPMSolverSchedulerConfig { } } } + +pub trait DPMSolverScheduler { + fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self; + fn convert_model_output( + &self, + model_output: &Tensor, + timestep: usize, + sample: &Tensor, + ) -> Tensor; + + fn first_order_update( + &self, + model_output: Tensor, + timestep: usize, + prev_timestep: usize, + sample: &Tensor, + ) -> Tensor; + + fn second_order_update( + &self, + model_output_list: &Vec, + timestep_list: [usize; 2], + prev_timestep: usize, + sample: &Tensor, + ) -> Tensor; + + fn third_order_update( + &self, + model_output_list: &Vec, + timestep_list: [usize; 3], + prev_timestep: usize, + sample: &Tensor, + ) -> Tensor; + + fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor; + + fn timesteps(&self) -> &[usize]; + fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Tensor; + + + fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor; + fn init_noise_sigma(&self) -> f64; +} diff --git a/src/schedulers/dpmsolver_multistep.rs b/src/schedulers/dpmsolver_multistep.rs index 7179bef..1de6ce8 100644 --- a/src/schedulers/dpmsolver_multistep.rs +++ b/src/schedulers/dpmsolver_multistep.rs @@ -1,4 +1,10 @@ -use super::{betas_for_alpha_bar, BetaSchedule, PredictionType, dpmsolver::{DPMSolverSchedulerConfig, DPMSolverAlgorithmType, DPMSolverType}}; +use super::{ + betas_for_alpha_bar, + dpmsolver::{ + DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType, + }, + BetaSchedule, PredictionType, +}; use tch::{kind, Kind, Tensor}; pub struct DPMSolverMultistepScheduler { @@ -15,8 +21,8 @@ pub struct DPMSolverMultistepScheduler { pub config: DPMSolverSchedulerConfig, } -impl DPMSolverMultistepScheduler { - pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self { +impl DPMSolverScheduler for DPMSolverMultistepScheduler { + fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self { let betas = match config.beta_schedule { BetaSchedule::ScaledLinear => Tensor::linspace( config.beta_start.sqrt(), @@ -116,7 +122,7 @@ impl DPMSolverMultistepScheduler { /// One step for the first-order DPM-Solver (equivalent to DDIM). /// See https://arxiv.org/abs/2206.00927 for the detailed derivation. - fn dpm_solver_first_order_update( + fn first_order_update( &self, model_output: Tensor, timestep: usize, @@ -138,7 +144,7 @@ impl DPMSolverMultistepScheduler { } /// One step for the second-order multistep DPM-Solver. - fn multistep_dpm_solver_second_order_update( + fn second_order_update( &self, model_output_list: &Vec, timestep_list: [usize; 2], @@ -191,7 +197,7 @@ impl DPMSolverMultistepScheduler { } /// One step for the third-order multistep DPM-Solver - fn multistep_dpm_solver_third_order_update( + fn third_order_update( &self, model_output_list: &Vec, timestep_list: [usize; 3], @@ -236,7 +242,7 @@ impl DPMSolverMultistepScheduler { } } - pub fn timesteps(&self) -> &[usize] { + fn timesteps(&self) -> &[usize] { self.timesteps.as_slice() } @@ -271,24 +277,14 @@ impl DPMSolverMultistepScheduler { || self.lower_order_nums < 1 || lower_order_final { - self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) + self.first_order_update(model_output, timestep, prev_timestep, sample) } else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second { let timestep_list = [self.timesteps[step_index - 1], timestep]; - self.multistep_dpm_solver_second_order_update( - &self.model_outputs, - timestep_list, - prev_timestep, - sample, - ) + self.second_order_update(&self.model_outputs, timestep_list, prev_timestep, sample) } else { let timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]; - self.multistep_dpm_solver_third_order_update( - &self.model_outputs, - timestep_list, - prev_timestep, - sample, - ) + self.third_order_update(&self.model_outputs, timestep_list, prev_timestep, sample) }; if self.lower_order_nums < self.config.solver_order { @@ -298,12 +294,12 @@ impl DPMSolverMultistepScheduler { prev_sample } - pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { + fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned() + (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise } - pub fn init_noise_sigma(&self) -> f64 { + fn init_noise_sigma(&self) -> f64 { self.init_noise_sigma } } diff --git a/src/schedulers/dpmsolver_singlestep.rs b/src/schedulers/dpmsolver_singlestep.rs index 356912a..1c8c3dc 100644 --- a/src/schedulers/dpmsolver_singlestep.rs +++ b/src/schedulers/dpmsolver_singlestep.rs @@ -2,7 +2,9 @@ use std::iter::repeat; use super::{ betas_for_alpha_bar, - dpmsolver::{DPMSolverAlgorithmType, DPMSolverSchedulerConfig, DPMSolverType}, + dpmsolver::{ + DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType, + }, BetaSchedule, PredictionType, }; use tch::{kind, Kind, Tensor}; @@ -23,8 +25,8 @@ pub struct DPMSolverSinglestepScheduler { pub config: DPMSolverSchedulerConfig, } -impl DPMSolverSinglestepScheduler { - pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self { +impl DPMSolverScheduler for DPMSolverSinglestepScheduler { + fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self { let betas = match config.beta_schedule { BetaSchedule::ScaledLinear => Tensor::linspace( config.beta_start.sqrt(), @@ -141,9 +143,9 @@ impl DPMSolverSinglestepScheduler { /// * `timestep` - current discrete timestep in the diffusion chain /// * `prev_timestep` - previous discrete timestep in the diffusion chain /// * `sample` - current instance of sample being created by diffusion process - fn dpm_solver_first_order_update( + fn first_order_update( &self, - model_output: &Tensor, + model_output: Tensor, timestep: usize, prev_timestep: usize, sample: &Tensor, @@ -171,7 +173,7 @@ impl DPMSolverSinglestepScheduler { /// * `timestep_list` - current and latter discrete timestep in the diffusion chain /// * `prev_timestep` - previous discrete timestep in the diffusion chain /// * `sample` - current instance of sample being created by diffusion process - fn singlestep_dpm_solver_second_order_update( + fn second_order_update( &self, model_output_list: &Vec, timestep_list: [usize; 2], @@ -232,7 +234,7 @@ impl DPMSolverSinglestepScheduler { /// * `timestep_list` - current and latter discrete timestep in the diffusion chain /// * `prev_timestep` - previous discrete timestep in the diffusion chain /// * `sample` - current instance of sample being created by diffusion process - fn singlestep_dpm_solver_third_order_update( + fn third_order_update( &self, model_output_list: &Vec, timestep_list: [usize; 3], @@ -290,13 +292,13 @@ impl DPMSolverSinglestepScheduler { } } - pub fn timesteps(&self) -> &[usize] { + fn timesteps(&self) -> &[usize] { self.timesteps.as_slice() } /// Ensures interchangeability with schedulers that need to scale the denoising model input /// depending on the current timestep. - pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { sample } @@ -307,7 +309,7 @@ impl DPMSolverSinglestepScheduler { /// * `model_output` - direct output from learned diffusion model /// * `timestep` - current discrete timestep in the diffusion chain /// * `sample` - current instance of sample being created by diffusion process - pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { + fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535 let step_index: usize = self.timesteps.iter().position(|&t| t == timestep).unwrap(); @@ -329,19 +331,19 @@ impl DPMSolverSinglestepScheduler { }; match order { - 1 => self.dpm_solver_first_order_update( - &self.model_outputs[self.model_outputs.len() - 1], + 1 => self.first_order_update( + model_output, timestep, prev_timestep, &self.sample.as_ref().unwrap(), ), - 2 => self.singlestep_dpm_solver_second_order_update( + 2 => self.second_order_update( &self.model_outputs, [self.timesteps[step_index - 1], self.timesteps[step_index]], prev_timestep, &self.sample.as_ref().unwrap(), ), - 3 => self.singlestep_dpm_solver_third_order_update( + 3 => self.third_order_update( &self.model_outputs, [ self.timesteps[step_index - 2], @@ -357,12 +359,12 @@ impl DPMSolverSinglestepScheduler { } } - pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { + fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned() + (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise } - pub fn init_noise_sigma(&self) -> f64 { + fn init_noise_sigma(&self) -> f64 { self.init_noise_sigma } } From 41e0286dba34d6c7e80d702062fc4884e18a6d53 Mon Sep 17 00:00:00 2001 From: "Dave Lage (rockerBOO)" Date: Sat, 14 Jan 2023 16:36:01 -0500 Subject: [PATCH 10/10] Revert "Add DPMSolverScheduler trait" This reverts commit 83b28b34fa174e71c79faff23d6729c05ca5b5c1. --- src/schedulers/dpmsolver.rs | 45 -------------------------- src/schedulers/dpmsolver_multistep.rs | 41 ++++++++++++----------- src/schedulers/dpmsolver_singlestep.rs | 34 +++++++++---------- 3 files changed, 39 insertions(+), 81 deletions(-) diff --git a/src/schedulers/dpmsolver.rs b/src/schedulers/dpmsolver.rs index 570aa54..4aeb6b5 100644 --- a/src/schedulers/dpmsolver.rs +++ b/src/schedulers/dpmsolver.rs @@ -1,5 +1,3 @@ -use tch::Tensor; - use crate::schedulers::BetaSchedule; use crate::schedulers::PredictionType; @@ -67,46 +65,3 @@ impl Default for DPMSolverSchedulerConfig { } } } - -pub trait DPMSolverScheduler { - fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self; - fn convert_model_output( - &self, - model_output: &Tensor, - timestep: usize, - sample: &Tensor, - ) -> Tensor; - - fn first_order_update( - &self, - model_output: Tensor, - timestep: usize, - prev_timestep: usize, - sample: &Tensor, - ) -> Tensor; - - fn second_order_update( - &self, - model_output_list: &Vec, - timestep_list: [usize; 2], - prev_timestep: usize, - sample: &Tensor, - ) -> Tensor; - - fn third_order_update( - &self, - model_output_list: &Vec, - timestep_list: [usize; 3], - prev_timestep: usize, - sample: &Tensor, - ) -> Tensor; - - fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor; - - fn timesteps(&self) -> &[usize]; - fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Tensor; - - - fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor; - fn init_noise_sigma(&self) -> f64; -} diff --git a/src/schedulers/dpmsolver_multistep.rs b/src/schedulers/dpmsolver_multistep.rs index 1de6ce8..f66ae60 100644 --- a/src/schedulers/dpmsolver_multistep.rs +++ b/src/schedulers/dpmsolver_multistep.rs @@ -1,10 +1,5 @@ -use super::{ - betas_for_alpha_bar, - dpmsolver::{ - DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType, - }, - BetaSchedule, PredictionType, -}; +use super::{betas_for_alpha_bar, BetaSchedule, PredictionType, dpmsolver::{DPMSolverSchedulerConfig, DPMSolverAlgorithmType, DPMSolverType}}; +use std::iter; use tch::{kind, Kind, Tensor}; pub struct DPMSolverMultistepScheduler { @@ -21,8 +16,8 @@ pub struct DPMSolverMultistepScheduler { pub config: DPMSolverSchedulerConfig, } -impl DPMSolverScheduler for DPMSolverMultistepScheduler { - fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self { +impl DPMSolverMultistepScheduler { + pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self { let betas = match config.beta_schedule { BetaSchedule::ScaledLinear => Tensor::linspace( config.beta_start.sqrt(), @@ -122,7 +117,7 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler { /// One step for the first-order DPM-Solver (equivalent to DDIM). /// See https://arxiv.org/abs/2206.00927 for the detailed derivation. - fn first_order_update( + fn dpm_solver_first_order_update( &self, model_output: Tensor, timestep: usize, @@ -144,7 +139,7 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler { } /// One step for the second-order multistep DPM-Solver. - fn second_order_update( + fn multistep_dpm_solver_second_order_update( &self, model_output_list: &Vec, timestep_list: [usize; 2], @@ -197,7 +192,7 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler { } /// One step for the third-order multistep DPM-Solver - fn third_order_update( + fn multistep_dpm_solver_third_order_update( &self, model_output_list: &Vec, timestep_list: [usize; 3], @@ -242,7 +237,7 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler { } } - fn timesteps(&self) -> &[usize] { + pub fn timesteps(&self) -> &[usize] { self.timesteps.as_slice() } @@ -277,14 +272,24 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler { || self.lower_order_nums < 1 || lower_order_final { - self.first_order_update(model_output, timestep, prev_timestep, sample) + self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) } else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second { let timestep_list = [self.timesteps[step_index - 1], timestep]; - self.second_order_update(&self.model_outputs, timestep_list, prev_timestep, sample) + self.multistep_dpm_solver_second_order_update( + &self.model_outputs, + timestep_list, + prev_timestep, + sample, + ) } else { let timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]; - self.third_order_update(&self.model_outputs, timestep_list, prev_timestep, sample) + self.multistep_dpm_solver_third_order_update( + &self.model_outputs, + timestep_list, + prev_timestep, + sample, + ) }; if self.lower_order_nums < self.config.solver_order { @@ -294,12 +299,12 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler { prev_sample } - fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { + pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned() + (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise } - fn init_noise_sigma(&self) -> f64 { + pub fn init_noise_sigma(&self) -> f64 { self.init_noise_sigma } } diff --git a/src/schedulers/dpmsolver_singlestep.rs b/src/schedulers/dpmsolver_singlestep.rs index 1c8c3dc..356912a 100644 --- a/src/schedulers/dpmsolver_singlestep.rs +++ b/src/schedulers/dpmsolver_singlestep.rs @@ -2,9 +2,7 @@ use std::iter::repeat; use super::{ betas_for_alpha_bar, - dpmsolver::{ - DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType, - }, + dpmsolver::{DPMSolverAlgorithmType, DPMSolverSchedulerConfig, DPMSolverType}, BetaSchedule, PredictionType, }; use tch::{kind, Kind, Tensor}; @@ -25,8 +23,8 @@ pub struct DPMSolverSinglestepScheduler { pub config: DPMSolverSchedulerConfig, } -impl DPMSolverScheduler for DPMSolverSinglestepScheduler { - fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self { +impl DPMSolverSinglestepScheduler { + pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self { let betas = match config.beta_schedule { BetaSchedule::ScaledLinear => Tensor::linspace( config.beta_start.sqrt(), @@ -143,9 +141,9 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler { /// * `timestep` - current discrete timestep in the diffusion chain /// * `prev_timestep` - previous discrete timestep in the diffusion chain /// * `sample` - current instance of sample being created by diffusion process - fn first_order_update( + fn dpm_solver_first_order_update( &self, - model_output: Tensor, + model_output: &Tensor, timestep: usize, prev_timestep: usize, sample: &Tensor, @@ -173,7 +171,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler { /// * `timestep_list` - current and latter discrete timestep in the diffusion chain /// * `prev_timestep` - previous discrete timestep in the diffusion chain /// * `sample` - current instance of sample being created by diffusion process - fn second_order_update( + fn singlestep_dpm_solver_second_order_update( &self, model_output_list: &Vec, timestep_list: [usize; 2], @@ -234,7 +232,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler { /// * `timestep_list` - current and latter discrete timestep in the diffusion chain /// * `prev_timestep` - previous discrete timestep in the diffusion chain /// * `sample` - current instance of sample being created by diffusion process - fn third_order_update( + fn singlestep_dpm_solver_third_order_update( &self, model_output_list: &Vec, timestep_list: [usize; 3], @@ -292,13 +290,13 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler { } } - fn timesteps(&self) -> &[usize] { + pub fn timesteps(&self) -> &[usize] { self.timesteps.as_slice() } /// Ensures interchangeability with schedulers that need to scale the denoising model input /// depending on the current timestep. - fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { sample } @@ -309,7 +307,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler { /// * `model_output` - direct output from learned diffusion model /// * `timestep` - current discrete timestep in the diffusion chain /// * `sample` - current instance of sample being created by diffusion process - fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { + pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535 let step_index: usize = self.timesteps.iter().position(|&t| t == timestep).unwrap(); @@ -331,19 +329,19 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler { }; match order { - 1 => self.first_order_update( - model_output, + 1 => self.dpm_solver_first_order_update( + &self.model_outputs[self.model_outputs.len() - 1], timestep, prev_timestep, &self.sample.as_ref().unwrap(), ), - 2 => self.second_order_update( + 2 => self.singlestep_dpm_solver_second_order_update( &self.model_outputs, [self.timesteps[step_index - 1], self.timesteps[step_index]], prev_timestep, &self.sample.as_ref().unwrap(), ), - 3 => self.third_order_update( + 3 => self.singlestep_dpm_solver_third_order_update( &self.model_outputs, [ self.timesteps[step_index - 2], @@ -359,12 +357,12 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler { } } - fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { + pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned() + (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise } - fn init_noise_sigma(&self) -> f64 { + pub fn init_noise_sigma(&self) -> f64 { self.init_noise_sigma } }