Layer norm tweaks (#482)

* Add some options to make layer-norm more configurable.

* Add the rms-norm variant.

* Replace the RmsNorm with the shared bits.
This commit is contained in:
Laurent Mazare
2023-08-17 10:07:13 +01:00
committed by GitHub
parent d99cac3ec3
commit d32e8199cd
7 changed files with 124 additions and 158 deletions

View File

@ -17,7 +17,7 @@ pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
pub use embedding::{embedding, Embedding};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, LayerNorm};
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig};
pub use linear::{linear, linear_no_bias, Linear};
pub use optim::{AdamW, ParamsAdamW, SGD};
pub use var_builder::{VarBuilder, VarMap};