mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
8 Commits
0.9.0-alph
...
w-uncond
Author | SHA1 | Date | |
---|---|---|---|
4114872aae | |||
f2a648f313 | |||
ec895453cd | |||
3769d8bf71 | |||
5d8e214dfe | |||
576bf7c21f | |||
49a4fa44bb | |||
b936e32e11 |
@ -16,7 +16,9 @@ use tokenizers::Tokenizer;
|
||||
|
||||
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
|
||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||
const LATENT_DIM_SCALE: f64 = 10.67;
|
||||
const PRIOR_CIN: usize = 16;
|
||||
const DECODER_CIN: usize = 4;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -156,7 +158,7 @@ fn output_filename(
|
||||
|
||||
fn encode_prompt(
|
||||
prompt: &str,
|
||||
uncond_prompt: &str,
|
||||
uncond_prompt: Option<&str>,
|
||||
tokenizer: std::path::PathBuf,
|
||||
clip_weights: std::path::PathBuf,
|
||||
clip_config: stable_diffusion::clip::Config,
|
||||
@ -179,6 +181,13 @@ fn encode_prompt(
|
||||
}
|
||||
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the clip transformer.");
|
||||
let text_model =
|
||||
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
||||
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
|
||||
match uncond_prompt {
|
||||
None => Ok(text_embeddings),
|
||||
Some(uncond_prompt) => {
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
@ -190,14 +199,13 @@ fn encode_prompt(
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the clip transformer.");
|
||||
let text_model =
|
||||
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
||||
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
|
||||
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
|
||||
let uncond_embeddings =
|
||||
text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
|
||||
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
|
||||
Ok(text_embeddings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
@ -239,31 +247,42 @@ fn run(args: Args) -> Result<()> {
|
||||
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
|
||||
encode_prompt(
|
||||
&prompt,
|
||||
&uncond_prompt,
|
||||
Some(&uncond_prompt),
|
||||
tokenizer.clone(),
|
||||
weights,
|
||||
stable_diffusion::clip::Config::wuerstchen_prior(),
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
println!("{prior_text_embeddings}");
|
||||
println!("generated prior text embeddings {prior_text_embeddings:?}");
|
||||
|
||||
let text_embeddings = {
|
||||
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
|
||||
let weights = ModelFile::Clip.get(clip_weights)?;
|
||||
encode_prompt(
|
||||
&prompt,
|
||||
&uncond_prompt,
|
||||
None,
|
||||
tokenizer.clone(),
|
||||
weights,
|
||||
stable_diffusion::clip::Config::wuerstchen(),
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
println!("{prior_text_embeddings}");
|
||||
println!("generated text embeddings {text_embeddings:?}");
|
||||
|
||||
println!("Building the prior.");
|
||||
let b_size = 1;
|
||||
let image_embeddings = {
|
||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
||||
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let mut latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
let prior = {
|
||||
let prior_weights = ModelFile::Prior.get(prior_weights)?;
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
|
||||
@ -274,6 +293,27 @@ fn run(args: Args) -> Result<()> {
|
||||
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
||||
)?
|
||||
};
|
||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let timesteps = prior_scheduler.timesteps();
|
||||
println!("prior denoising");
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
|
||||
let noise_pred = (noise_pred_uncond
|
||||
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
|
||||
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||
}
|
||||
((latents * 42.)? - 1.)?
|
||||
};
|
||||
|
||||
println!("Building the vqgan.");
|
||||
let vqgan = {
|
||||
@ -293,58 +333,40 @@ fn run(args: Args) -> Result<()> {
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
wuerstchen::diffnext::WDiffNeXt::new(
|
||||
/* c_in */ 4, /* c_out */ 4, /* c_r */ 64, /* c_cond */ 1024,
|
||||
/* clip_embd */ 1024, /* patch_size */ 2, vb,
|
||||
/* c_in */ DECODER_CIN,
|
||||
/* c_out */ DECODER_CIN,
|
||||
/* c_r */ 64,
|
||||
/* c_cond */ 1024,
|
||||
/* clip_embd */ 1024,
|
||||
/* patch_size */ 2,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
|
||||
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||
let b_size = 1;
|
||||
for idx in 0..num_samples {
|
||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json
|
||||
let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize;
|
||||
let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize;
|
||||
|
||||
let mut latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||
(b_size, DECODER_CIN, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let timesteps = prior_scheduler.timesteps();
|
||||
println!("prior denoising");
|
||||
println!("diffusion process with prior {image_embeddings:?}");
|
||||
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let timesteps = scheduler.timesteps();
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
||||
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
|
||||
let noise_pred = (noise_pred_uncond
|
||||
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
|
||||
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||
}
|
||||
let effnet = ((latents * 42.)? - 1.)?;
|
||||
let mut latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
|
||||
println!("diffusion process");
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||
let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?;
|
||||
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
||||
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
|
||||
let noise_pred =
|
||||
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
|
||||
latents = scheduler.step(&noise_pred, t, &latents)?;
|
||||
let dt = start_time.elapsed().as_secs_f32();
|
||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||
}
|
||||
|
@ -34,6 +34,34 @@ impl Module for WLayerNorm {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNormNoWeights {
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNormNoWeights {
|
||||
pub fn new(_size: usize) -> Result<Self> {
|
||||
Ok(Self { eps: 1e-6 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNormNoWeights {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = xs.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = xs.dim(D::Minus1)?;
|
||||
let xs = xs.to_dtype(internal_dtype)?;
|
||||
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let xs = xs.broadcast_sub(&mean_x)?;
|
||||
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
|
||||
.to_dtype(x_dtype)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TimestepBlock {
|
||||
mapper: candle_nn::Linear,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm};
|
||||
use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
@ -37,7 +37,7 @@ impl ResBlockStageB {
|
||||
let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?;
|
||||
let xs = match x_skip {
|
||||
None => xs.clone(),
|
||||
Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?,
|
||||
Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?.contiguous()?,
|
||||
};
|
||||
let xs = xs
|
||||
.permute((0, 2, 3, 1))?
|
||||
@ -75,7 +75,7 @@ struct UpBlock {
|
||||
pub struct WDiffNeXt {
|
||||
clip_mapper: candle_nn::Linear,
|
||||
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
|
||||
seq_norm: WLayerNorm,
|
||||
seq_norm: LayerNormNoWeights,
|
||||
embedding_conv: candle_nn::Conv2d,
|
||||
embedding_ln: WLayerNorm,
|
||||
down_blocks: Vec<DownBlock>,
|
||||
@ -133,7 +133,7 @@ impl WDiffNeXt {
|
||||
};
|
||||
effnet_mappers.push(c)
|
||||
}
|
||||
let seq_norm = WLayerNorm::new(c_cond)?;
|
||||
let seq_norm = LayerNormNoWeights::new(c_cond)?;
|
||||
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
|
||||
let embedding_conv = candle_nn::conv2d(
|
||||
c_in * patch_size * patch_size,
|
||||
@ -335,6 +335,7 @@ impl WDiffNeXt {
|
||||
level_outputs.push(xs.clone())
|
||||
}
|
||||
level_outputs.reverse();
|
||||
let mut xs = level_outputs[0].clone();
|
||||
|
||||
for (i, up_block) in self.up_blocks.iter().enumerate() {
|
||||
let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
|
||||
@ -351,7 +352,9 @@ impl WDiffNeXt {
|
||||
None
|
||||
};
|
||||
let skip = match (skip, effnet_c.as_ref()) {
|
||||
(Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
|
||||
(Some(skip), Some(effnet_c)) => {
|
||||
Some(Tensor::cat(&[skip, effnet_c], 1)?.contiguous()?)
|
||||
}
|
||||
(None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
|
||||
(None, None) => None,
|
||||
};
|
||||
|
@ -1,12 +1,12 @@
|
||||
use super::common::WLayerNorm;
|
||||
use super::common::LayerNormNoWeights;
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MixingResidualBlock {
|
||||
norm1: WLayerNorm,
|
||||
norm1: LayerNormNoWeights,
|
||||
depthwise_conv: candle_nn::Conv2d,
|
||||
norm2: WLayerNorm,
|
||||
norm2: LayerNormNoWeights,
|
||||
channelwise_lin1: candle_nn::Linear,
|
||||
channelwise_lin2: candle_nn::Linear,
|
||||
gammas: Vec<f32>,
|
||||
@ -14,8 +14,8 @@ pub struct MixingResidualBlock {
|
||||
|
||||
impl MixingResidualBlock {
|
||||
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let norm1 = WLayerNorm::new(inp)?;
|
||||
let norm2 = WLayerNorm::new(inp)?;
|
||||
let norm1 = LayerNormNoWeights::new(inp)?;
|
||||
let norm2 = LayerNormNoWeights::new(inp)?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
groups: inp,
|
||||
..Default::default()
|
||||
|
Reference in New Issue
Block a user