From 7051fb8098efa291a169083333750b7a8130bea3 Mon Sep 17 00:00:00 2001 From: drbh Date: Sat, 4 Nov 2023 16:26:41 -0400 Subject: [PATCH] feat: add backprop for elu (#1269) * feat: add backprop for elu * Cosmetic tweaks. --------- Co-authored-by: Laurent --- candle-core/src/backprop.rs | 11 ++++++++++- candle-core/tests/grad_tests.rs | 24 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 4fde7ea9..1448a6f4 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -554,7 +554,16 @@ impl Tensor { let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?; *sum_grad = sum_grad.add(&(&grad * relu_grad)?)? } - Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?, + Op::Elu(arg, alpha) => { + // d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0 + let sum_grad = grads.or_insert(arg)?; + let zeros = arg.zeros_like()?; + let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?; + let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?; + let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?; + let combined_mask = (positive_mask + negative_exp_mask)?; + *sum_grad = sum_grad.add(&(grad * combined_mask)?)? + } Op::Powf(arg, e) => { let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?; let sum_grad = grads.or_insert(arg)?; diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 4a529789..6413ea2e 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -246,6 +246,30 @@ fn unary_grad(device: &Device) -> Result<()> { [1.0119, 1.0833, 1.0005, 0.6188], ); + // Testing compared to pytorch elu + // + // import torch + // import torch.nn.functional as F + // x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True) + // y = F.elu(x, alpha=2.0) + // print(y) + // loss = y.min + // loss = y.sum() + // loss.backward() + // print(x.grad) + let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?; + let y = elu_x.elu(2.)?; + let grads = y.backward()?; + let grad_x = grads.get(&elu_x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [-1.2642, 0.0000, -1.7293, 3.0000] + ); + assert_eq!( + test_utils::to_vec1_round(grad_x, 4)?, + [0.7358, 2.0000, 0.2707, 1.0000] + ); + Ok(()) }