mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Fixes for the stable diffusion example. (#342)
* Fixes for the stable diffusion example. * Bugfix. * Another fix. * Fix for group-norm. * More fixes to get SD to work.
This commit is contained in:
@ -29,7 +29,7 @@ pub struct Config {
|
||||
embed_dim: usize, // aka config.hidden_size
|
||||
activation: Activation, // aka config.hidden_act
|
||||
intermediate_size: usize,
|
||||
max_position_embeddings: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
// The character to use for padding, use EOS when not set.
|
||||
pad_with: Option<String>,
|
||||
num_hidden_layers: usize,
|
||||
@ -90,7 +90,7 @@ impl ClipTextEmbeddings {
|
||||
vs.pp("position_embedding"),
|
||||
)?;
|
||||
let position_ids =
|
||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(1)?;
|
||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
||||
Ok(ClipTextEmbeddings {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
|
@ -49,7 +49,7 @@ impl Timesteps {
|
||||
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
|
||||
let emb = exponent.exp()?;
|
||||
// emb = timesteps[:, None].float() * emb[None, :]
|
||||
let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?;
|
||||
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
|
||||
let (cos, sin) = (emb.cos()?, emb.sin()?);
|
||||
let emb = if self.flip_sin_to_cos {
|
||||
Tensor::cat(&[&cos, &sin], D::Minus1)?
|
||||
|
@ -181,19 +181,29 @@ fn run(args: Args) -> Result<()> {
|
||||
let device = candle_examples::device(cpu)?;
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let pad_id = match tokenizer.get_padding() {
|
||||
Some(padding) => padding.pad_id,
|
||||
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
|
||||
};
|
||||
println!("Running with prompt \"{prompt}\".");
|
||||
let tokens = tokenizer
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
tokens.push(pad_id)
|
||||
}
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
|
||||
let uncond_tokens = tokenizer
|
||||
let mut uncond_tokens = tokenizer
|
||||
.encode(uncond_prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the Clip transformer.");
|
||||
@ -202,6 +212,7 @@ fn run(args: Args) -> Result<()> {
|
||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
||||
|
||||
println!("text-embeddings: {text_embeddings:?}");
|
||||
println!("Building the autoencoder.");
|
||||
let vae = sd_config.build_vae(&vae_weights, &device)?;
|
||||
println!("Building the unet.");
|
||||
|
@ -118,7 +118,7 @@ impl ResnetBlock2D {
|
||||
.forward(&nn::ops::silu(temb)?)?
|
||||
.unsqueeze(D::Minus1)?
|
||||
.unsqueeze(D::Minus1)?
|
||||
.add(&xs)?,
|
||||
.broadcast_add(&xs)?,
|
||||
_ => xs,
|
||||
};
|
||||
let xs = self
|
||||
|
Reference in New Issue
Block a user