mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
More Wuerstchen fixes. (#882)
* More Weurstchen fixes. * More shape fixes. * Add more of the prior specific bits. * Broadcast add. * Fix the clip config. * Add some masking options to the clip model.
This commit is contained in:
@ -16,6 +16,7 @@ use tokenizers::Tokenizer;
|
||||
|
||||
const GUIDANCE_SCALE: f64 = 7.5;
|
||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||
const PRIOR_CIN: usize = 16;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -54,6 +55,10 @@ struct Args {
|
||||
#[arg(long, value_name = "FILE")]
|
||||
clip_weights: Option<String>,
|
||||
|
||||
/// The CLIP weight file used by the prior model, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
prior_clip_weights: Option<String>,
|
||||
|
||||
/// The prior weight file, in .safetensors format.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
prior_weights: Option<String>,
|
||||
@ -66,6 +71,10 @@ struct Args {
|
||||
/// The file specifying the tokenizer to used for tokenization.
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long, value_name = "FILE")]
|
||||
/// The file specifying the tokenizer to used for prior tokenization.
|
||||
prior_tokenizer: Option<String>,
|
||||
|
||||
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
||||
#[arg(long)]
|
||||
sliced_attention_size: Option<usize>,
|
||||
@ -86,7 +95,9 @@ struct Args {
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ModelFile {
|
||||
Tokenizer,
|
||||
PriorTokenizer,
|
||||
Clip,
|
||||
PriorClip,
|
||||
Decoder,
|
||||
VqGan,
|
||||
Prior,
|
||||
@ -102,7 +113,9 @@ impl ModelFile {
|
||||
let repo_prior = "warp-ai/wuerstchen-prior";
|
||||
let (repo, path) = match self {
|
||||
Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"),
|
||||
Self::PriorTokenizer => (repo_prior, "tokenizer/tokenizer.json"),
|
||||
Self::Clip => (repo_main, "text_encoder/model.safetensors"),
|
||||
Self::PriorClip => (repo_prior, "text_encoder/model.safetensors"),
|
||||
Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"),
|
||||
Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"),
|
||||
Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"),
|
||||
@ -144,12 +157,11 @@ fn output_filename(
|
||||
fn encode_prompt(
|
||||
prompt: &str,
|
||||
uncond_prompt: &str,
|
||||
tokenizer: Option<String>,
|
||||
clip_weights: Option<String>,
|
||||
tokenizer: std::path::PathBuf,
|
||||
clip_weights: std::path::PathBuf,
|
||||
clip_config: stable_diffusion::clip::Config,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let pad_id = match &clip_config.pad_with {
|
||||
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
||||
@ -161,6 +173,7 @@ fn encode_prompt(
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let tokens_len = tokens.len();
|
||||
while tokens.len() < clip_config.max_position_embeddings {
|
||||
tokens.push(pad_id)
|
||||
}
|
||||
@ -171,17 +184,17 @@ fn encode_prompt(
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let uncond_tokens_len = uncond_tokens.len();
|
||||
while uncond_tokens.len() < clip_config.max_position_embeddings {
|
||||
uncond_tokens.push(pad_id)
|
||||
}
|
||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||
|
||||
println!("Building the Clip transformer.");
|
||||
let clip_weights = ModelFile::Clip.get(clip_weights)?;
|
||||
println!("Building the clip transformer.");
|
||||
let text_model =
|
||||
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
||||
let text_embeddings = text_model.forward(&tokens)?;
|
||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
||||
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len)?;
|
||||
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len)?;
|
||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
||||
Ok(text_embeddings)
|
||||
}
|
||||
@ -221,15 +234,19 @@ fn run(args: Args) -> Result<()> {
|
||||
let height = height.unwrap_or(1024);
|
||||
let width = width.unwrap_or(1024);
|
||||
|
||||
let text_embeddings = encode_prompt(
|
||||
&prompt,
|
||||
&uncond_prompt,
|
||||
tokenizer.clone(),
|
||||
clip_weights.clone(),
|
||||
stable_diffusion::clip::Config::wuerstchen(),
|
||||
&device,
|
||||
)?;
|
||||
println!("{text_embeddings:?}");
|
||||
let prior_text_embeddings = {
|
||||
let tokenizer = ModelFile::PriorTokenizer.get(args.prior_tokenizer)?;
|
||||
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
|
||||
encode_prompt(
|
||||
&prompt,
|
||||
&uncond_prompt,
|
||||
tokenizer.clone(),
|
||||
weights,
|
||||
stable_diffusion::clip::Config::wuerstchen_prior(),
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
println!("{prior_text_embeddings}");
|
||||
|
||||
println!("Building the prior.");
|
||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
||||
@ -239,8 +256,8 @@ fn run(args: Args) -> Result<()> {
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
wuerstchen::prior::WPrior::new(
|
||||
/* c_in */ 16, /* c */ 1536, /* c_cond */ 1280, /* c_r */ 64,
|
||||
/* depth */ 32, /* nhead */ 24, vb,
|
||||
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
|
||||
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
||||
)?
|
||||
};
|
||||
|
||||
@ -274,12 +291,12 @@ fn run(args: Args) -> Result<()> {
|
||||
let latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(b_size, 4, latent_height, latent_width),
|
||||
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||
&device,
|
||||
)?;
|
||||
// TODO: latents denoising loop, use the scheduler values.
|
||||
let ratio = Tensor::ones(1, DType::F32, &device)?;
|
||||
let prior = prior.forward(&latents, &ratio, &text_embeddings)?;
|
||||
let prior = prior.forward(&latents, &ratio, &prior_text_embeddings)?;
|
||||
|
||||
let latents = ((latents * 42.)? - 1.)?;
|
||||
/*
|
||||
|
Reference in New Issue
Block a user