mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
@ -471,7 +471,15 @@ impl Tensor {
|
||||
Op::Unary(_, UnaryOp::Round) => {
|
||||
Err(Error::BackwardNotSupported { op: "round" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(arg, UnaryOp::Gelu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let cube = arg.powf(3.)?;
|
||||
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
|
||||
let gelu_grad = (((0.5 * &tanh)?
|
||||
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
|
||||
+ 0.5)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
|
||||
|
@ -192,6 +192,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[0.01, 0.42, 0.0, 0.98],
|
||||
);
|
||||
|
||||
// testing compared to pytorch nn.GELU(approximate = 'tanh')
|
||||
let y = x.gelu()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[2.9964, 0.8412, 3.9999, 0.0839]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[1.0116, 1.0830, 1.0003, 0.6188],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user