diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index da6dbb66..4fde7ea9 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -532,9 +532,22 @@ impl Tensor { + 0.5)?; *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)? } - Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?, - Op::Unary(_, UnaryOp::GeluErf) => { - Err(Error::BackwardNotSupported { op: "gelu-erf" })? + Op::Unary(arg, UnaryOp::Erf) => { + let sum_grad = grads.or_insert(arg)?; + // d/dx erf(x) = 2/sqrt(pi) * e^(-x^2) + let erf_grad = + (2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?; + *sum_grad = sum_grad.add(&(&grad * erf_grad)?)? + } + Op::Unary(arg, UnaryOp::GeluErf) => { + let sum_grad = grads.or_insert(arg)?; + // d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2)) + let neg_half_square = (arg.sqr()?.neg()? / 2.)?; + let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?; + let arg_scaled_sqrt = (arg / 2f64.sqrt())?; + let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?; + let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?; + *sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?; } Op::Unary(arg, UnaryOp::Relu) => { let sum_grad = grads.or_insert(arg)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 4abe1189..02b8bd9a 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -641,6 +641,8 @@ impl UnaryOpT for Gelu { } } +/// `erf` operation +/// impl UnaryOpT for Erf { const NAME: &'static str = "erf"; const KERNEL: &'static str = "uerf"; diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index bcfe639f..4a529789 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -205,6 +205,47 @@ fn unary_grad(device: &Device) -> Result<()> { test_utils::to_vec1_round(grad_x, 4)?, [1.0116, 1.0830, 1.0003, 0.6188], ); + + // Testing compared to pytorch torch.erf + // + // import torch + // x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True) + // y = x.erf() + // print(y) + // loss = y.sum() + // loss.backward() + // print(x.grad) + let y = x.erf()?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!(test_utils::to_vec1_round(&y, 4)?, [1.0, 0.8427, 1.0, 0.168]); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [0.0001, 0.4151, 0.0, 1.1033], + ); + + // Testing compared to pytorch nn.GELU(approximate = 'none') + // + // import torch + // import torch.nn.functional as F + // x = torch.tensor([3.0, 1.0, 4.0, 0.15], requires_grad=True) + // y = F.gelu(x, approximate='none') + // print(y) + // loss = y.sum() + // loss.backward() + // print(x.grad) + let y = x.gelu_erf()?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [2.9960, 0.8413, 3.9999, 0.0839] + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [1.0119, 1.0833, 1.0005, 0.6188], + ); + Ok(()) }