mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Add the upblocks. (#853)
This commit is contained in:
@ -85,10 +85,9 @@ impl WPrior {
|
||||
pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
let x_in = xs;
|
||||
let mut xs = xs.apply(&self.projection)?;
|
||||
// TODO: leaky relu
|
||||
let c_embed = c
|
||||
.apply(&self.cond_mapper_lin1)?
|
||||
.relu()?
|
||||
.apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))?
|
||||
.apply(&self.cond_mapper_lin2)?;
|
||||
let r_embed = self.gen_r_embedding(r)?;
|
||||
for block in self.blocks.iter() {
|
||||
|
Reference in New Issue
Block a user