mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
397 lines
13 KiB
Rust
397 lines
13 KiB
Rust
use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
|
|
use candle::{DType, Module, Result, Tensor, D};
|
|
use candle_nn::VarBuilder;
|
|
|
|
#[derive(Debug)]
|
|
pub struct ResBlockStageB {
|
|
depthwise: candle_nn::Conv2d,
|
|
norm: WLayerNorm,
|
|
channelwise_lin1: candle_nn::Linear,
|
|
channelwise_grn: GlobalResponseNorm,
|
|
channelwise_lin2: candle_nn::Linear,
|
|
}
|
|
|
|
impl ResBlockStageB {
|
|
pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
|
|
let cfg = candle_nn::Conv2dConfig {
|
|
groups: c,
|
|
padding: ksize / 2,
|
|
..Default::default()
|
|
};
|
|
let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
|
|
let norm = WLayerNorm::new(c)?;
|
|
let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
|
|
let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
|
|
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
|
|
Ok(Self {
|
|
depthwise,
|
|
norm,
|
|
channelwise_lin1,
|
|
channelwise_grn,
|
|
channelwise_lin2,
|
|
})
|
|
}
|
|
|
|
pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
|
|
let x_res = xs;
|
|
let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?;
|
|
let xs = match x_skip {
|
|
None => xs.clone(),
|
|
Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?,
|
|
};
|
|
let xs = xs
|
|
.permute((0, 2, 3, 1))?
|
|
.contiguous()?
|
|
.apply(&self.channelwise_lin1)?
|
|
.gelu()?
|
|
.apply(&self.channelwise_grn)?
|
|
.apply(&self.channelwise_lin2)?
|
|
.permute((0, 3, 1, 2))?;
|
|
xs + x_res
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct SubBlock {
|
|
res_block: ResBlockStageB,
|
|
ts_block: TimestepBlock,
|
|
attn_block: Option<AttnBlock>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct DownBlock {
|
|
layer_norm: Option<WLayerNorm>,
|
|
conv: Option<candle_nn::Conv2d>,
|
|
sub_blocks: Vec<SubBlock>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct UpBlock {
|
|
sub_blocks: Vec<SubBlock>,
|
|
layer_norm: Option<WLayerNorm>,
|
|
conv: Option<candle_nn::ConvTranspose2d>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct WDiffNeXt {
|
|
clip_mapper: candle_nn::Linear,
|
|
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
|
|
seq_norm: LayerNormNoWeights,
|
|
embedding_conv: candle_nn::Conv2d,
|
|
embedding_ln: WLayerNorm,
|
|
down_blocks: Vec<DownBlock>,
|
|
up_blocks: Vec<UpBlock>,
|
|
clf_ln: WLayerNorm,
|
|
clf_conv: candle_nn::Conv2d,
|
|
c_r: usize,
|
|
patch_size: usize,
|
|
}
|
|
|
|
impl WDiffNeXt {
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub fn new(
|
|
c_in: usize,
|
|
c_out: usize,
|
|
c_r: usize,
|
|
c_cond: usize,
|
|
clip_embd: usize,
|
|
patch_size: usize,
|
|
use_flash_attn: bool,
|
|
vb: VarBuilder,
|
|
) -> Result<Self> {
|
|
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
|
|
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
|
|
const NHEAD: [usize; 4] = [1, 10, 20, 20];
|
|
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
|
|
const EFFNET_EMBD: usize = 16;
|
|
|
|
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
|
|
let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len());
|
|
let vb_e = vb.pp("effnet_mappers");
|
|
for (i, &inject) in INJECT_EFFNET.iter().enumerate() {
|
|
let c = if inject {
|
|
Some(candle_nn::conv2d(
|
|
EFFNET_EMBD,
|
|
c_cond,
|
|
1,
|
|
Default::default(),
|
|
vb_e.pp(i),
|
|
)?)
|
|
} else {
|
|
None
|
|
};
|
|
effnet_mappers.push(c)
|
|
}
|
|
for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() {
|
|
let c = if inject {
|
|
Some(candle_nn::conv2d(
|
|
EFFNET_EMBD,
|
|
c_cond,
|
|
1,
|
|
Default::default(),
|
|
vb_e.pp(i + INJECT_EFFNET.len()),
|
|
)?)
|
|
} else {
|
|
None
|
|
};
|
|
effnet_mappers.push(c)
|
|
}
|
|
let seq_norm = LayerNormNoWeights::new(c_cond)?;
|
|
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
|
|
let embedding_conv = candle_nn::conv2d(
|
|
c_in * patch_size * patch_size,
|
|
C_HIDDEN[0],
|
|
1,
|
|
Default::default(),
|
|
vb.pp("embedding.1"),
|
|
)?;
|
|
|
|
let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
|
|
for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
|
|
let vb = vb.pp("down_blocks").pp(i);
|
|
let (layer_norm, conv, start_layer_i) = if i > 0 {
|
|
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
|
|
let cfg = candle_nn::Conv2dConfig {
|
|
stride: 2,
|
|
..Default::default()
|
|
};
|
|
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?;
|
|
(Some(layer_norm), Some(conv), 1)
|
|
} else {
|
|
(None, None, 0)
|
|
};
|
|
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
|
|
let mut layer_i = start_layer_i;
|
|
for _j in 0..BLOCKS[i] {
|
|
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
|
|
let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
|
|
layer_i += 1;
|
|
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
|
|
layer_i += 1;
|
|
let attn_block = if i == 0 {
|
|
None
|
|
} else {
|
|
let attn_block = AttnBlock::new(
|
|
c_hidden,
|
|
c_cond,
|
|
NHEAD[i],
|
|
true,
|
|
use_flash_attn,
|
|
vb.pp(layer_i),
|
|
)?;
|
|
layer_i += 1;
|
|
Some(attn_block)
|
|
};
|
|
let sub_block = SubBlock {
|
|
res_block,
|
|
ts_block,
|
|
attn_block,
|
|
};
|
|
sub_blocks.push(sub_block)
|
|
}
|
|
let down_block = DownBlock {
|
|
layer_norm,
|
|
conv,
|
|
sub_blocks,
|
|
};
|
|
down_blocks.push(down_block)
|
|
}
|
|
|
|
let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
|
|
for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
|
|
let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i);
|
|
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
|
|
let mut layer_i = 0;
|
|
for j in 0..BLOCKS[i] {
|
|
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
|
|
let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 {
|
|
c_hidden + c_skip
|
|
} else {
|
|
c_skip
|
|
};
|
|
let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?;
|
|
layer_i += 1;
|
|
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
|
|
layer_i += 1;
|
|
let attn_block = if i == 0 {
|
|
None
|
|
} else {
|
|
let attn_block = AttnBlock::new(
|
|
c_hidden,
|
|
c_cond,
|
|
NHEAD[i],
|
|
true,
|
|
use_flash_attn,
|
|
vb.pp(layer_i),
|
|
)?;
|
|
layer_i += 1;
|
|
Some(attn_block)
|
|
};
|
|
let sub_block = SubBlock {
|
|
res_block,
|
|
ts_block,
|
|
attn_block,
|
|
};
|
|
sub_blocks.push(sub_block)
|
|
}
|
|
let (layer_norm, conv) = if i > 0 {
|
|
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
|
|
let cfg = candle_nn::ConvTranspose2dConfig {
|
|
stride: 2,
|
|
..Default::default()
|
|
};
|
|
let conv = candle_nn::conv_transpose2d(
|
|
c_hidden,
|
|
C_HIDDEN[i - 1],
|
|
2,
|
|
cfg,
|
|
vb.pp(layer_i).pp(1),
|
|
)?;
|
|
(Some(layer_norm), Some(conv))
|
|
} else {
|
|
(None, None)
|
|
};
|
|
let up_block = UpBlock {
|
|
layer_norm,
|
|
conv,
|
|
sub_blocks,
|
|
};
|
|
up_blocks.push(up_block)
|
|
}
|
|
|
|
let clf_ln = WLayerNorm::new(C_HIDDEN[0])?;
|
|
let clf_conv = candle_nn::conv2d(
|
|
C_HIDDEN[0],
|
|
2 * c_out * patch_size * patch_size,
|
|
1,
|
|
Default::default(),
|
|
vb.pp("clf.1"),
|
|
)?;
|
|
Ok(Self {
|
|
clip_mapper,
|
|
effnet_mappers,
|
|
seq_norm,
|
|
embedding_conv,
|
|
embedding_ln,
|
|
down_blocks,
|
|
up_blocks,
|
|
clf_ln,
|
|
clf_conv,
|
|
c_r,
|
|
patch_size,
|
|
})
|
|
}
|
|
|
|
fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
|
|
const MAX_POSITIONS: usize = 10000;
|
|
let r = (r * MAX_POSITIONS as f64)?;
|
|
let half_dim = self.c_r / 2;
|
|
let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
|
|
let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
|
|
* -emb)?
|
|
.exp()?;
|
|
let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
|
|
let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
|
|
let emb = if self.c_r % 2 == 1 {
|
|
emb.pad_with_zeros(D::Minus1, 0, 1)?
|
|
} else {
|
|
emb
|
|
};
|
|
emb.to_dtype(r.dtype())
|
|
}
|
|
|
|
fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> {
|
|
clip.apply(&self.clip_mapper)?.apply(&self.seq_norm)
|
|
}
|
|
|
|
pub fn forward(
|
|
&self,
|
|
xs: &Tensor,
|
|
r: &Tensor,
|
|
effnet: &Tensor,
|
|
clip: Option<&Tensor>,
|
|
) -> Result<Tensor> {
|
|
const EPS: f64 = 1e-3;
|
|
|
|
let r_embed = self.gen_r_embedding(r)?;
|
|
let clip = match clip {
|
|
None => None,
|
|
Some(clip) => Some(self.gen_c_embeddings(clip)?),
|
|
};
|
|
let x_in = xs;
|
|
|
|
let mut xs = xs
|
|
.apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))?
|
|
.apply(&self.embedding_conv)?
|
|
.apply(&self.embedding_ln)?;
|
|
|
|
let mut level_outputs = Vec::new();
|
|
for (i, down_block) in self.down_blocks.iter().enumerate() {
|
|
if let Some(ln) = &down_block.layer_norm {
|
|
xs = xs.apply(ln)?
|
|
}
|
|
if let Some(conv) = &down_block.conv {
|
|
xs = xs.apply(conv)?
|
|
}
|
|
let skip = match &self.effnet_mappers[i] {
|
|
None => None,
|
|
Some(m) => {
|
|
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
|
|
Some(m.forward(&effnet)?)
|
|
}
|
|
};
|
|
for block in down_block.sub_blocks.iter() {
|
|
xs = block.res_block.forward(&xs, skip.as_ref())?;
|
|
xs = block.ts_block.forward(&xs, &r_embed)?;
|
|
if let Some(attn_block) = &block.attn_block {
|
|
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
|
|
}
|
|
}
|
|
level_outputs.push(xs.clone())
|
|
}
|
|
level_outputs.reverse();
|
|
let mut xs = level_outputs[0].clone();
|
|
|
|
for (i, up_block) in self.up_blocks.iter().enumerate() {
|
|
let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
|
|
None => None,
|
|
Some(m) => {
|
|
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
|
|
Some(m.forward(&effnet)?)
|
|
}
|
|
};
|
|
for (j, block) in up_block.sub_blocks.iter().enumerate() {
|
|
let skip = if j == 0 && i > 0 {
|
|
Some(&level_outputs[i])
|
|
} else {
|
|
None
|
|
};
|
|
let skip = match (skip, effnet_c.as_ref()) {
|
|
(Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
|
|
(None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
|
|
(None, None) => None,
|
|
};
|
|
xs = block.res_block.forward(&xs, skip.as_ref())?;
|
|
xs = block.ts_block.forward(&xs, &r_embed)?;
|
|
if let Some(attn_block) = &block.attn_block {
|
|
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
|
|
}
|
|
}
|
|
if let Some(ln) = &up_block.layer_norm {
|
|
xs = xs.apply(ln)?
|
|
}
|
|
if let Some(conv) = &up_block.conv {
|
|
xs = xs.apply(conv)?
|
|
}
|
|
}
|
|
|
|
let ab = xs
|
|
.apply(&self.clf_ln)?
|
|
.apply(&self.clf_conv)?
|
|
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
|
|
.chunk(2, 1)?;
|
|
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
|
|
(x_in - &ab[0])? / b
|
|
}
|
|
}
|