diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index 7e4d2360..b3231360 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -16,6 +16,7 @@ use tokenizers::Tokenizer; const GUIDANCE_SCALE: f64 = 7.5; const RESOLUTION_MULTIPLE: f64 = 42.67; +const PRIOR_CIN: usize = 16; #[derive(Parser)] #[command(author, version, about, long_about = None)] @@ -54,6 +55,10 @@ struct Args { #[arg(long, value_name = "FILE")] clip_weights: Option, + /// The CLIP weight file used by the prior model, in .safetensors format. + #[arg(long, value_name = "FILE")] + prior_clip_weights: Option, + /// The prior weight file, in .safetensors format. #[arg(long, value_name = "FILE")] prior_weights: Option, @@ -66,6 +71,10 @@ struct Args { /// The file specifying the tokenizer to used for tokenization. tokenizer: Option, + #[arg(long, value_name = "FILE")] + /// The file specifying the tokenizer to used for prior tokenization. + prior_tokenizer: Option, + /// The size of the sliced attention or 0 for automatic slicing (disabled by default) #[arg(long)] sliced_attention_size: Option, @@ -86,7 +95,9 @@ struct Args { #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum ModelFile { Tokenizer, + PriorTokenizer, Clip, + PriorClip, Decoder, VqGan, Prior, @@ -102,7 +113,9 @@ impl ModelFile { let repo_prior = "warp-ai/wuerstchen-prior"; let (repo, path) = match self { Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"), + Self::PriorTokenizer => (repo_prior, "tokenizer/tokenizer.json"), Self::Clip => (repo_main, "text_encoder/model.safetensors"), + Self::PriorClip => (repo_prior, "text_encoder/model.safetensors"), Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"), Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"), Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"), @@ -144,12 +157,11 @@ fn output_filename( fn encode_prompt( prompt: &str, uncond_prompt: &str, - tokenizer: Option, - clip_weights: Option, + tokenizer: std::path::PathBuf, + clip_weights: std::path::PathBuf, clip_config: stable_diffusion::clip::Config, device: &Device, ) -> Result { - let tokenizer = ModelFile::Tokenizer.get(tokenizer)?; let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; let pad_id = match &clip_config.pad_with { Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(), @@ -161,6 +173,7 @@ fn encode_prompt( .map_err(E::msg)? .get_ids() .to_vec(); + let tokens_len = tokens.len(); while tokens.len() < clip_config.max_position_embeddings { tokens.push(pad_id) } @@ -171,17 +184,17 @@ fn encode_prompt( .map_err(E::msg)? .get_ids() .to_vec(); + let uncond_tokens_len = uncond_tokens.len(); while uncond_tokens.len() < clip_config.max_position_embeddings { uncond_tokens.push(pad_id) } let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; - println!("Building the Clip transformer."); - let clip_weights = ModelFile::Clip.get(clip_weights)?; + println!("Building the clip transformer."); let text_model = stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?; - let text_embeddings = text_model.forward(&tokens)?; - let uncond_embeddings = text_model.forward(&uncond_tokens)?; + let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len)?; + let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len)?; let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?; Ok(text_embeddings) } @@ -221,15 +234,19 @@ fn run(args: Args) -> Result<()> { let height = height.unwrap_or(1024); let width = width.unwrap_or(1024); - let text_embeddings = encode_prompt( - &prompt, - &uncond_prompt, - tokenizer.clone(), - clip_weights.clone(), - stable_diffusion::clip::Config::wuerstchen(), - &device, - )?; - println!("{text_embeddings:?}"); + let prior_text_embeddings = { + let tokenizer = ModelFile::PriorTokenizer.get(args.prior_tokenizer)?; + let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?; + encode_prompt( + &prompt, + &uncond_prompt, + tokenizer.clone(), + weights, + stable_diffusion::clip::Config::wuerstchen_prior(), + &device, + )? + }; + println!("{prior_text_embeddings}"); println!("Building the prior."); // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json @@ -239,8 +256,8 @@ fn run(args: Args) -> Result<()> { let weights = weights.deserialize()?; let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); wuerstchen::prior::WPrior::new( - /* c_in */ 16, /* c */ 1536, /* c_cond */ 1280, /* c_r */ 64, - /* depth */ 32, /* nhead */ 24, vb, + /* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280, + /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb, )? }; @@ -274,12 +291,12 @@ fn run(args: Args) -> Result<()> { let latents = Tensor::randn( 0f32, 1f32, - (b_size, 4, latent_height, latent_width), + (b_size, PRIOR_CIN, latent_height, latent_width), &device, )?; // TODO: latents denoising loop, use the scheduler values. let ratio = Tensor::ones(1, DType::F32, &device)?; - let prior = prior.forward(&latents, &ratio, &text_embeddings)?; + let prior = prior.forward(&latents, &ratio, &prior_text_embeddings)?; let latents = ((latents * 42.)? - 1.)?; /* diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index 31d025b3..7f86cf31 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -107,13 +107,28 @@ impl Config { embed_dim: 1024, intermediate_size: 4096, max_position_embeddings: 77, - pad_with: Some("!".to_string()), + pad_with: None, num_hidden_layers: 24, num_attention_heads: 16, projection_dim: 1024, activation: Activation::Gelu, } } + + // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/text_encoder/config.json + pub fn wuerstchen_prior() -> Self { + Self { + vocab_size: 49408, + embed_dim: 1280, + intermediate_size: 5120, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 32, + num_attention_heads: 20, + projection_dim: 512, + activation: Activation::Gelu, + } + } } // CLIP Text Model @@ -334,21 +349,39 @@ impl ClipTextTransformer { } // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 - fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result { + fn build_causal_attention_mask( + bsz: usize, + seq_len: usize, + mask_after: usize, + device: &Device, + ) -> Result { let mask: Vec<_> = (0..seq_len) - .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. })) + .flat_map(|i| { + (0..seq_len).map(move |j| { + if j > i || j > mask_after { + f32::MIN + } else { + 0. + } + }) + }) .collect(); let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; mask.broadcast_as((bsz, seq_len, seq_len)) } + + pub fn forward_with_mask(&self, xs: &Tensor, mask_after: usize) -> Result { + let (bsz, seq_len) = xs.dims2()?; + let xs = self.embeddings.forward(xs)?; + let causal_attention_mask = + Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?; + let xs = self.encoder.forward(&xs, &causal_attention_mask)?; + self.final_layer_norm.forward(&xs) + } } impl Module for ClipTextTransformer { fn forward(&self, xs: &Tensor) -> Result { - let (bsz, seq_len) = xs.dims2()?; - let xs = self.embeddings.forward(xs)?; - let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?; - let xs = self.encoder.forward(&xs, &causal_attention_mask)?; - self.final_layer_norm.forward(&xs) + self.forward_with_mask(xs, usize::MAX) } } diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs index ee318d27..5337fdc6 100644 --- a/candle-transformers/src/models/wuerstchen/common.rs +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -75,9 +75,9 @@ impl Module for GlobalResponseNorm { let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?; let stand_div_norm = agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?; - (xs.broadcast_mul(&stand_div_norm)? - .broadcast_mul(&self.gamma) - + &self.beta)? + xs.broadcast_mul(&stand_div_norm)? + .broadcast_mul(&self.gamma)? + .broadcast_add(&self.beta)? + xs } } diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 70e4ba34..664251ed 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -68,7 +68,7 @@ struct DownBlock { struct UpBlock { sub_blocks: Vec, layer_norm: Option, - conv: Option, + conv: Option, } #[derive(Debug)] @@ -152,20 +152,20 @@ impl WDiffNeXt { stride: 2, ..Default::default() }; - let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(1))?; - (Some(layer_norm), Some(conv), 2) + let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?; + (Some(layer_norm), Some(conv), 1) } else { (None, None, 0) }; let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); let mut layer_i = start_layer_i; - for j in 0..BLOCKS[i] { + for _j in 0..BLOCKS[i] { let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?; layer_i += 1; let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; layer_i += 1; - let attn_block = if j == 0 { + let attn_block = if i == 0 { None } else { let attn_block = @@ -190,7 +190,7 @@ impl WDiffNeXt { let mut up_blocks = Vec::with_capacity(C_HIDDEN.len()); for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() { - let vb = vb.pp("up_blocks").pp(i); + let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i); let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); let mut layer_i = 0; for j in 0..BLOCKS[i] { @@ -204,7 +204,7 @@ impl WDiffNeXt { layer_i += 1; let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; layer_i += 1; - let attn_block = if j == 0 { + let attn_block = if i == 0 { None } else { let attn_block = @@ -221,12 +221,17 @@ impl WDiffNeXt { } let (layer_norm, conv) = if i > 0 { let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; - layer_i += 1; - let cfg = candle_nn::Conv2dConfig { + let cfg = candle_nn::ConvTranspose2dConfig { stride: 2, ..Default::default() }; - let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?; + let conv = candle_nn::conv_transpose2d( + c_hidden, + C_HIDDEN[i - 1], + 2, + cfg, + vb.pp(layer_i).pp(1), + )?; (Some(layer_norm), Some(conv)) } else { (None, None)