mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
feat: impl backprop for erf and gelu-erf (#1258)
* impl backprop for erf anf gelu-erf * feat: unary tests added for erf and gelu-erf * fix: (clippy) remove immediately dereferenced ref * fix: improve comments with pytorch code snippet * fix: adjust comment typo in backprop impl
This commit is contained in:
@ -532,9 +532,22 @@ impl Tensor {
|
|||||||
+ 0.5)?;
|
+ 0.5)?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||||
}
|
}
|
||||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
Op::Unary(arg, UnaryOp::Erf) => {
|
||||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
|
||||||
|
let erf_grad =
|
||||||
|
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
|
||||||
|
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
|
||||||
|
}
|
||||||
|
Op::Unary(arg, UnaryOp::GeluErf) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
|
||||||
|
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
|
||||||
|
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
|
||||||
|
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
|
||||||
|
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
|
||||||
|
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
|
||||||
|
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
|
||||||
}
|
}
|
||||||
Op::Unary(arg, UnaryOp::Relu) => {
|
Op::Unary(arg, UnaryOp::Relu) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
@ -641,6 +641,8 @@ impl UnaryOpT for Gelu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `erf` operation
|
||||||
|
/// <https://en.wikipedia.org/wiki/Error_function>
|
||||||
impl UnaryOpT for Erf {
|
impl UnaryOpT for Erf {
|
||||||
const NAME: &'static str = "erf";
|
const NAME: &'static str = "erf";
|
||||||
const KERNEL: &'static str = "uerf";
|
const KERNEL: &'static str = "uerf";
|
||||||
|
@ -205,6 +205,47 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(grad_x, 4)?,
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
[1.0116, 1.0830, 1.0003, 0.6188],
|
[1.0116, 1.0830, 1.0003, 0.6188],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Testing compared to pytorch torch.erf
|
||||||
|
//
|
||||||
|
// import torch
|
||||||
|
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||||
|
// y = x.erf()
|
||||||
|
// print(y)
|
||||||
|
// loss = y.sum()
|
||||||
|
// loss.backward()
|
||||||
|
// print(x.grad)
|
||||||
|
let y = x.erf()?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[0.0001, 0.4151, 0.0, 1.1033],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Testing compared to pytorch nn.GELU(approximate = 'none')
|
||||||
|
//
|
||||||
|
// import torch
|
||||||
|
// import torch.nn.functional as F
|
||||||
|
// x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True)
|
||||||
|
// y = F.gelu(x, approximate='none')
|
||||||
|
// print(y)
|
||||||
|
// loss = y.sum()
|
||||||
|
// loss.backward()
|
||||||
|
// print(x.grad)
|
||||||
|
let y = x.gelu_erf()?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[2.9960, 0.8413, 3.9999, 0.0839]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[1.0119, 1.0833, 1.0005, 0.6188],
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user