mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
More Wuerstchen fixes. (#882)
* More Weurstchen fixes. * More shape fixes. * Add more of the prior specific bits. * Broadcast add. * Fix the clip config. * Add some masking options to the clip model.
This commit is contained in:
@ -107,13 +107,28 @@ impl Config {
|
||||
embed_dim: 1024,
|
||||
intermediate_size: 4096,
|
||||
max_position_embeddings: 77,
|
||||
pad_with: Some("!".to_string()),
|
||||
pad_with: None,
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 16,
|
||||
projection_dim: 1024,
|
||||
activation: Activation::Gelu,
|
||||
}
|
||||
}
|
||||
|
||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/text_encoder/config.json
|
||||
pub fn wuerstchen_prior() -> Self {
|
||||
Self {
|
||||
vocab_size: 49408,
|
||||
embed_dim: 1280,
|
||||
intermediate_size: 5120,
|
||||
max_position_embeddings: 77,
|
||||
pad_with: None,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 20,
|
||||
projection_dim: 512,
|
||||
activation: Activation::Gelu,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CLIP Text Model
|
||||
@ -334,21 +349,39 @@ impl ClipTextTransformer {
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
|
||||
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
|
||||
fn build_causal_attention_mask(
|
||||
bsz: usize,
|
||||
seq_len: usize,
|
||||
mask_after: usize,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
|
||||
.flat_map(|i| {
|
||||
(0..seq_len).map(move |j| {
|
||||
if j > i || j > mask_after {
|
||||
f32::MIN
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||
mask.broadcast_as((bsz, seq_len, seq_len))
|
||||
}
|
||||
|
||||
pub fn forward_with_mask(&self, xs: &Tensor, mask_after: usize) -> Result<Tensor> {
|
||||
let (bsz, seq_len) = xs.dims2()?;
|
||||
let xs = self.embeddings.forward(xs)?;
|
||||
let causal_attention_mask =
|
||||
Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
|
||||
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
|
||||
self.final_layer_norm.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ClipTextTransformer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (bsz, seq_len) = xs.dims2()?;
|
||||
let xs = self.embeddings.forward(xs)?;
|
||||
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
|
||||
let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
|
||||
self.final_layer_norm.forward(&xs)
|
||||
self.forward_with_mask(xs, usize::MAX)
|
||||
}
|
||||
}
|
||||
|
@ -75,9 +75,9 @@ impl Module for GlobalResponseNorm {
|
||||
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
|
||||
let stand_div_norm =
|
||||
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
|
||||
(xs.broadcast_mul(&stand_div_norm)?
|
||||
.broadcast_mul(&self.gamma)
|
||||
+ &self.beta)?
|
||||
xs.broadcast_mul(&stand_div_norm)?
|
||||
.broadcast_mul(&self.gamma)?
|
||||
.broadcast_add(&self.beta)?
|
||||
+ xs
|
||||
}
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ struct DownBlock {
|
||||
struct UpBlock {
|
||||
sub_blocks: Vec<SubBlock>,
|
||||
layer_norm: Option<WLayerNorm>,
|
||||
conv: Option<candle_nn::Conv2d>,
|
||||
conv: Option<candle_nn::ConvTranspose2d>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -152,20 +152,20 @@ impl WDiffNeXt {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(1))?;
|
||||
(Some(layer_norm), Some(conv), 2)
|
||||
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?;
|
||||
(Some(layer_norm), Some(conv), 1)
|
||||
} else {
|
||||
(None, None, 0)
|
||||
};
|
||||
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
|
||||
let mut layer_i = start_layer_i;
|
||||
for j in 0..BLOCKS[i] {
|
||||
for _j in 0..BLOCKS[i] {
|
||||
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
|
||||
let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
|
||||
layer_i += 1;
|
||||
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
|
||||
layer_i += 1;
|
||||
let attn_block = if j == 0 {
|
||||
let attn_block = if i == 0 {
|
||||
None
|
||||
} else {
|
||||
let attn_block =
|
||||
@ -190,7 +190,7 @@ impl WDiffNeXt {
|
||||
|
||||
let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
|
||||
for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
|
||||
let vb = vb.pp("up_blocks").pp(i);
|
||||
let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i);
|
||||
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
|
||||
let mut layer_i = 0;
|
||||
for j in 0..BLOCKS[i] {
|
||||
@ -204,7 +204,7 @@ impl WDiffNeXt {
|
||||
layer_i += 1;
|
||||
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
|
||||
layer_i += 1;
|
||||
let attn_block = if j == 0 {
|
||||
let attn_block = if i == 0 {
|
||||
None
|
||||
} else {
|
||||
let attn_block =
|
||||
@ -221,12 +221,17 @@ impl WDiffNeXt {
|
||||
}
|
||||
let (layer_norm, conv) = if i > 0 {
|
||||
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
|
||||
layer_i += 1;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
let cfg = candle_nn::ConvTranspose2dConfig {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?;
|
||||
let conv = candle_nn::conv_transpose2d(
|
||||
c_hidden,
|
||||
C_HIDDEN[i - 1],
|
||||
2,
|
||||
cfg,
|
||||
vb.pp(layer_i).pp(1),
|
||||
)?;
|
||||
(Some(layer_norm), Some(conv))
|
||||
} else {
|
||||
(None, None)
|
||||
|
Reference in New Issue
Block a user