Add the upblocks. (#853)

This commit is contained in:
Laurent Mazare
2023-09-14 23:24:56 +02:00
committed by GitHub
parent 91ec546feb
commit 130fe5a087
4 changed files with 63 additions and 5 deletions

View File

@ -85,10 +85,9 @@ impl WPrior {
pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
let x_in = xs;
let mut xs = xs.apply(&self.projection)?;
// TODO: leaky relu
let c_embed = c
.apply(&self.cond_mapper_lin1)?
.relu()?
.apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))?
.apply(&self.cond_mapper_lin2)?;
let r_embed = self.gen_r_embedding(r)?;
for block in self.blocks.iter() {