mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
feat: impl backprop for erf and gelu-erf (#1258)
* impl backprop for erf anf gelu-erf * feat: unary tests added for erf and gelu-erf * fix: (clippy) remove immediately dereferenced ref * fix: improve comments with pytorch code snippet * fix: adjust comment typo in backprop impl
This commit is contained in:
@ -532,9 +532,22 @@ impl Tensor {
|
||||
+ 0.5)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||
Op::Unary(arg, UnaryOp::Erf) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
|
||||
let erf_grad =
|
||||
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
|
||||
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::GeluErf) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
|
||||
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
|
||||
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
|
||||
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
|
||||
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
|
||||
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Relu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user