Add a batch normalization layer (#508)

* Add BatchNormalization.

* More batch-norm.

* Add some validation of the inputs.

* More validation.
This commit is contained in:
Laurent Mazare
2023-08-18 20:05:56 +01:00
committed by GitHub
parent b64e782c2d
commit 42e1cc8062
3 changed files with 226 additions and 0 deletions

View File

@ -1,6 +1,7 @@
use candle::{Result, Tensor};
pub mod activation;
pub mod batch_norm;
pub mod conv;
pub mod embedding;
pub mod group_norm;
@ -13,6 +14,7 @@ pub mod optim;
pub mod var_builder;
pub use activation::Activation;
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
pub use conv::{conv1d, conv2d, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
pub use embedding::{embedding, Embedding};
pub use group_norm::{group_norm, GroupNorm};