UniPC for diffusion sampling (#2684)

* feat: Add unipc multistep scheduler

* chore: Clippy and formatting

* chore: Update comments

* chore: Avoid unsafety in float ordering

* refactor: Update Scheduler::step mutability requirements

* fix: Corrector img2img

* chore: Update unipc ref link to latest diffusers release

* chore: Deduplicate float ordering

* fix: Panic when running with dev profile
This commit is contained in:
Nick Senger
2025-01-01 12:34:17 -08:00
committed by GitHub
parent b12c7c2888
commit cbaa0ad46f
6 changed files with 1011 additions and 5 deletions

View File

@ -477,7 +477,7 @@ fn run(args: Args) -> Result<()> {
),
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let mut scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
if let Some(seed) = seed {
device.set_seed(seed)?;
@ -539,7 +539,7 @@ fn run(args: Args) -> Result<()> {
};
for idx in 0..num_samples {
let timesteps = scheduler.timesteps();
let timesteps = scheduler.timesteps().to_vec();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;

View File

@ -127,7 +127,7 @@ impl DDIMScheduler {
impl Scheduler for DDIMScheduler {
/// Performs a backward step during inference.
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {

View File

@ -171,7 +171,7 @@ impl Scheduler for EulerAncestralDiscreteScheduler {
}
/// Performs a backward step during inference.
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let step_index = self
.timesteps
.iter()

View File

@ -47,6 +47,7 @@ pub mod resnet;
pub mod schedulers;
pub mod unet_2d;
pub mod unet_2d_blocks;
pub mod uni_pc;
pub mod utils;
pub mod vae;

View File

@ -19,7 +19,7 @@ pub trait Scheduler {
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
}
/// This represents how beta ranges from its minimum value to the maximum

File diff suppressed because it is too large Load Diff