Add some recurrent neural networks (#674)

* Add the rnn module.

* More LSTM.

* Implement the RNN forward pass.

* More forward pass for LSTM.
This commit is contained in:
Laurent Mazare
2023-08-30 13:27:09 +01:00
committed by GitHub
parent 618f4e4c78
commit f35b9f6baa
2 changed files with 190 additions and 0 deletions

View File

@ -10,6 +10,7 @@ pub mod linear;
pub mod loss;
pub mod ops;
pub mod optim;
pub mod rnn;
pub mod var_builder;
pub mod var_map;
@ -23,6 +24,7 @@ pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_no_bias, Linear};
pub use optim::{AdamW, ParamsAdamW, SGD};
pub use rnn::{lstm, LSTM, RNN};
pub use var_builder::VarBuilder;
pub use var_map::VarMap;