mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
UniPC for diffusion sampling (#2684)
* feat: Add unipc multistep scheduler * chore: Clippy and formatting * chore: Update comments * chore: Avoid unsafety in float ordering * refactor: Update Scheduler::step mutability requirements * fix: Corrector img2img * chore: Update unipc ref link to latest diffusers release * chore: Deduplicate float ordering * fix: Panic when running with dev profile
This commit is contained in:
@ -477,7 +477,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
let mut scheduler = sd_config.build_scheduler(n_steps)?;
|
||||||
let device = candle_examples::device(cpu)?;
|
let device = candle_examples::device(cpu)?;
|
||||||
if let Some(seed) = seed {
|
if let Some(seed) = seed {
|
||||||
device.set_seed(seed)?;
|
device.set_seed(seed)?;
|
||||||
@ -539,7 +539,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
for idx in 0..num_samples {
|
for idx in 0..num_samples {
|
||||||
let timesteps = scheduler.timesteps();
|
let timesteps = scheduler.timesteps().to_vec();
|
||||||
let latents = match &init_latent_dist {
|
let latents = match &init_latent_dist {
|
||||||
Some(init_latent_dist) => {
|
Some(init_latent_dist) => {
|
||||||
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
|
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
|
||||||
|
@ -127,7 +127,7 @@ impl DDIMScheduler {
|
|||||||
|
|
||||||
impl Scheduler for DDIMScheduler {
|
impl Scheduler for DDIMScheduler {
|
||||||
/// Performs a backward step during inference.
|
/// Performs a backward step during inference.
|
||||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||||
timestep - 1
|
timestep - 1
|
||||||
} else {
|
} else {
|
||||||
|
@ -171,7 +171,7 @@ impl Scheduler for EulerAncestralDiscreteScheduler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Performs a backward step during inference.
|
/// Performs a backward step during inference.
|
||||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||||
let step_index = self
|
let step_index = self
|
||||||
.timesteps
|
.timesteps
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -47,6 +47,7 @@ pub mod resnet;
|
|||||||
pub mod schedulers;
|
pub mod schedulers;
|
||||||
pub mod unet_2d;
|
pub mod unet_2d;
|
||||||
pub mod unet_2d_blocks;
|
pub mod unet_2d_blocks;
|
||||||
|
pub mod uni_pc;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
pub mod vae;
|
pub mod vae;
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ pub trait Scheduler {
|
|||||||
|
|
||||||
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
|
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
|
||||||
|
|
||||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
|
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This represents how beta ranges from its minimum value to the maximum
|
/// This represents how beta ranges from its minimum value to the maximum
|
||||||
|
1005
candle-transformers/src/models/stable_diffusion/uni_pc.rs
Normal file
1005
candle-transformers/src/models/stable_diffusion/uni_pc.rs
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user