Add a GRU layer. (#688)

* Add a GRU layer.

* Fix the n gate computation.
This commit is contained in:
Laurent Mazare
2023-08-31 09:43:10 +02:00
committed by GitHub
parent d210c71d77
commit db59816087
3 changed files with 187 additions and 1 deletions

View File

@ -25,7 +25,7 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, ParamsAdamW, SGD};
pub use rnn::{lstm, LSTM, RNN};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
pub use var_builder::VarBuilder;
pub use var_map::VarMap;