diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs index ac9843f7..12f482fd 100644 --- a/candle-examples/examples/stable-diffusion/clip.rs +++ b/candle-examples/examples/stable-diffusion/clip.rs @@ -29,7 +29,7 @@ pub struct Config { embed_dim: usize, // aka config.hidden_size activation: Activation, // aka config.hidden_act intermediate_size: usize, - max_position_embeddings: usize, + pub max_position_embeddings: usize, // The character to use for padding, use EOS when not set. pad_with: Option, num_hidden_layers: usize, @@ -90,7 +90,7 @@ impl ClipTextEmbeddings { vs.pp("position_embedding"), )?; let position_ids = - Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(1)?; + Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?; Ok(ClipTextEmbeddings { token_embedding, position_embedding, diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs index 848f1760..e3a339f5 100644 --- a/candle-examples/examples/stable-diffusion/embeddings.rs +++ b/candle-examples/examples/stable-diffusion/embeddings.rs @@ -49,7 +49,7 @@ impl Timesteps { let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?; let emb = exponent.exp()?; // emb = timesteps[:, None].float() * emb[None, :] - let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?; + let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?; let (cos, sin) = (emb.cos()?, emb.sin()?); let emb = if self.flip_sin_to_cos { Tensor::cat(&[&cos, &sin], D::Minus1)? diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index d8327c0e..0e0330d7 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -181,19 +181,29 @@ fn run(args: Args) -> Result<()> { let device = candle_examples::device(cpu)?; let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; + let pad_id = match tokenizer.get_padding() { + Some(padding) => padding.pad_id, + None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(), + }; println!("Running with prompt \"{prompt}\"."); - let tokens = tokenizer + let mut tokens = tokenizer .encode(prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); + while tokens.len() < sd_config.clip.max_position_embeddings { + tokens.push(pad_id) + } let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; - let uncond_tokens = tokenizer + let mut uncond_tokens = tokenizer .encode(uncond_prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); + while uncond_tokens.len() < sd_config.clip.max_position_embeddings { + uncond_tokens.push(pad_id) + } let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?; println!("Building the Clip transformer."); @@ -202,6 +212,7 @@ fn run(args: Args) -> Result<()> { let uncond_embeddings = text_model.forward(&uncond_tokens)?; let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?; + println!("text-embeddings: {text_embeddings:?}"); println!("Building the autoencoder."); let vae = sd_config.build_vae(&vae_weights, &device)?; println!("Building the unet."); diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs index b6696083..7790dcf9 100644 --- a/candle-examples/examples/stable-diffusion/resnet.rs +++ b/candle-examples/examples/stable-diffusion/resnet.rs @@ -118,7 +118,7 @@ impl ResnetBlock2D { .forward(&nn::ops::silu(temb)?)? .unsqueeze(D::Minus1)? .unsqueeze(D::Minus1)? - .add(&xs)?, + .broadcast_add(&xs)?, _ => xs, }; let xs = self diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs index e277ae85..ac77db4b 100644 --- a/candle-nn/src/group_norm.rs +++ b/candle-nn/src/group_norm.rs @@ -59,17 +59,21 @@ impl GroupNorm { let x = x.broadcast_sub(&mean_x)?; let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let mut w_dims = vec![1; x_shape.len()]; + w_dims[1] = n_channels; + let weight = self.weight.reshape(w_dims.clone())?; + let bias = self.bias.reshape(w_dims)?; x_normed .to_dtype(x_dtype)? - .broadcast_mul(&self.weight)? - .broadcast_add(&self.bias)? - .reshape(x_shape) + .reshape(x_shape)? + .broadcast_mul(&weight)? + .broadcast_add(&bias) } } pub fn group_norm( - num_channels: usize, num_groups: usize, + num_channels: usize, eps: f64, vb: crate::VarBuilder, ) -> Result { diff --git a/candle-nn/tests/group_norm.rs b/candle-nn/tests/group_norm.rs index d48b69f6..f3ef2455 100644 --- a/candle-nn/tests/group_norm.rs +++ b/candle-nn/tests/group_norm.rs @@ -30,8 +30,8 @@ use test_utils::to_vec3_round; #[test] fn group_norm() -> Result<()> { let device = &Device::Cpu; - let w = Tensor::new(&[1f32], device)?; - let b = Tensor::new(&[0f32], device)?; + let w = Tensor::from_vec(vec![1f32; 6], 6, device)?; + let b = Tensor::from_vec(vec![0f32; 6], 6, device)?; let gn2 = GroupNorm::new(w.clone(), b.clone(), 6, 2, 1e-5)?; let gn3 = GroupNorm::new(w, b, 6, 3, 1e-5)?;