Add the SGD optimizer (#160)

* Add the nn::optim and some conversion traits.

* Add the backward_step function for SGD.

* Get the SGD optimizer to work and add a test.

* Make the test slighly simpler.
This commit is contained in:
Laurent Mazare
2023-07-13 19:05:44 +01:00
committed by GitHub
parent 5ee3c95582
commit ded93a1169
6 changed files with 168 additions and 4 deletions

View File

@ -6,6 +6,7 @@ pub mod embedding;
pub mod init;
pub mod layer_norm;
pub mod linear;
pub mod optim;
pub mod var_builder;
pub use activation::Activation;
@ -13,4 +14,5 @@ pub use conv::{Conv1d, Conv1dConfig};
pub use embedding::Embedding;
pub use layer_norm::LayerNorm;
pub use linear::Linear;
pub use optim::SGD;
pub use var_builder::VarBuilder;