More Wuerstchen fixes. (#882)

* More Weurstchen fixes.

* More shape fixes.

* Add more of the prior specific bits.

* Broadcast add.

* Fix the clip config.

* Add some masking options to the clip model.
This commit is contained in:
Laurent Mazare
2023-09-17 22:08:11 +01:00
committed by GitHub
parent 06cc329e71
commit c2b866172a
4 changed files with 96 additions and 41 deletions

View File

@ -16,6 +16,7 @@ use tokenizers::Tokenizer;
const GUIDANCE_SCALE: f64 = 7.5; const GUIDANCE_SCALE: f64 = 7.5;
const RESOLUTION_MULTIPLE: f64 = 42.67; const RESOLUTION_MULTIPLE: f64 = 42.67;
const PRIOR_CIN: usize = 16;
#[derive(Parser)] #[derive(Parser)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
@ -54,6 +55,10 @@ struct Args {
#[arg(long, value_name = "FILE")] #[arg(long, value_name = "FILE")]
clip_weights: Option<String>, clip_weights: Option<String>,
/// The CLIP weight file used by the prior model, in .safetensors format.
#[arg(long, value_name = "FILE")]
prior_clip_weights: Option<String>,
/// The prior weight file, in .safetensors format. /// The prior weight file, in .safetensors format.
#[arg(long, value_name = "FILE")] #[arg(long, value_name = "FILE")]
prior_weights: Option<String>, prior_weights: Option<String>,
@ -66,6 +71,10 @@ struct Args {
/// The file specifying the tokenizer to used for tokenization. /// The file specifying the tokenizer to used for tokenization.
tokenizer: Option<String>, tokenizer: Option<String>,
#[arg(long, value_name = "FILE")]
/// The file specifying the tokenizer to used for prior tokenization.
prior_tokenizer: Option<String>,
/// The size of the sliced attention or 0 for automatic slicing (disabled by default) /// The size of the sliced attention or 0 for automatic slicing (disabled by default)
#[arg(long)] #[arg(long)]
sliced_attention_size: Option<usize>, sliced_attention_size: Option<usize>,
@ -86,7 +95,9 @@ struct Args {
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile { enum ModelFile {
Tokenizer, Tokenizer,
PriorTokenizer,
Clip, Clip,
PriorClip,
Decoder, Decoder,
VqGan, VqGan,
Prior, Prior,
@ -102,7 +113,9 @@ impl ModelFile {
let repo_prior = "warp-ai/wuerstchen-prior"; let repo_prior = "warp-ai/wuerstchen-prior";
let (repo, path) = match self { let (repo, path) = match self {
Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"), Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"),
Self::PriorTokenizer => (repo_prior, "tokenizer/tokenizer.json"),
Self::Clip => (repo_main, "text_encoder/model.safetensors"), Self::Clip => (repo_main, "text_encoder/model.safetensors"),
Self::PriorClip => (repo_prior, "text_encoder/model.safetensors"),
Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"), Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"),
Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"), Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"),
Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"), Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"),
@ -144,12 +157,11 @@ fn output_filename(
fn encode_prompt( fn encode_prompt(
prompt: &str, prompt: &str,
uncond_prompt: &str, uncond_prompt: &str,
tokenizer: Option<String>, tokenizer: std::path::PathBuf,
clip_weights: Option<String>, clip_weights: std::path::PathBuf,
clip_config: stable_diffusion::clip::Config, clip_config: stable_diffusion::clip::Config,
device: &Device, device: &Device,
) -> Result<Tensor> { ) -> Result<Tensor> {
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match &clip_config.pad_with { let pad_id = match &clip_config.pad_with {
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(), Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
@ -161,6 +173,7 @@ fn encode_prompt(
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let tokens_len = tokens.len();
while tokens.len() < clip_config.max_position_embeddings { while tokens.len() < clip_config.max_position_embeddings {
tokens.push(pad_id) tokens.push(pad_id)
} }
@ -171,17 +184,17 @@ fn encode_prompt(
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let uncond_tokens_len = uncond_tokens.len();
while uncond_tokens.len() < clip_config.max_position_embeddings { while uncond_tokens.len() < clip_config.max_position_embeddings {
uncond_tokens.push(pad_id) uncond_tokens.push(pad_id)
} }
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the Clip transformer."); println!("Building the clip transformer.");
let clip_weights = ModelFile::Clip.get(clip_weights)?;
let text_model = let text_model =
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?; stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?; let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?; let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?; let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
Ok(text_embeddings) Ok(text_embeddings)
} }
@ -221,15 +234,19 @@ fn run(args: Args) -> Result<()> {
let height = height.unwrap_or(1024); let height = height.unwrap_or(1024);
let width = width.unwrap_or(1024); let width = width.unwrap_or(1024);
let text_embeddings = encode_prompt( let prior_text_embeddings = {
&prompt, let tokenizer = ModelFile::PriorTokenizer.get(args.prior_tokenizer)?;
&uncond_prompt, let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
tokenizer.clone(), encode_prompt(
clip_weights.clone(), &prompt,
stable_diffusion::clip::Config::wuerstchen(), &uncond_prompt,
&device, tokenizer.clone(),
)?; weights,
println!("{text_embeddings:?}"); stable_diffusion::clip::Config::wuerstchen_prior(),
&device,
)?
};
println!("{prior_text_embeddings}");
println!("Building the prior."); println!("Building the prior.");
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
@ -239,8 +256,8 @@ fn run(args: Args) -> Result<()> {
let weights = weights.deserialize()?; let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::prior::WPrior::new( wuerstchen::prior::WPrior::new(
/* c_in */ 16, /* c */ 1536, /* c_cond */ 1280, /* c_r */ 64, /* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
/* depth */ 32, /* nhead */ 24, vb, /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
)? )?
}; };
@ -274,12 +291,12 @@ fn run(args: Args) -> Result<()> {
let latents = Tensor::randn( let latents = Tensor::randn(
0f32, 0f32,
1f32, 1f32,
(b_size, 4, latent_height, latent_width), (b_size, PRIOR_CIN, latent_height, latent_width),
&device, &device,
)?; )?;
// TODO: latents denoising loop, use the scheduler values. // TODO: latents denoising loop, use the scheduler values.
let ratio = Tensor::ones(1, DType::F32, &device)?; let ratio = Tensor::ones(1, DType::F32, &device)?;
let prior = prior.forward(&latents, &ratio, &text_embeddings)?; let prior = prior.forward(&latents, &ratio, &prior_text_embeddings)?;
let latents = ((latents * 42.)? - 1.)?; let latents = ((latents * 42.)? - 1.)?;
/* /*

View File

@ -107,13 +107,28 @@ impl Config {
embed_dim: 1024, embed_dim: 1024,
intermediate_size: 4096, intermediate_size: 4096,
max_position_embeddings: 77, max_position_embeddings: 77,
pad_with: Some("!".to_string()), pad_with: None,
num_hidden_layers: 24, num_hidden_layers: 24,
num_attention_heads: 16, num_attention_heads: 16,
projection_dim: 1024, projection_dim: 1024,
activation: Activation::Gelu, activation: Activation::Gelu,
} }
} }
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/text_encoder/config.json
pub fn wuerstchen_prior() -> Self {
Self {
vocab_size: 49408,
embed_dim: 1280,
intermediate_size: 5120,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 32,
num_attention_heads: 20,
projection_dim: 512,
activation: Activation::Gelu,
}
}
} }
// CLIP Text Model // CLIP Text Model
@ -334,21 +349,39 @@ impl ClipTextTransformer {
} }
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> { fn build_causal_attention_mask(
bsz: usize,
seq_len: usize,
mask_after: usize,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len) let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. })) .flat_map(|i| {
(0..seq_len).map(move |j| {
if j > i || j > mask_after {
f32::MIN
} else {
0.
}
})
})
.collect(); .collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, seq_len, seq_len)) mask.broadcast_as((bsz, seq_len, seq_len))
} }
pub fn forward_with_mask(&self, xs: &Tensor, mask_after: usize) -> Result<Tensor> {
let (bsz, seq_len) = xs.dims2()?;
let xs = self.embeddings.forward(xs)?;
let causal_attention_mask =
Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
self.final_layer_norm.forward(&xs)
}
} }
impl Module for ClipTextTransformer { impl Module for ClipTextTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (bsz, seq_len) = xs.dims2()?; self.forward_with_mask(xs, usize::MAX)
let xs = self.embeddings.forward(xs)?;
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
self.final_layer_norm.forward(&xs)
} }
} }

View File

@ -75,9 +75,9 @@ impl Module for GlobalResponseNorm {
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?; let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
let stand_div_norm = let stand_div_norm =
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?; agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
(xs.broadcast_mul(&stand_div_norm)? xs.broadcast_mul(&stand_div_norm)?
.broadcast_mul(&self.gamma) .broadcast_mul(&self.gamma)?
+ &self.beta)? .broadcast_add(&self.beta)?
+ xs + xs
} }
} }

View File

@ -68,7 +68,7 @@ struct DownBlock {
struct UpBlock { struct UpBlock {
sub_blocks: Vec<SubBlock>, sub_blocks: Vec<SubBlock>,
layer_norm: Option<WLayerNorm>, layer_norm: Option<WLayerNorm>,
conv: Option<candle_nn::Conv2d>, conv: Option<candle_nn::ConvTranspose2d>,
} }
#[derive(Debug)] #[derive(Debug)]
@ -152,20 +152,20 @@ impl WDiffNeXt {
stride: 2, stride: 2,
..Default::default() ..Default::default()
}; };
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(1))?; let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?;
(Some(layer_norm), Some(conv), 2) (Some(layer_norm), Some(conv), 1)
} else { } else {
(None, None, 0) (None, None, 0)
}; };
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
let mut layer_i = start_layer_i; let mut layer_i = start_layer_i;
for j in 0..BLOCKS[i] { for _j in 0..BLOCKS[i] {
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?; let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
layer_i += 1; layer_i += 1;
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
layer_i += 1; layer_i += 1;
let attn_block = if j == 0 { let attn_block = if i == 0 {
None None
} else { } else {
let attn_block = let attn_block =
@ -190,7 +190,7 @@ impl WDiffNeXt {
let mut up_blocks = Vec::with_capacity(C_HIDDEN.len()); let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() { for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
let vb = vb.pp("up_blocks").pp(i); let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i);
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
let mut layer_i = 0; let mut layer_i = 0;
for j in 0..BLOCKS[i] { for j in 0..BLOCKS[i] {
@ -204,7 +204,7 @@ impl WDiffNeXt {
layer_i += 1; layer_i += 1;
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
layer_i += 1; layer_i += 1;
let attn_block = if j == 0 { let attn_block = if i == 0 {
None None
} else { } else {
let attn_block = let attn_block =
@ -221,12 +221,17 @@ impl WDiffNeXt {
} }
let (layer_norm, conv) = if i > 0 { let (layer_norm, conv) = if i > 0 {
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
layer_i += 1; let cfg = candle_nn::ConvTranspose2dConfig {
let cfg = candle_nn::Conv2dConfig {
stride: 2, stride: 2,
..Default::default() ..Default::default()
}; };
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?; let conv = candle_nn::conv_transpose2d(
c_hidden,
C_HIDDEN[i - 1],
2,
cfg,
vb.pp(layer_i).pp(1),
)?;
(Some(layer_norm), Some(conv)) (Some(layer_norm), Some(conv))
} else { } else {
(None, None) (None, None)