Fix the W clip embeddings. (#887)

* Fix the W clip embeddings.

* Add the specialized ddpm scheduler.
This commit is contained in:
Laurent Mazare
2023-09-18 14:50:14 +01:00
committed by GitHub
parent 7dd8e12472
commit 5082954c52
3 changed files with 101 additions and 3 deletions

View File

@ -193,9 +193,9 @@ fn encode_prompt(
println!("Building the clip transformer.");
let text_model =
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len)?;
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
Ok(text_embeddings)
}