mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
feat: add silu activation function (#1706)
* feat: add silu activation function * use silu/arg in grad * update candle-nn * use node
This commit is contained in:
@ -270,6 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.Silu()
|
||||
let y = x.silu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.8577, 0.7311, 3.9281, 0.0806]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||
);
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
Reference in New Issue
Block a user