mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add tanh. (#675)
* Add tanh. * Use tanh in the lstm block. * Add a test for tanh forward and backward passes.
This commit is contained in:
@ -379,6 +379,11 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Tanh) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let minus_dtanh = (node.sqr()? - 1.)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Abs) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let ones = arg.ones_like()?;
|
||||
|
@ -301,7 +301,7 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
@ -311,7 +311,7 @@ fn vs_tanh(a: &[f32], y: &mut [f32]) {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
|
||||
let a_len = a.len();
|
||||
let y_len = y.len();
|
||||
if a_len != y_len {
|
||||
|
@ -59,6 +59,7 @@ pub enum UnaryOp {
|
||||
Sqrt,
|
||||
Gelu,
|
||||
Relu,
|
||||
Tanh,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -324,6 +325,7 @@ pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Tanh;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -521,6 +523,7 @@ unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
|
||||
unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Recip, "recip", v, v.recip());
|
||||
|
@ -460,6 +460,7 @@ impl Tensor {
|
||||
unary_op!(log, Log);
|
||||
unary_op!(sin, Sin);
|
||||
unary_op!(cos, Cos);
|
||||
unary_op!(tanh, Tanh);
|
||||
unary_op!(abs, Abs);
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
|
@ -183,6 +183,15 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[12.99, 2.5, 20.0, 0.15]
|
||||
);
|
||||
|
||||
let y = x.tanh()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [1.0, 0.76, 1.0, 0.15]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[0.01, 0.42, 0.0, 0.98],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -85,6 +85,7 @@ UNARY_OP(__nv_bfloat16, uexp_bf16, expg(x))
|
||||
UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
|
||||
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
|
||||
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
|
||||
UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x))
|
||||
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
|
||||
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
|
||||
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
|
||||
@ -102,6 +103,7 @@ UNARY_OP(__half, uexp_f16, expg(x))
|
||||
UNARY_OP(__half, ulog_f16, logg(x))
|
||||
UNARY_OP(__half, usin_f16, sing(x))
|
||||
UNARY_OP(__half, ucos_f16, cosg(x))
|
||||
UNARY_OP(__half, utanh_f16, tanhg(x))
|
||||
UNARY_OP(__half, uabs_f16, absg(x))
|
||||
UNARY_OP(__half, usqr_f16, x*x)
|
||||
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
||||
@ -127,6 +129,8 @@ UNARY_OP(float, usin_f32, sing(x))
|
||||
UNARY_OP(double, usin_f64, sing(x))
|
||||
UNARY_OP(float, ucos_f32, cosg(x))
|
||||
UNARY_OP(double, ucos_f64, cosg(x))
|
||||
UNARY_OP(float, utanh_f32, tanhg(x))
|
||||
UNARY_OP(double, utanh_f64, tanhg(x))
|
||||
UNARY_OP(float, uabs_f32, absg(x))
|
||||
UNARY_OP(double, uabs_f64, absg(x))
|
||||
UNARY_OP(float, usqr_f32, x*x)
|
||||
|
@ -159,13 +159,11 @@ impl RNN for LSTM {
|
||||
let chunks = (&w_ih + &w_hh)?.chunk(4, 1)?;
|
||||
let in_gate = crate::ops::sigmoid(&chunks[0])?;
|
||||
let forget_gate = crate::ops::sigmoid(&chunks[1])?;
|
||||
// TODO: This should be a tanh
|
||||
let cell_gate = crate::ops::sigmoid(&chunks[2])?;
|
||||
let cell_gate = chunks[2].tanh()?;
|
||||
let out_gate = crate::ops::sigmoid(&chunks[3])?;
|
||||
|
||||
let next_c = ((forget_gate * &in_state.c)? + (in_gate * cell_gate)?)?;
|
||||
// TODO: This should be another tanh
|
||||
let next_h = (out_gate * crate::ops::sigmoid(&next_c)?)?;
|
||||
let next_h = (out_gate * next_c.tanh()?)?;
|
||||
Ok(LSTMState {
|
||||
c: next_c,
|
||||
h: next_h,
|
||||
|
Reference in New Issue
Block a user