More Wuerstchen fixes. (#882)

* More Weurstchen fixes.

* More shape fixes.

* Add more of the prior specific bits.

* Broadcast add.

* Fix the clip config.

* Add some masking options to the clip model.
This commit is contained in:
Laurent Mazare
2023-09-17 22:08:11 +01:00
committed by GitHub
parent 06cc329e71
commit c2b866172a
4 changed files with 96 additions and 41 deletions

View File

@ -16,6 +16,7 @@ use tokenizers::Tokenizer;
const GUIDANCE_SCALE: f64 = 7.5;
const RESOLUTION_MULTIPLE: f64 = 42.67;
const PRIOR_CIN: usize = 16;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
@ -54,6 +55,10 @@ struct Args {
#[arg(long, value_name = "FILE")]
clip_weights: Option<String>,
/// The CLIP weight file used by the prior model, in .safetensors format.
#[arg(long, value_name = "FILE")]
prior_clip_weights: Option<String>,
/// The prior weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
prior_weights: Option<String>,
@ -66,6 +71,10 @@ struct Args {
/// The file specifying the tokenizer to used for tokenization.
tokenizer: Option<String>,
#[arg(long, value_name = "FILE")]
/// The file specifying the tokenizer to used for prior tokenization.
prior_tokenizer: Option<String>,
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
#[arg(long)]
sliced_attention_size: Option<usize>,
@ -86,7 +95,9 @@ struct Args {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
PriorTokenizer,
Clip,
PriorClip,
Decoder,
VqGan,
Prior,
@ -102,7 +113,9 @@ impl ModelFile {
let repo_prior = "warp-ai/wuerstchen-prior";
let (repo, path) = match self {
Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"),
Self::PriorTokenizer => (repo_prior, "tokenizer/tokenizer.json"),
Self::Clip => (repo_main, "text_encoder/model.safetensors"),
Self::PriorClip => (repo_prior, "text_encoder/model.safetensors"),
Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"),
Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"),
Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"),
@ -144,12 +157,11 @@ fn output_filename(
fn encode_prompt(
prompt: &str,
uncond_prompt: &str,
tokenizer: Option<String>,
clip_weights: Option<String>,
tokenizer: std::path::PathBuf,
clip_weights: std::path::PathBuf,
clip_config: stable_diffusion::clip::Config,
device: &Device,
) -> Result<Tensor> {
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match &clip_config.pad_with {
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
@ -161,6 +173,7 @@ fn encode_prompt(
.map_err(E::msg)?
.get_ids()
.to_vec();
let tokens_len = tokens.len();
while tokens.len() < clip_config.max_position_embeddings {
tokens.push(pad_id)
}
@ -171,17 +184,17 @@ fn encode_prompt(
.map_err(E::msg)?
.get_ids()
.to_vec();
let uncond_tokens_len = uncond_tokens.len();
while uncond_tokens.len() < clip_config.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the Clip transformer.");
let clip_weights = ModelFile::Clip.get(clip_weights)?;
println!("Building the clip transformer.");
let text_model =
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len)?;
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
Ok(text_embeddings)
}
@ -221,15 +234,19 @@ fn run(args: Args) -> Result<()> {
let height = height.unwrap_or(1024);
let width = width.unwrap_or(1024);
let text_embeddings = encode_prompt(
&prompt,
&uncond_prompt,
tokenizer.clone(),
clip_weights.clone(),
stable_diffusion::clip::Config::wuerstchen(),
&device,
)?;
println!("{text_embeddings:?}");
let prior_text_embeddings = {
let tokenizer = ModelFile::PriorTokenizer.get(args.prior_tokenizer)?;
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
encode_prompt(
&prompt,
&uncond_prompt,
tokenizer.clone(),
weights,
stable_diffusion::clip::Config::wuerstchen_prior(),
&device,
)?
};
println!("{prior_text_embeddings}");
println!("Building the prior.");
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
@ -239,8 +256,8 @@ fn run(args: Args) -> Result<()> {
let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::prior::WPrior::new(
/* c_in */ 16, /* c */ 1536, /* c_cond */ 1280, /* c_r */ 64,
/* depth */ 32, /* nhead */ 24, vb,
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
)?
};
@ -274,12 +291,12 @@ fn run(args: Args) -> Result<()> {
let latents = Tensor::randn(
0f32,
1f32,
(b_size, 4, latent_height, latent_width),
(b_size, PRIOR_CIN, latent_height, latent_width),
&device,
)?;
// TODO: latents denoising loop, use the scheduler values.
let ratio = Tensor::ones(1, DType::F32, &device)?;
let prior = prior.forward(&latents, &ratio, &text_embeddings)?;
let prior = prior.forward(&latents, &ratio, &prior_text_embeddings)?;
let latents = ((latents * 42.)? - 1.)?;
/*