mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fixes for the stable diffusion example. (#342)
* Fixes for the stable diffusion example. * Bugfix. * Another fix. * Fix for group-norm. * More fixes to get SD to work.
This commit is contained in:
@ -29,7 +29,7 @@ pub struct Config {
|
|||||||
embed_dim: usize, // aka config.hidden_size
|
embed_dim: usize, // aka config.hidden_size
|
||||||
activation: Activation, // aka config.hidden_act
|
activation: Activation, // aka config.hidden_act
|
||||||
intermediate_size: usize,
|
intermediate_size: usize,
|
||||||
max_position_embeddings: usize,
|
pub max_position_embeddings: usize,
|
||||||
// The character to use for padding, use EOS when not set.
|
// The character to use for padding, use EOS when not set.
|
||||||
pad_with: Option<String>,
|
pad_with: Option<String>,
|
||||||
num_hidden_layers: usize,
|
num_hidden_layers: usize,
|
||||||
@ -90,7 +90,7 @@ impl ClipTextEmbeddings {
|
|||||||
vs.pp("position_embedding"),
|
vs.pp("position_embedding"),
|
||||||
)?;
|
)?;
|
||||||
let position_ids =
|
let position_ids =
|
||||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(1)?;
|
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
||||||
Ok(ClipTextEmbeddings {
|
Ok(ClipTextEmbeddings {
|
||||||
token_embedding,
|
token_embedding,
|
||||||
position_embedding,
|
position_embedding,
|
||||||
|
@ -49,7 +49,7 @@ impl Timesteps {
|
|||||||
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
|
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
|
||||||
let emb = exponent.exp()?;
|
let emb = exponent.exp()?;
|
||||||
// emb = timesteps[:, None].float() * emb[None, :]
|
// emb = timesteps[:, None].float() * emb[None, :]
|
||||||
let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?;
|
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
|
||||||
let (cos, sin) = (emb.cos()?, emb.sin()?);
|
let (cos, sin) = (emb.cos()?, emb.sin()?);
|
||||||
let emb = if self.flip_sin_to_cos {
|
let emb = if self.flip_sin_to_cos {
|
||||||
Tensor::cat(&[&cos, &sin], D::Minus1)?
|
Tensor::cat(&[&cos, &sin], D::Minus1)?
|
||||||
|
@ -181,19 +181,29 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let device = candle_examples::device(cpu)?;
|
let device = candle_examples::device(cpu)?;
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
let pad_id = match tokenizer.get_padding() {
|
||||||
|
Some(padding) => padding.pad_id,
|
||||||
|
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
|
||||||
|
};
|
||||||
println!("Running with prompt \"{prompt}\".");
|
println!("Running with prompt \"{prompt}\".");
|
||||||
let tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(prompt, true)
|
.encode(prompt, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
while tokens.len() < sd_config.clip.max_position_embeddings {
|
||||||
|
tokens.push(pad_id)
|
||||||
|
}
|
||||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
let uncond_tokens = tokenizer
|
let mut uncond_tokens = tokenizer
|
||||||
.encode(uncond_prompt, true)
|
.encode(uncond_prompt, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||||
|
uncond_tokens.push(pad_id)
|
||||||
|
}
|
||||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
println!("Building the Clip transformer.");
|
println!("Building the Clip transformer.");
|
||||||
@ -202,6 +212,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
||||||
|
|
||||||
|
println!("text-embeddings: {text_embeddings:?}");
|
||||||
println!("Building the autoencoder.");
|
println!("Building the autoencoder.");
|
||||||
let vae = sd_config.build_vae(&vae_weights, &device)?;
|
let vae = sd_config.build_vae(&vae_weights, &device)?;
|
||||||
println!("Building the unet.");
|
println!("Building the unet.");
|
||||||
|
@ -118,7 +118,7 @@ impl ResnetBlock2D {
|
|||||||
.forward(&nn::ops::silu(temb)?)?
|
.forward(&nn::ops::silu(temb)?)?
|
||||||
.unsqueeze(D::Minus1)?
|
.unsqueeze(D::Minus1)?
|
||||||
.unsqueeze(D::Minus1)?
|
.unsqueeze(D::Minus1)?
|
||||||
.add(&xs)?,
|
.broadcast_add(&xs)?,
|
||||||
_ => xs,
|
_ => xs,
|
||||||
};
|
};
|
||||||
let xs = self
|
let xs = self
|
||||||
|
@ -59,17 +59,21 @@ impl GroupNorm {
|
|||||||
let x = x.broadcast_sub(&mean_x)?;
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||||
|
let mut w_dims = vec![1; x_shape.len()];
|
||||||
|
w_dims[1] = n_channels;
|
||||||
|
let weight = self.weight.reshape(w_dims.clone())?;
|
||||||
|
let bias = self.bias.reshape(w_dims)?;
|
||||||
x_normed
|
x_normed
|
||||||
.to_dtype(x_dtype)?
|
.to_dtype(x_dtype)?
|
||||||
.broadcast_mul(&self.weight)?
|
.reshape(x_shape)?
|
||||||
.broadcast_add(&self.bias)?
|
.broadcast_mul(&weight)?
|
||||||
.reshape(x_shape)
|
.broadcast_add(&bias)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn group_norm(
|
pub fn group_norm(
|
||||||
num_channels: usize,
|
|
||||||
num_groups: usize,
|
num_groups: usize,
|
||||||
|
num_channels: usize,
|
||||||
eps: f64,
|
eps: f64,
|
||||||
vb: crate::VarBuilder,
|
vb: crate::VarBuilder,
|
||||||
) -> Result<GroupNorm> {
|
) -> Result<GroupNorm> {
|
||||||
|
@ -30,8 +30,8 @@ use test_utils::to_vec3_round;
|
|||||||
#[test]
|
#[test]
|
||||||
fn group_norm() -> Result<()> {
|
fn group_norm() -> Result<()> {
|
||||||
let device = &Device::Cpu;
|
let device = &Device::Cpu;
|
||||||
let w = Tensor::new(&[1f32], device)?;
|
let w = Tensor::from_vec(vec![1f32; 6], 6, device)?;
|
||||||
let b = Tensor::new(&[0f32], device)?;
|
let b = Tensor::from_vec(vec![0f32; 6], 6, device)?;
|
||||||
let gn2 = GroupNorm::new(w.clone(), b.clone(), 6, 2, 1e-5)?;
|
let gn2 = GroupNorm::new(w.clone(), b.clone(), 6, 2, 1e-5)?;
|
||||||
let gn3 = GroupNorm::new(w, b, 6, 3, 1e-5)?;
|
let gn3 = GroupNorm::new(w, b, 6, 3, 1e-5)?;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user