mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
feat: add backprop for elu (#1269)
* feat: add backprop for elu * Cosmetic tweaks. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -554,7 +554,16 @@ impl Tensor {
|
|||||||
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
|
||||||
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
|
||||||
}
|
}
|
||||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
Op::Elu(arg, alpha) => {
|
||||||
|
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
let zeros = arg.zeros_like()?;
|
||||||
|
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
|
||||||
|
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
|
||||||
|
let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
|
||||||
|
let combined_mask = (positive_mask + negative_exp_mask)?;
|
||||||
|
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
|
||||||
|
}
|
||||||
Op::Powf(arg, e) => {
|
Op::Powf(arg, e) => {
|
||||||
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
@ -246,6 +246,30 @@ fn unary_grad(device: &Device) -> Result<()> {
|
|||||||
[1.0119, 1.0833, 1.0005, 0.6188],
|
[1.0119, 1.0833, 1.0005, 0.6188],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Testing compared to pytorch elu
|
||||||
|
//
|
||||||
|
// import torch
|
||||||
|
// import torch.nn.functional as F
|
||||||
|
// x = torch.tensor([-1.0, 0.0, -2.0, 3.0], requires_grad=True)
|
||||||
|
// y = F.elu(x, alpha=2.0)
|
||||||
|
// print(y)
|
||||||
|
// loss = y.min
|
||||||
|
// loss = y.sum()
|
||||||
|
// loss.backward()
|
||||||
|
// print(x.grad)
|
||||||
|
let elu_x = Var::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||||
|
let y = elu_x.elu(2.)?;
|
||||||
|
let grads = y.backward()?;
|
||||||
|
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[0.7358, 2.0000, 0.2707, 1.0000]
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user