mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Start adding the Wuerstchen diffusion pipeline (#843)
* Wuerstchen common bits. * Add the prior layer. * Start adding diffnext.
This commit is contained in:
@ -9,3 +9,4 @@ pub mod segment_anything;
|
|||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
pub mod t5;
|
pub mod t5;
|
||||||
pub mod whisper;
|
pub mod whisper;
|
||||||
|
pub mod wuerstchen;
|
||||||
|
126
candle-transformers/src/models/wuerstchen/common.rs
Normal file
126
candle-transformers/src/models/wuerstchen/common.rs
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
use candle::{Module, Result, Tensor, D};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
|
// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct WLayerNorm {
|
||||||
|
inner: candle_nn::LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WLayerNorm {
|
||||||
|
pub fn new(size: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let cfg = candle_nn::layer_norm::LayerNormConfig {
|
||||||
|
eps: 1e-6,
|
||||||
|
remove_mean: true,
|
||||||
|
affine: false,
|
||||||
|
};
|
||||||
|
let inner = candle_nn::layer_norm(size, cfg, vb)?;
|
||||||
|
Ok(Self { inner })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for WLayerNorm {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.permute((0, 2, 3, 1))?
|
||||||
|
.apply(&self.inner)?
|
||||||
|
.permute((0, 3, 1, 2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TimestepBlock {
|
||||||
|
mapper: candle_nn::Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TimestepBlock {
|
||||||
|
pub fn new(c: usize, c_timestep: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let mapper = candle_nn::linear(c_timestep, c * 2, vb.pp("mapper"))?;
|
||||||
|
Ok(Self { mapper })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, t: &Tensor) -> Result<Tensor> {
|
||||||
|
let ab = self
|
||||||
|
.mapper
|
||||||
|
.forward(t)?
|
||||||
|
.unsqueeze(2)?
|
||||||
|
.unsqueeze(3)?
|
||||||
|
.chunk(2, 1)?;
|
||||||
|
xs.broadcast_mul(&(&ab[0] + 1.)?)?.broadcast_add(&ab[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct GlobalResponseNorm {
|
||||||
|
gamma: Tensor,
|
||||||
|
beta: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GlobalResponseNorm {
|
||||||
|
pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let gamma = vb.get((1, 1, 1, 1, dim), "gamma")?;
|
||||||
|
let beta = vb.get((1, 1, 1, 1, dim), "beta")?;
|
||||||
|
Ok(Self { gamma, beta })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for GlobalResponseNorm {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
|
||||||
|
let stand_div_norm =
|
||||||
|
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
|
||||||
|
(xs.broadcast_mul(&stand_div_norm)?
|
||||||
|
.broadcast_mul(&self.gamma)
|
||||||
|
+ &self.beta)?
|
||||||
|
+ xs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ResBlock {
|
||||||
|
depthwise: candle_nn::Conv2d,
|
||||||
|
norm: WLayerNorm,
|
||||||
|
channelwise_lin1: candle_nn::Linear,
|
||||||
|
channelwise_grn: GlobalResponseNorm,
|
||||||
|
channelwise_lin2: candle_nn::Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResBlock {
|
||||||
|
pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let cfg = candle_nn::Conv2dConfig {
|
||||||
|
padding: ksize / 2,
|
||||||
|
groups: c,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
|
||||||
|
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
|
||||||
|
let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
|
||||||
|
let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
|
||||||
|
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
|
||||||
|
Ok(Self {
|
||||||
|
depthwise,
|
||||||
|
norm,
|
||||||
|
channelwise_lin1,
|
||||||
|
channelwise_grn,
|
||||||
|
channelwise_lin2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let x_res = xs;
|
||||||
|
let xs = match x_skip {
|
||||||
|
None => xs.clone(),
|
||||||
|
Some(x_skip) => Tensor::cat(&[xs, x_skip], 1)?,
|
||||||
|
};
|
||||||
|
let xs = xs
|
||||||
|
.apply(&self.depthwise)?
|
||||||
|
.apply(&self.norm)?
|
||||||
|
.permute((0, 2, 3, 1))?;
|
||||||
|
let xs = xs
|
||||||
|
.apply(&self.channelwise_lin1)?
|
||||||
|
.gelu()?
|
||||||
|
.apply(&self.channelwise_grn)?
|
||||||
|
.apply(&self.channelwise_lin2)?
|
||||||
|
.permute((0, 3, 1, 2))?;
|
||||||
|
xs + x_res
|
||||||
|
}
|
||||||
|
}
|
55
candle-transformers/src/models/wuerstchen/diffnext.rs
Normal file
55
candle-transformers/src/models/wuerstchen/diffnext.rs
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use super::common::{GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm};
|
||||||
|
use candle::{DType, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ResBlockStageB {
|
||||||
|
depthwise: candle_nn::Conv2d,
|
||||||
|
norm: WLayerNorm,
|
||||||
|
channelwise_lin1: candle_nn::Linear,
|
||||||
|
channelwise_grn: GlobalResponseNorm,
|
||||||
|
channelwise_lin2: candle_nn::Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResBlockStageB {
|
||||||
|
pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let cfg = candle_nn::Conv2dConfig {
|
||||||
|
groups: c,
|
||||||
|
padding: ksize / 2,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
|
||||||
|
let norm = WLayerNorm::new(c, vb.pp("norm"))?;
|
||||||
|
let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
|
||||||
|
let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
|
||||||
|
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
|
||||||
|
Ok(Self {
|
||||||
|
depthwise,
|
||||||
|
norm,
|
||||||
|
channelwise_lin1,
|
||||||
|
channelwise_grn,
|
||||||
|
channelwise_lin2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let x_res = xs;
|
||||||
|
let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?;
|
||||||
|
let xs = match x_skip {
|
||||||
|
None => xs.clone(),
|
||||||
|
Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?,
|
||||||
|
};
|
||||||
|
let xs = xs
|
||||||
|
.permute((0, 2, 3, 1))?
|
||||||
|
.apply(&self.channelwise_lin1)?
|
||||||
|
.gelu()?
|
||||||
|
.apply(&self.channelwise_grn)?
|
||||||
|
.apply(&self.channelwise_lin2)?
|
||||||
|
.permute((0, 3, 1, 2))?;
|
||||||
|
xs + x_res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct WDiffNeXt {}
|
3
candle-transformers/src/models/wuerstchen/mod.rs
Normal file
3
candle-transformers/src/models/wuerstchen/mod.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
pub mod common;
|
||||||
|
pub mod diffnext;
|
||||||
|
pub mod prior;
|
94
candle-transformers/src/models/wuerstchen/prior.rs
Normal file
94
candle-transformers/src/models/wuerstchen/prior.rs
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use super::common::{ResBlock, TimestepBlock};
|
||||||
|
use candle::{DType, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Block {
|
||||||
|
res_block: ResBlock,
|
||||||
|
ts_block: TimestepBlock,
|
||||||
|
// TODO: attn_block: super::common::AttnBlock,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct WPrior {
|
||||||
|
projection: candle_nn::Conv2d,
|
||||||
|
cond_mapper_lin1: candle_nn::Linear,
|
||||||
|
cond_mapper_lin2: candle_nn::Linear,
|
||||||
|
blocks: Vec<Block>,
|
||||||
|
out_ln: super::common::WLayerNorm,
|
||||||
|
out_conv: candle_nn::Conv2d,
|
||||||
|
c_r: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WPrior {
|
||||||
|
pub fn new(
|
||||||
|
c_in: usize,
|
||||||
|
c: usize,
|
||||||
|
c_cond: usize,
|
||||||
|
c_r: usize,
|
||||||
|
depth: usize,
|
||||||
|
_nhead: usize,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
|
||||||
|
let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
|
||||||
|
let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
|
||||||
|
let out_ln = super::common::WLayerNorm::new(c, vb.pp("out.0"))?;
|
||||||
|
let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
|
||||||
|
let mut blocks = Vec::with_capacity(depth);
|
||||||
|
for index in 0..depth {
|
||||||
|
let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
|
||||||
|
let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
|
||||||
|
blocks.push(Block {
|
||||||
|
res_block,
|
||||||
|
ts_block,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
projection,
|
||||||
|
cond_mapper_lin1,
|
||||||
|
cond_mapper_lin2,
|
||||||
|
blocks,
|
||||||
|
out_ln,
|
||||||
|
out_conv,
|
||||||
|
c_r,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
|
||||||
|
const MAX_POSITIONS: usize = 10000;
|
||||||
|
let r = (r * MAX_POSITIONS as f64)?;
|
||||||
|
let half_dim = self.c_r / 2;
|
||||||
|
let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
|
||||||
|
let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
|
||||||
|
* -emb)?
|
||||||
|
.exp()?;
|
||||||
|
let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
|
||||||
|
let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
|
||||||
|
let emb = if self.c_r % 2 == 1 {
|
||||||
|
emb.pad_with_zeros(D::Minus1, 0, 1)?
|
||||||
|
} else {
|
||||||
|
emb
|
||||||
|
};
|
||||||
|
emb.to_dtype(r.dtype())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||||
|
let x_in = xs;
|
||||||
|
let mut xs = xs.apply(&self.projection)?;
|
||||||
|
// TODO: leaky relu
|
||||||
|
let c_embed = c
|
||||||
|
.apply(&self.cond_mapper_lin1)?
|
||||||
|
.relu()?
|
||||||
|
.apply(&self.cond_mapper_lin2)?;
|
||||||
|
let r_embed = self.gen_r_embedding(r)?;
|
||||||
|
for block in self.blocks.iter() {
|
||||||
|
xs = block.res_block.forward(&xs, None)?;
|
||||||
|
xs = block.ts_block.forward(&xs, &r_embed)?;
|
||||||
|
// TODO: attn
|
||||||
|
}
|
||||||
|
let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?;
|
||||||
|
(x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user