mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix the W clip embeddings. (#887)
* Fix the W clip embeddings. * Add the specialized ddpm scheduler.
This commit is contained in:
@ -193,9 +193,9 @@ fn encode_prompt(
|
|||||||
println!("Building the clip transformer.");
|
println!("Building the clip transformer.");
|
||||||
let text_model =
|
let text_model =
|
||||||
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
||||||
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len)?;
|
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
|
||||||
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len)?;
|
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
|
||||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
|
||||||
Ok(text_embeddings)
|
Ok(text_embeddings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
97
candle-transformers/src/models/wuerstchen/ddpm.rs
Normal file
97
candle-transformers/src/models/wuerstchen/ddpm.rs
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DDPMWSchedulerConfig {
|
||||||
|
scaler: f64,
|
||||||
|
s: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DDPMWSchedulerConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
scaler: 1f64,
|
||||||
|
s: 0.008f64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DDPMWScheduler {
|
||||||
|
init_alpha_cumprod: f64,
|
||||||
|
init_noise_sigma: f64,
|
||||||
|
timesteps: Vec<f64>,
|
||||||
|
pub config: DDPMWSchedulerConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DDPMWScheduler {
|
||||||
|
pub fn new(inference_steps: usize, config: DDPMWSchedulerConfig) -> Result<Self> {
|
||||||
|
let init_alpha_cumprod = (config.s / (1. + config.s) * std::f64::consts::PI)
|
||||||
|
.cos()
|
||||||
|
.powi(2);
|
||||||
|
let timesteps = (0..=inference_steps)
|
||||||
|
.map(|i| 1. - i as f64 / inference_steps as f64)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
Ok(Self {
|
||||||
|
init_alpha_cumprod,
|
||||||
|
init_noise_sigma: 1.0,
|
||||||
|
timesteps,
|
||||||
|
config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alpha_cumprod(&self, t: f64) -> f64 {
|
||||||
|
let scaler = self.config.scaler;
|
||||||
|
let s = self.config.s;
|
||||||
|
let t = if scaler > 1. {
|
||||||
|
1. - (1. - t).powf(scaler)
|
||||||
|
} else if scaler < 1. {
|
||||||
|
t.powf(scaler)
|
||||||
|
} else {
|
||||||
|
t
|
||||||
|
};
|
||||||
|
let alpha_cumprod =
|
||||||
|
((t + s) / (1. + s) * std::f64::consts::PI * 0.5).powi(2) / self.init_alpha_cumprod;
|
||||||
|
alpha_cumprod.clamp(0.0001, 0.9999)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn previous_timestep(&self, ts: f64) -> f64 {
|
||||||
|
let index = self
|
||||||
|
.timesteps
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(idx, v)| (idx, (v - ts).abs()))
|
||||||
|
.min_by(|x, y| x.1.total_cmp(&y.1))
|
||||||
|
.unwrap()
|
||||||
|
.0;
|
||||||
|
self.timesteps[index + 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ensures interchangeability with schedulers that need to scale the denoising model input
|
||||||
|
/// depending on the current timestep.
|
||||||
|
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
|
||||||
|
sample
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn step(&self, model_output: &Tensor, ts: f64, sample: &Tensor) -> Result<Tensor> {
|
||||||
|
let prev_t = self.previous_timestep(ts);
|
||||||
|
|
||||||
|
let alpha_cumprod = self.alpha_cumprod(ts);
|
||||||
|
let alpha_cumprod_prev = self.alpha_cumprod(prev_t);
|
||||||
|
let alpha = alpha_cumprod / alpha_cumprod_prev;
|
||||||
|
|
||||||
|
let mu = (sample - model_output * ((1. - alpha) / (1. - alpha_cumprod).sqrt()))?;
|
||||||
|
let mu = (mu * (1. / alpha).sqrt())?;
|
||||||
|
|
||||||
|
let std_noise = mu.randn_like(0., 1.)?;
|
||||||
|
let std =
|
||||||
|
std_noise * ((1. - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt();
|
||||||
|
if prev_t == 0. {
|
||||||
|
Ok(mu)
|
||||||
|
} else {
|
||||||
|
mu + std
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn init_noise_sigma(&self) -> f64 {
|
||||||
|
self.init_noise_sigma
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,5 @@
|
|||||||
pub mod common;
|
pub mod common;
|
||||||
|
pub mod ddpm;
|
||||||
pub mod diffnext;
|
pub mod diffnext;
|
||||||
pub mod paella_vq;
|
pub mod paella_vq;
|
||||||
pub mod prior;
|
pub mod prior;
|
||||||
|
Reference in New Issue
Block a user