mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
UniPC for diffusion sampling (#2684)
* feat: Add unipc multistep scheduler * chore: Clippy and formatting * chore: Update comments * chore: Avoid unsafety in float ordering * refactor: Update Scheduler::step mutability requirements * fix: Corrector img2img * chore: Update unipc ref link to latest diffusers release * chore: Deduplicate float ordering * fix: Panic when running with dev profile
This commit is contained in:
@ -127,7 +127,7 @@ impl DDIMScheduler {
|
||||
|
||||
impl Scheduler for DDIMScheduler {
|
||||
/// Performs a backward step during inference.
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
let timestep = if timestep >= self.alphas_cumprod.len() {
|
||||
timestep - 1
|
||||
} else {
|
||||
|
@ -171,7 +171,7 @@ impl Scheduler for EulerAncestralDiscreteScheduler {
|
||||
}
|
||||
|
||||
/// Performs a backward step during inference.
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
|
||||
let step_index = self
|
||||
.timesteps
|
||||
.iter()
|
||||
|
@ -47,6 +47,7 @@ pub mod resnet;
|
||||
pub mod schedulers;
|
||||
pub mod unet_2d;
|
||||
pub mod unet_2d_blocks;
|
||||
pub mod uni_pc;
|
||||
pub mod utils;
|
||||
pub mod vae;
|
||||
|
||||
|
@ -19,7 +19,7 @@ pub trait Scheduler {
|
||||
|
||||
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
|
||||
|
||||
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
|
||||
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
/// This represents how beta ranges from its minimum value to the maximum
|
||||
|
1005
candle-transformers/src/models/stable_diffusion/uni_pc.rs
Normal file
1005
candle-transformers/src/models/stable_diffusion/uni_pc.rs
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user