mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the upblocks. (#853)
This commit is contained in:
@ -110,7 +110,7 @@ impl ToUsize2 for (usize, usize) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A simple trait defining a module with forward method using a single argument.
|
// A simple trait defining a module with forward method using a single argument.
|
||||||
pub trait Module: std::fmt::Debug {
|
pub trait Module {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,3 +119,9 @@ impl Module for quantized::QMatMul {
|
|||||||
self.forward(xs)
|
self.forward(xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
self(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -44,6 +44,10 @@ pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
|||||||
(xs.neg()?.exp()? + 1.0)?.recip()
|
(xs.neg()?.exp()? + 1.0)?.recip()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
|
||||||
|
xs.relu()?.minimum(&(xs * negative_slope)?)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
|
pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
|
||||||
// This implementation is inefficient as it stores the full mask for the backward pass.
|
// This implementation is inefficient as it stores the full mask for the backward pass.
|
||||||
// Instead we could just store the seed and have a specialized kernel that would both
|
// Instead we could just store the seed and have a specialized kernel that would both
|
||||||
|
@ -161,8 +161,57 @@ impl WDiffNeXt {
|
|||||||
down_blocks.push(down_block)
|
down_blocks.push(down_block)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: populate.
|
let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
|
||||||
let up_blocks = Vec::with_capacity(C_HIDDEN.len());
|
for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
|
||||||
|
let vb = vb.pp("up_blocks").pp(i);
|
||||||
|
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
|
||||||
|
let mut layer_i = 0;
|
||||||
|
for j in 0..BLOCKS[i] {
|
||||||
|
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
|
||||||
|
let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 {
|
||||||
|
c_hidden + c_skip
|
||||||
|
} else {
|
||||||
|
c_skip
|
||||||
|
};
|
||||||
|
let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?;
|
||||||
|
layer_i += 1;
|
||||||
|
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
|
||||||
|
layer_i += 1;
|
||||||
|
let attn_block = if j == 0 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let attn_block =
|
||||||
|
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
|
||||||
|
layer_i += 1;
|
||||||
|
Some(attn_block)
|
||||||
|
};
|
||||||
|
let sub_block = SubBlock {
|
||||||
|
res_block,
|
||||||
|
ts_block,
|
||||||
|
attn_block,
|
||||||
|
};
|
||||||
|
sub_blocks.push(sub_block)
|
||||||
|
}
|
||||||
|
let (layer_norm, conv, start_layer_i) = if i > 0 {
|
||||||
|
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
|
||||||
|
layer_i += 1;
|
||||||
|
let cfg = candle_nn::Conv2dConfig {
|
||||||
|
stride: 2,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?;
|
||||||
|
layer_i += 1;
|
||||||
|
(Some(layer_norm), Some(conv), 2)
|
||||||
|
} else {
|
||||||
|
(None, None, 0)
|
||||||
|
};
|
||||||
|
let up_block = UpBlock {
|
||||||
|
layer_norm,
|
||||||
|
conv,
|
||||||
|
sub_blocks,
|
||||||
|
};
|
||||||
|
up_blocks.push(up_block)
|
||||||
|
}
|
||||||
|
|
||||||
let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?;
|
let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?;
|
||||||
let clf_conv = candle_nn::conv2d(
|
let clf_conv = candle_nn::conv2d(
|
||||||
|
@ -85,10 +85,9 @@ impl WPrior {
|
|||||||
pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||||
let x_in = xs;
|
let x_in = xs;
|
||||||
let mut xs = xs.apply(&self.projection)?;
|
let mut xs = xs.apply(&self.projection)?;
|
||||||
// TODO: leaky relu
|
|
||||||
let c_embed = c
|
let c_embed = c
|
||||||
.apply(&self.cond_mapper_lin1)?
|
.apply(&self.cond_mapper_lin1)?
|
||||||
.relu()?
|
.apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))?
|
||||||
.apply(&self.cond_mapper_lin2)?;
|
.apply(&self.cond_mapper_lin2)?;
|
||||||
let r_embed = self.gen_r_embedding(r)?;
|
let r_embed = self.gen_r_embedding(r)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
|
Reference in New Issue
Block a user