mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
More Wuerstchen fixes. (#882)
* More Weurstchen fixes. * More shape fixes. * Add more of the prior specific bits. * Broadcast add. * Fix the clip config. * Add some masking options to the clip model.
This commit is contained in:
@ -16,6 +16,7 @@ use tokenizers::Tokenizer;
|
|||||||
|
|
||||||
const GUIDANCE_SCALE: f64 = 7.5;
|
const GUIDANCE_SCALE: f64 = 7.5;
|
||||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||||
|
const PRIOR_CIN: usize = 16;
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
@ -54,6 +55,10 @@ struct Args {
|
|||||||
#[arg(long, value_name = "FILE")]
|
#[arg(long, value_name = "FILE")]
|
||||||
clip_weights: Option<String>,
|
clip_weights: Option<String>,
|
||||||
|
|
||||||
|
/// The CLIP weight file used by the prior model, in .safetensors format.
|
||||||
|
#[arg(long, value_name = "FILE")]
|
||||||
|
prior_clip_weights: Option<String>,
|
||||||
|
|
||||||
/// The prior weight file, in .safetensors format.
|
/// The prior weight file, in .safetensors format.
|
||||||
#[arg(long, value_name = "FILE")]
|
#[arg(long, value_name = "FILE")]
|
||||||
prior_weights: Option<String>,
|
prior_weights: Option<String>,
|
||||||
@ -66,6 +71,10 @@ struct Args {
|
|||||||
/// The file specifying the tokenizer to used for tokenization.
|
/// The file specifying the tokenizer to used for tokenization.
|
||||||
tokenizer: Option<String>,
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, value_name = "FILE")]
|
||||||
|
/// The file specifying the tokenizer to used for prior tokenization.
|
||||||
|
prior_tokenizer: Option<String>,
|
||||||
|
|
||||||
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
sliced_attention_size: Option<usize>,
|
sliced_attention_size: Option<usize>,
|
||||||
@ -86,7 +95,9 @@ struct Args {
|
|||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
enum ModelFile {
|
enum ModelFile {
|
||||||
Tokenizer,
|
Tokenizer,
|
||||||
|
PriorTokenizer,
|
||||||
Clip,
|
Clip,
|
||||||
|
PriorClip,
|
||||||
Decoder,
|
Decoder,
|
||||||
VqGan,
|
VqGan,
|
||||||
Prior,
|
Prior,
|
||||||
@ -102,7 +113,9 @@ impl ModelFile {
|
|||||||
let repo_prior = "warp-ai/wuerstchen-prior";
|
let repo_prior = "warp-ai/wuerstchen-prior";
|
||||||
let (repo, path) = match self {
|
let (repo, path) = match self {
|
||||||
Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"),
|
Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"),
|
||||||
|
Self::PriorTokenizer => (repo_prior, "tokenizer/tokenizer.json"),
|
||||||
Self::Clip => (repo_main, "text_encoder/model.safetensors"),
|
Self::Clip => (repo_main, "text_encoder/model.safetensors"),
|
||||||
|
Self::PriorClip => (repo_prior, "text_encoder/model.safetensors"),
|
||||||
Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"),
|
Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"),
|
||||||
Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"),
|
Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"),
|
||||||
Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"),
|
Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"),
|
||||||
@ -144,12 +157,11 @@ fn output_filename(
|
|||||||
fn encode_prompt(
|
fn encode_prompt(
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
uncond_prompt: &str,
|
uncond_prompt: &str,
|
||||||
tokenizer: Option<String>,
|
tokenizer: std::path::PathBuf,
|
||||||
clip_weights: Option<String>,
|
clip_weights: std::path::PathBuf,
|
||||||
clip_config: stable_diffusion::clip::Config,
|
clip_config: stable_diffusion::clip::Config,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
let pad_id = match &clip_config.pad_with {
|
let pad_id = match &clip_config.pad_with {
|
||||||
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
|
||||||
@ -161,6 +173,7 @@ fn encode_prompt(
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
let tokens_len = tokens.len();
|
||||||
while tokens.len() < clip_config.max_position_embeddings {
|
while tokens.len() < clip_config.max_position_embeddings {
|
||||||
tokens.push(pad_id)
|
tokens.push(pad_id)
|
||||||
}
|
}
|
||||||
@ -171,17 +184,17 @@ fn encode_prompt(
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
let uncond_tokens_len = uncond_tokens.len();
|
||||||
while uncond_tokens.len() < clip_config.max_position_embeddings {
|
while uncond_tokens.len() < clip_config.max_position_embeddings {
|
||||||
uncond_tokens.push(pad_id)
|
uncond_tokens.push(pad_id)
|
||||||
}
|
}
|
||||||
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
|
||||||
|
|
||||||
println!("Building the Clip transformer.");
|
println!("Building the clip transformer.");
|
||||||
let clip_weights = ModelFile::Clip.get(clip_weights)?;
|
|
||||||
let text_model =
|
let text_model =
|
||||||
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
|
||||||
let text_embeddings = text_model.forward(&tokens)?;
|
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len)?;
|
||||||
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
|
let uncond_embeddings = text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len)?;
|
||||||
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
|
||||||
Ok(text_embeddings)
|
Ok(text_embeddings)
|
||||||
}
|
}
|
||||||
@ -221,15 +234,19 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let height = height.unwrap_or(1024);
|
let height = height.unwrap_or(1024);
|
||||||
let width = width.unwrap_or(1024);
|
let width = width.unwrap_or(1024);
|
||||||
|
|
||||||
let text_embeddings = encode_prompt(
|
let prior_text_embeddings = {
|
||||||
&prompt,
|
let tokenizer = ModelFile::PriorTokenizer.get(args.prior_tokenizer)?;
|
||||||
&uncond_prompt,
|
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
|
||||||
tokenizer.clone(),
|
encode_prompt(
|
||||||
clip_weights.clone(),
|
&prompt,
|
||||||
stable_diffusion::clip::Config::wuerstchen(),
|
&uncond_prompt,
|
||||||
&device,
|
tokenizer.clone(),
|
||||||
)?;
|
weights,
|
||||||
println!("{text_embeddings:?}");
|
stable_diffusion::clip::Config::wuerstchen_prior(),
|
||||||
|
&device,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
println!("{prior_text_embeddings}");
|
||||||
|
|
||||||
println!("Building the prior.");
|
println!("Building the prior.");
|
||||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
||||||
@ -239,8 +256,8 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||||
wuerstchen::prior::WPrior::new(
|
wuerstchen::prior::WPrior::new(
|
||||||
/* c_in */ 16, /* c */ 1536, /* c_cond */ 1280, /* c_r */ 64,
|
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
|
||||||
/* depth */ 32, /* nhead */ 24, vb,
|
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -274,12 +291,12 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let latents = Tensor::randn(
|
let latents = Tensor::randn(
|
||||||
0f32,
|
0f32,
|
||||||
1f32,
|
1f32,
|
||||||
(b_size, 4, latent_height, latent_width),
|
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||||
&device,
|
&device,
|
||||||
)?;
|
)?;
|
||||||
// TODO: latents denoising loop, use the scheduler values.
|
// TODO: latents denoising loop, use the scheduler values.
|
||||||
let ratio = Tensor::ones(1, DType::F32, &device)?;
|
let ratio = Tensor::ones(1, DType::F32, &device)?;
|
||||||
let prior = prior.forward(&latents, &ratio, &text_embeddings)?;
|
let prior = prior.forward(&latents, &ratio, &prior_text_embeddings)?;
|
||||||
|
|
||||||
let latents = ((latents * 42.)? - 1.)?;
|
let latents = ((latents * 42.)? - 1.)?;
|
||||||
/*
|
/*
|
||||||
|
@ -107,13 +107,28 @@ impl Config {
|
|||||||
embed_dim: 1024,
|
embed_dim: 1024,
|
||||||
intermediate_size: 4096,
|
intermediate_size: 4096,
|
||||||
max_position_embeddings: 77,
|
max_position_embeddings: 77,
|
||||||
pad_with: Some("!".to_string()),
|
pad_with: None,
|
||||||
num_hidden_layers: 24,
|
num_hidden_layers: 24,
|
||||||
num_attention_heads: 16,
|
num_attention_heads: 16,
|
||||||
projection_dim: 1024,
|
projection_dim: 1024,
|
||||||
activation: Activation::Gelu,
|
activation: Activation::Gelu,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/text_encoder/config.json
|
||||||
|
pub fn wuerstchen_prior() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 49408,
|
||||||
|
embed_dim: 1280,
|
||||||
|
intermediate_size: 5120,
|
||||||
|
max_position_embeddings: 77,
|
||||||
|
pad_with: None,
|
||||||
|
num_hidden_layers: 32,
|
||||||
|
num_attention_heads: 20,
|
||||||
|
projection_dim: 512,
|
||||||
|
activation: Activation::Gelu,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CLIP Text Model
|
// CLIP Text Model
|
||||||
@ -334,21 +349,39 @@ impl ClipTextTransformer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
|
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
|
||||||
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
|
fn build_causal_attention_mask(
|
||||||
|
bsz: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
mask_after: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let mask: Vec<_> = (0..seq_len)
|
let mask: Vec<_> = (0..seq_len)
|
||||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
|
.flat_map(|i| {
|
||||||
|
(0..seq_len).map(move |j| {
|
||||||
|
if j > i || j > mask_after {
|
||||||
|
f32::MIN
|
||||||
|
} else {
|
||||||
|
0.
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||||
mask.broadcast_as((bsz, seq_len, seq_len))
|
mask.broadcast_as((bsz, seq_len, seq_len))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn forward_with_mask(&self, xs: &Tensor, mask_after: usize) -> Result<Tensor> {
|
||||||
|
let (bsz, seq_len) = xs.dims2()?;
|
||||||
|
let xs = self.embeddings.forward(xs)?;
|
||||||
|
let causal_attention_mask =
|
||||||
|
Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
|
||||||
|
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
|
||||||
|
self.final_layer_norm.forward(&xs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for ClipTextTransformer {
|
impl Module for ClipTextTransformer {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let (bsz, seq_len) = xs.dims2()?;
|
self.forward_with_mask(xs, usize::MAX)
|
||||||
let xs = self.embeddings.forward(xs)?;
|
|
||||||
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
|
|
||||||
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
|
|
||||||
self.final_layer_norm.forward(&xs)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -75,9 +75,9 @@ impl Module for GlobalResponseNorm {
|
|||||||
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
|
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
|
||||||
let stand_div_norm =
|
let stand_div_norm =
|
||||||
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
|
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
|
||||||
(xs.broadcast_mul(&stand_div_norm)?
|
xs.broadcast_mul(&stand_div_norm)?
|
||||||
.broadcast_mul(&self.gamma)
|
.broadcast_mul(&self.gamma)?
|
||||||
+ &self.beta)?
|
.broadcast_add(&self.beta)?
|
||||||
+ xs
|
+ xs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ struct DownBlock {
|
|||||||
struct UpBlock {
|
struct UpBlock {
|
||||||
sub_blocks: Vec<SubBlock>,
|
sub_blocks: Vec<SubBlock>,
|
||||||
layer_norm: Option<WLayerNorm>,
|
layer_norm: Option<WLayerNorm>,
|
||||||
conv: Option<candle_nn::Conv2d>,
|
conv: Option<candle_nn::ConvTranspose2d>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -152,20 +152,20 @@ impl WDiffNeXt {
|
|||||||
stride: 2,
|
stride: 2,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(1))?;
|
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?;
|
||||||
(Some(layer_norm), Some(conv), 2)
|
(Some(layer_norm), Some(conv), 1)
|
||||||
} else {
|
} else {
|
||||||
(None, None, 0)
|
(None, None, 0)
|
||||||
};
|
};
|
||||||
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
|
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
|
||||||
let mut layer_i = start_layer_i;
|
let mut layer_i = start_layer_i;
|
||||||
for j in 0..BLOCKS[i] {
|
for _j in 0..BLOCKS[i] {
|
||||||
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
|
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
|
||||||
let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
|
let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
|
||||||
layer_i += 1;
|
layer_i += 1;
|
||||||
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
|
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
|
||||||
layer_i += 1;
|
layer_i += 1;
|
||||||
let attn_block = if j == 0 {
|
let attn_block = if i == 0 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
let attn_block =
|
let attn_block =
|
||||||
@ -190,7 +190,7 @@ impl WDiffNeXt {
|
|||||||
|
|
||||||
let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
|
let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
|
||||||
for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
|
for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
|
||||||
let vb = vb.pp("up_blocks").pp(i);
|
let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i);
|
||||||
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
|
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
|
||||||
let mut layer_i = 0;
|
let mut layer_i = 0;
|
||||||
for j in 0..BLOCKS[i] {
|
for j in 0..BLOCKS[i] {
|
||||||
@ -204,7 +204,7 @@ impl WDiffNeXt {
|
|||||||
layer_i += 1;
|
layer_i += 1;
|
||||||
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
|
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
|
||||||
layer_i += 1;
|
layer_i += 1;
|
||||||
let attn_block = if j == 0 {
|
let attn_block = if i == 0 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
let attn_block =
|
let attn_block =
|
||||||
@ -221,12 +221,17 @@ impl WDiffNeXt {
|
|||||||
}
|
}
|
||||||
let (layer_norm, conv) = if i > 0 {
|
let (layer_norm, conv) = if i > 0 {
|
||||||
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
|
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
|
||||||
layer_i += 1;
|
let cfg = candle_nn::ConvTranspose2dConfig {
|
||||||
let cfg = candle_nn::Conv2dConfig {
|
|
||||||
stride: 2,
|
stride: 2,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?;
|
let conv = candle_nn::conv_transpose2d(
|
||||||
|
c_hidden,
|
||||||
|
C_HIDDEN[i - 1],
|
||||||
|
2,
|
||||||
|
cfg,
|
||||||
|
vb.pp(layer_i).pp(1),
|
||||||
|
)?;
|
||||||
(Some(layer_norm), Some(conv))
|
(Some(layer_norm), Some(conv))
|
||||||
} else {
|
} else {
|
||||||
(None, None)
|
(None, None)
|
||||||
|
Reference in New Issue
Block a user