Compare commits

...

8 Commits

Author SHA1 Message Date
4114872aae Make things contiguous. 2023-09-19 14:04:56 +01:00
f2a648f313 Remove some debugging continue. 2023-09-19 13:43:41 +01:00
ec895453cd More shape fixes. 2023-09-19 13:43:19 +01:00
3769d8bf71 Print the prior shape. 2023-09-19 10:26:18 +01:00
5d8e214dfe Fix the latent shape. 2023-09-19 09:21:35 +01:00
576bf7c21f Tweaks. 2023-09-19 09:08:32 +01:00
49a4fa44bb Add another specific layer-norm structure. 2023-09-19 09:06:10 +01:00
b936e32e11 Only use classifier free guidance for the prior. 2023-09-19 08:40:02 +01:00
4 changed files with 129 additions and 76 deletions

View File

@ -16,7 +16,9 @@ use tokenizers::Tokenizer;
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
const RESOLUTION_MULTIPLE: f64 = 42.67;
const LATENT_DIM_SCALE: f64 = 10.67;
const PRIOR_CIN: usize = 16;
const DECODER_CIN: usize = 4;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
@ -156,7 +158,7 @@ fn output_filename(
fn encode_prompt(
prompt: &str,
uncond_prompt: &str,
uncond_prompt: Option<&str>,
tokenizer: std::path::PathBuf,
clip_weights: std::path::PathBuf,
clip_config: stable_diffusion::clip::Config,
@ -179,24 +181,30 @@ fn encode_prompt(
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let uncond_tokens_len = uncond_tokens.len();
while uncond_tokens.len() < clip_config.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the clip transformer.");
let text_model =
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
Ok(text_embeddings)
match uncond_prompt {
None => Ok(text_embeddings),
Some(uncond_prompt) => {
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let uncond_tokens_len = uncond_tokens.len();
while uncond_tokens.len() < clip_config.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
let uncond_embeddings =
text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
Ok(text_embeddings)
}
}
}
fn run(args: Args) -> Result<()> {
@ -239,40 +247,72 @@ fn run(args: Args) -> Result<()> {
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
encode_prompt(
&prompt,
&uncond_prompt,
Some(&uncond_prompt),
tokenizer.clone(),
weights,
stable_diffusion::clip::Config::wuerstchen_prior(),
&device,
)?
};
println!("{prior_text_embeddings}");
println!("generated prior text embeddings {prior_text_embeddings:?}");
let text_embeddings = {
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
let weights = ModelFile::Clip.get(clip_weights)?;
encode_prompt(
&prompt,
&uncond_prompt,
None,
tokenizer.clone(),
weights,
stable_diffusion::clip::Config::wuerstchen(),
&device,
)?
};
println!("{prior_text_embeddings}");
println!("generated text embeddings {text_embeddings:?}");
println!("Building the prior.");
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
let prior = {
let prior_weights = ModelFile::Prior.get(prior_weights)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::prior::WPrior::new(
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
)?
let b_size = 1;
let image_embeddings = {
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let mut latents = Tensor::randn(
0f32,
1f32,
(b_size, PRIOR_CIN, latent_height, latent_width),
&device,
)?;
let prior = {
let prior_weights = ModelFile::Prior.get(prior_weights)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::prior::WPrior::new(
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
)?
};
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = prior_scheduler.timesteps();
println!("prior denoising");
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
if index == timesteps.len() - 1 {
continue;
}
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
let noise_pred = (noise_pred_uncond
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
}
((latents * 42.)? - 1.)?
};
println!("Building the vqgan.");
@ -293,58 +333,40 @@ fn run(args: Args) -> Result<()> {
let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::diffnext::WDiffNeXt::new(
/* c_in */ 4, /* c_out */ 4, /* c_r */ 64, /* c_cond */ 1024,
/* clip_embd */ 1024, /* patch_size */ 2, vb,
/* c_in */ DECODER_CIN,
/* c_out */ DECODER_CIN,
/* c_r */ 64,
/* c_cond */ 1024,
/* clip_embd */ 1024,
/* patch_size */ 2,
vb,
)?
};
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let b_size = 1;
for idx in 0..num_samples {
// https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json
let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize;
let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize;
let mut latents = Tensor::randn(
0f32,
1f32,
(b_size, PRIOR_CIN, latent_height, latent_width),
(b_size, DECODER_CIN, latent_height, latent_width),
&device,
)?;
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = prior_scheduler.timesteps();
println!("prior denoising");
println!("diffusion process with prior {image_embeddings:?}");
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = scheduler.timesteps();
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
if index == timesteps.len() - 1 {
continue;
}
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
let noise_pred = (noise_pred_uncond
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
}
let effnet = ((latents * 42.)? - 1.)?;
let mut latents = Tensor::randn(
0f32,
1f32,
(b_size, PRIOR_CIN, latent_height, latent_width),
&device,
)?;
println!("diffusion process");
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
if index == timesteps.len() - 1 {
continue;
}
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?;
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
let noise_pred =
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
latents = scheduler.step(&noise_pred, t, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
}

View File

@ -34,6 +34,34 @@ impl Module for WLayerNorm {
}
}
#[derive(Debug)]
pub struct LayerNormNoWeights {
eps: f64,
}
impl LayerNormNoWeights {
pub fn new(_size: usize) -> Result<Self> {
Ok(Self { eps: 1e-6 })
}
}
impl Module for LayerNormNoWeights {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let x_dtype = xs.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = xs.dim(D::Minus1)?;
let xs = xs.to_dtype(internal_dtype)?;
let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let xs = xs.broadcast_sub(&mean_x)?;
let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
.to_dtype(x_dtype)
}
}
#[derive(Debug)]
pub struct TimestepBlock {
mapper: candle_nn::Linear,

View File

@ -1,4 +1,4 @@
use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm};
use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
@ -37,7 +37,7 @@ impl ResBlockStageB {
let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?;
let xs = match x_skip {
None => xs.clone(),
Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?,
Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?.contiguous()?,
};
let xs = xs
.permute((0, 2, 3, 1))?
@ -75,7 +75,7 @@ struct UpBlock {
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
seq_norm: WLayerNorm,
seq_norm: LayerNormNoWeights,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
down_blocks: Vec<DownBlock>,
@ -133,7 +133,7 @@ impl WDiffNeXt {
};
effnet_mappers.push(c)
}
let seq_norm = WLayerNorm::new(c_cond)?;
let seq_norm = LayerNormNoWeights::new(c_cond)?;
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
let embedding_conv = candle_nn::conv2d(
c_in * patch_size * patch_size,
@ -335,6 +335,7 @@ impl WDiffNeXt {
level_outputs.push(xs.clone())
}
level_outputs.reverse();
let mut xs = level_outputs[0].clone();
for (i, up_block) in self.up_blocks.iter().enumerate() {
let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
@ -351,7 +352,9 @@ impl WDiffNeXt {
None
};
let skip = match (skip, effnet_c.as_ref()) {
(Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
(Some(skip), Some(effnet_c)) => {
Some(Tensor::cat(&[skip, effnet_c], 1)?.contiguous()?)
}
(None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
(None, None) => None,
};

View File

@ -1,12 +1,12 @@
use super::common::WLayerNorm;
use super::common::LayerNormNoWeights;
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
#[derive(Debug)]
pub struct MixingResidualBlock {
norm1: WLayerNorm,
norm1: LayerNormNoWeights,
depthwise_conv: candle_nn::Conv2d,
norm2: WLayerNorm,
norm2: LayerNormNoWeights,
channelwise_lin1: candle_nn::Linear,
channelwise_lin2: candle_nn::Linear,
gammas: Vec<f32>,
@ -14,8 +14,8 @@ pub struct MixingResidualBlock {
impl MixingResidualBlock {
pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
let norm1 = WLayerNorm::new(inp)?;
let norm2 = WLayerNorm::new(inp)?;
let norm1 = LayerNormNoWeights::new(inp)?;
let norm2 = LayerNormNoWeights::new(inp)?;
let cfg = candle_nn::Conv2dConfig {
groups: inp,
..Default::default()