feat: impl backprop for erf and gelu-erf (#1258)

* impl backprop for erf anf gelu-erf

* feat: unary tests added for erf and gelu-erf

* fix: (clippy) remove immediately dereferenced ref

* fix: improve comments with pytorch code snippet

* fix: adjust comment typo in backprop impl
This commit is contained in:
drbh
2023-11-03 16:32:30 -04:00
committed by GitHub
parent ad63f20781
commit 3173b1ce3b
3 changed files with 59 additions and 3 deletions

View File

@ -532,9 +532,22 @@ impl Tensor {
+ 0.5)?;
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
}
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
Op::Unary(arg, UnaryOp::Erf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
let erf_grad =
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
}
Op::Unary(arg, UnaryOp::GeluErf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
}
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;