Add an erf based gelu op (#900)

* Erf based gelu.

* Add the erf backed gelu.

* Test the new gelu op (which is not gelu_new).
This commit is contained in:
Laurent Mazare
2023-09-19 19:54:28 +01:00
committed by GitHub
parent 34f2ecbc3b
commit d7e48234d4
8 changed files with 851 additions and 1 deletions

View File

@ -442,6 +442,9 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
}
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;