mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add tanh. (#675)
* Add tanh. * Use tanh in the lstm block. * Add a test for tanh forward and backward passes.
This commit is contained in:
@ -159,13 +159,11 @@ impl RNN for LSTM {
|
||||
let chunks = (&w_ih + &w_hh)?.chunk(4, 1)?;
|
||||
let in_gate = crate::ops::sigmoid(&chunks[0])?;
|
||||
let forget_gate = crate::ops::sigmoid(&chunks[1])?;
|
||||
// TODO: This should be a tanh
|
||||
let cell_gate = crate::ops::sigmoid(&chunks[2])?;
|
||||
let cell_gate = chunks[2].tanh()?;
|
||||
let out_gate = crate::ops::sigmoid(&chunks[3])?;
|
||||
|
||||
let next_c = ((forget_gate * &in_state.c)? + (in_gate * cell_gate)?)?;
|
||||
// TODO: This should be another tanh
|
||||
let next_h = (out_gate * crate::ops::sigmoid(&next_c)?)?;
|
||||
let next_h = (out_gate * next_c.tanh()?)?;
|
||||
Ok(LSTMState {
|
||||
c: next_c,
|
||||
h: next_h,
|
||||
|
Reference in New Issue
Block a user