diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index adb3e1dd..d2099df7 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -379,6 +379,11 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)? } + Op::Unary(arg, UnaryOp::Tanh) => { + let sum_grad = grads.or_insert(arg)?; + let minus_dtanh = (node.sqr()? - 1.)?; + *sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)? + } Op::Unary(arg, UnaryOp::Abs) => { let sum_grad = grads.or_insert(arg)?; let ones = arg.ones_like()?; diff --git a/candle-core/src/mkl.rs b/candle-core/src/mkl.rs index 64f83cc6..26167e86 100644 --- a/candle-core/src/mkl.rs +++ b/candle-core/src/mkl.rs @@ -301,7 +301,7 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) { } #[inline] -fn vs_tanh(a: &[f32], y: &mut [f32]) { +pub fn vs_tanh(a: &[f32], y: &mut [f32]) { let a_len = a.len(); let y_len = y.len(); if a_len != y_len { @@ -311,7 +311,7 @@ fn vs_tanh(a: &[f32], y: &mut [f32]) { } #[inline] -fn vd_tanh(a: &[f64], y: &mut [f64]) { +pub fn vd_tanh(a: &[f64], y: &mut [f64]) { let a_len = a.len(); let y_len = y.len(); if a_len != y_len { diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 213ae2c8..fbfc9c1a 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -59,6 +59,7 @@ pub enum UnaryOp { Sqrt, Gelu, Relu, + Tanh, } #[derive(Clone)] @@ -324,6 +325,7 @@ pub(crate) struct Sqr; pub(crate) struct Sqrt; pub(crate) struct Gelu; pub(crate) struct Relu; +pub(crate) struct Tanh; macro_rules! bin_op { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { @@ -521,6 +523,7 @@ unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp); unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln); unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin); unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos); +unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh); unary_op!(Abs, "abs", v, v.abs()); unary_op!(Neg, "neg", v, -v); unary_op!(Recip, "recip", v, v.recip()); diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ec89af12..f23907dd 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -460,6 +460,7 @@ impl Tensor { unary_op!(log, Log); unary_op!(sin, Sin); unary_op!(cos, Cos); + unary_op!(tanh, Tanh); unary_op!(abs, Abs); unary_op!(sqr, Sqr); unary_op!(sqrt, Sqrt); diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 26e05b68..ad09c90f 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -183,6 +183,15 @@ fn unary_grad(device: &Device) -> Result<()> { test_utils::to_vec1_round(grad_x, 2)?, [12.99, 2.5, 20.0, 0.15] ); + + let y = x.tanh()?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!(test_utils::to_vec1_round(&y, 2)?, [1.0, 0.76, 1.0, 0.15]); + assert_eq!( + test_utils::to_vec1_round(grad_x, 2)?, + [0.01, 0.42, 0.0, 0.98], + ); Ok(()) } diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index c5b18461..c6142a03 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -85,6 +85,7 @@ UNARY_OP(__nv_bfloat16, uexp_bf16, expg(x)) UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x)) UNARY_OP(__nv_bfloat16, usin_bf16, sing(x)) UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x)) +UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x)) UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x)) UNARY_OP(__nv_bfloat16, usqr_bf16, x*x) UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x)) @@ -102,6 +103,7 @@ UNARY_OP(__half, uexp_f16, expg(x)) UNARY_OP(__half, ulog_f16, logg(x)) UNARY_OP(__half, usin_f16, sing(x)) UNARY_OP(__half, ucos_f16, cosg(x)) +UNARY_OP(__half, utanh_f16, tanhg(x)) UNARY_OP(__half, uabs_f16, absg(x)) UNARY_OP(__half, usqr_f16, x*x) UNARY_OP(__half, usqrt_f16, sqrtg(x)) @@ -127,6 +129,8 @@ UNARY_OP(float, usin_f32, sing(x)) UNARY_OP(double, usin_f64, sing(x)) UNARY_OP(float, ucos_f32, cosg(x)) UNARY_OP(double, ucos_f64, cosg(x)) +UNARY_OP(float, utanh_f32, tanhg(x)) +UNARY_OP(double, utanh_f64, tanhg(x)) UNARY_OP(float, uabs_f32, absg(x)) UNARY_OP(double, uabs_f64, absg(x)) UNARY_OP(float, usqr_f32, x*x) diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 4b116081..3c82e794 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -159,13 +159,11 @@ impl RNN for LSTM { let chunks = (&w_ih + &w_hh)?.chunk(4, 1)?; let in_gate = crate::ops::sigmoid(&chunks[0])?; let forget_gate = crate::ops::sigmoid(&chunks[1])?; - // TODO: This should be a tanh - let cell_gate = crate::ops::sigmoid(&chunks[2])?; + let cell_gate = chunks[2].tanh()?; let out_gate = crate::ops::sigmoid(&chunks[3])?; let next_c = ((forget_gate * &in_state.c)? + (in_gate * cell_gate)?)?; - // TODO: This should be another tanh - let next_h = (out_gate * crate::ops::sigmoid(&next_c)?)?; + let next_h = (out_gate * next_c.tanh()?)?; Ok(LSTMState { c: next_c, h: next_h,