Add tanh. (#675)

* Add tanh.

* Use tanh in the lstm block.

* Add a test for tanh forward and backward passes.
This commit is contained in:
Laurent Mazare
2023-08-30 13:54:50 +01:00
committed by GitHub
parent f35b9f6baa
commit ad8a62dbf5
7 changed files with 26 additions and 6 deletions

View File

@ -159,13 +159,11 @@ impl RNN for LSTM {
let chunks = (&w_ih + &w_hh)?.chunk(4, 1)?;
let in_gate = crate::ops::sigmoid(&chunks[0])?;
let forget_gate = crate::ops::sigmoid(&chunks[1])?;
// TODO: This should be a tanh
let cell_gate = crate::ops::sigmoid(&chunks[2])?;
let cell_gate = chunks[2].tanh()?;
let out_gate = crate::ops::sigmoid(&chunks[3])?;
let next_c = ((forget_gate * &in_state.c)? + (in_gate * cell_gate)?)?;
// TODO: This should be another tanh
let next_h = (out_gate * crate::ops::sigmoid(&next_c)?)?;
let next_h = (out_gate * next_c.tanh()?)?;
Ok(LSTMState {
c: next_c,
h: next_h,