derivative for GELU (#1160)

* derivative for GELU

* add tests
This commit is contained in:
KGrewal1
2023-10-23 20:23:45 +01:00
committed by GitHub
parent eae94a451b
commit 807e3f9f52
2 changed files with 22 additions and 1 deletions

View File

@ -192,6 +192,19 @@ fn unary_grad(device: &Device) -> Result<()> {
test_utils::to_vec1_round(grad_x, 2)?,
[0.01, 0.42, 0.0, 0.98],
);
// testing compared to pytorch nn.GELU(approximate = 'tanh')
let y = x.gelu()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9964, 0.8412, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0116, 1.0830, 1.0003, 0.6188],
);
Ok(())
}