mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Wuerstchen main (#876)
* Wuerstchen main. * More of the wuerstchen cli example. * Paella creation. * Build the prior model. * Fix the weight file names.
This commit is contained in:
@ -99,6 +99,21 @@ impl Config {
|
||||
activation: Activation::Gelu,
|
||||
}
|
||||
}
|
||||
|
||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
|
||||
pub fn wuerstchen() -> Self {
|
||||
Self {
|
||||
vocab_size: 49408,
|
||||
embed_dim: 1024,
|
||||
intermediate_size: 4096,
|
||||
max_position_embeddings: 77,
|
||||
pad_with: Some("!".to_string()),
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 16,
|
||||
projection_dim: 1024,
|
||||
activation: Activation::Gelu,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CLIP Text Model
|
||||
|
@ -65,17 +65,121 @@ impl Module for MixingResidualBlock {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PaellaVQ {
|
||||
pub struct PaellaVQ {
|
||||
in_block_conv: candle_nn::Conv2d,
|
||||
out_block_conv: candle_nn::Conv2d,
|
||||
down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>,
|
||||
down_blocks_conv: candle_nn::Conv2d,
|
||||
down_blocks_bn: candle_nn::BatchNorm,
|
||||
up_blocks_conv: candle_nn::Conv2d,
|
||||
up_blocks: Vec<(MixingResidualBlock, Option<candle_nn::ConvTranspose2d>)>,
|
||||
up_blocks: Vec<(Vec<MixingResidualBlock>, Option<candle_nn::ConvTranspose2d>)>,
|
||||
}
|
||||
|
||||
impl PaellaVQ {
|
||||
pub fn new(vb: VarBuilder) -> Result<Self> {
|
||||
const IN_CHANNELS: usize = 3;
|
||||
const OUT_CHANNELS: usize = 3;
|
||||
const LATENT_CHANNELS: usize = 4;
|
||||
const EMBED_DIM: usize = 384;
|
||||
const BOTTLENECK_BLOCKS: usize = 12;
|
||||
const C_LEVELS: [usize; 2] = [EMBED_DIM / 2, EMBED_DIM];
|
||||
|
||||
let in_block_conv = candle_nn::conv2d(
|
||||
IN_CHANNELS * 4,
|
||||
C_LEVELS[0],
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("in_block.1"),
|
||||
)?;
|
||||
let out_block_conv = candle_nn::conv2d(
|
||||
C_LEVELS[0],
|
||||
OUT_CHANNELS * 4,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("out_block.0"),
|
||||
)?;
|
||||
|
||||
let mut down_blocks = Vec::new();
|
||||
let vb_d = vb.pp("down_blocks");
|
||||
let mut d_idx = 0;
|
||||
for (i, &c_level) in C_LEVELS.iter().enumerate() {
|
||||
let conv_block = if i > 0 {
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let block =
|
||||
candle_nn::conv2d_no_bias(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?;
|
||||
d_idx += 1;
|
||||
Some(block)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_d.pp(d_idx))?;
|
||||
d_idx += 1;
|
||||
down_blocks.push((conv_block, res_block))
|
||||
}
|
||||
let down_blocks_conv = candle_nn::conv2d_no_bias(
|
||||
C_LEVELS[1],
|
||||
LATENT_CHANNELS,
|
||||
1,
|
||||
Default::default(),
|
||||
vb_d.pp(d_idx),
|
||||
)?;
|
||||
d_idx += 1;
|
||||
let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(d_idx))?;
|
||||
|
||||
let mut up_blocks = Vec::new();
|
||||
let vb_u = vb.pp("up_blocks");
|
||||
let mut u_idx = 0;
|
||||
let up_blocks_conv = candle_nn::conv2d_no_bias(
|
||||
LATENT_CHANNELS,
|
||||
C_LEVELS[1],
|
||||
1,
|
||||
Default::default(),
|
||||
vb_u.pp(u_idx),
|
||||
)?;
|
||||
u_idx += 1;
|
||||
for (i, &c_level) in C_LEVELS.iter().rev().enumerate() {
|
||||
let mut res_blocks = Vec::new();
|
||||
let n_bottleneck_blocks = if i == 0 { BOTTLENECK_BLOCKS } else { 1 };
|
||||
for _j in 0..n_bottleneck_blocks {
|
||||
let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_u.pp(u_idx))?;
|
||||
u_idx += 1;
|
||||
res_blocks.push(res_block)
|
||||
}
|
||||
let conv_block = if i < C_LEVELS.len() - 1 {
|
||||
let cfg = candle_nn::ConvTranspose2dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let block = candle_nn::conv_transpose2d_no_bias(
|
||||
c_level,
|
||||
C_LEVELS[i - 1],
|
||||
4,
|
||||
cfg,
|
||||
vb_u.pp(u_idx),
|
||||
)?;
|
||||
u_idx += 1;
|
||||
Some(block)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
up_blocks.push((res_blocks, conv_block))
|
||||
}
|
||||
Ok(Self {
|
||||
in_block_conv,
|
||||
down_blocks,
|
||||
down_blocks_conv,
|
||||
down_blocks_bn,
|
||||
up_blocks,
|
||||
up_blocks_conv,
|
||||
out_block_conv,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?;
|
||||
for down_block in self.down_blocks.iter() {
|
||||
@ -92,7 +196,9 @@ impl PaellaVQ {
|
||||
// TODO: quantizer if we want to support `force_not_quantize=False`.
|
||||
let mut xs = xs.apply(&self.up_blocks_conv)?;
|
||||
for up_block in self.up_blocks.iter() {
|
||||
xs = xs.apply(&up_block.0)?;
|
||||
for b in up_block.0.iter() {
|
||||
xs = xs.apply(b)?;
|
||||
}
|
||||
if let Some(conv) = &up_block.1 {
|
||||
xs = xs.apply(conv)?
|
||||
}
|
||||
|
Reference in New Issue
Block a user