Add tanh. (#675)

* Add tanh.

* Use tanh in the lstm block.

* Add a test for tanh forward and backward passes.
This commit is contained in:
Laurent Mazare
2023-08-30 13:54:50 +01:00
committed by GitHub
parent f35b9f6baa
commit ad8a62dbf5
7 changed files with 26 additions and 6 deletions

View File

@ -379,6 +379,11 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
Op::Unary(arg, UnaryOp::Tanh) => {
let sum_grad = grads.or_insert(arg)?;
let minus_dtanh = (node.sqr()? - 1.)?;
*sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
}
Op::Unary(arg, UnaryOp::Abs) => {
let sum_grad = grads.or_insert(arg)?;
let ones = arg.ones_like()?;

View File

@ -301,7 +301,7 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
}
#[inline]
fn vs_tanh(a: &[f32], y: &mut [f32]) {
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@ -311,7 +311,7 @@ fn vs_tanh(a: &[f32], y: &mut [f32]) {
}
#[inline]
fn vd_tanh(a: &[f64], y: &mut [f64]) {
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {

View File

@ -59,6 +59,7 @@ pub enum UnaryOp {
Sqrt,
Gelu,
Relu,
Tanh,
}
#[derive(Clone)]
@ -324,6 +325,7 @@ pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
pub(crate) struct Relu;
pub(crate) struct Tanh;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
@ -521,6 +523,7 @@ unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());

View File

@ -460,6 +460,7 @@ impl Tensor {
unary_op!(log, Log);
unary_op!(sin, Sin);
unary_op!(cos, Cos);
unary_op!(tanh, Tanh);
unary_op!(abs, Abs);
unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt);