Commit Graph

9 Commits

Author SHA1 Message Date
b0fe5e4453 Do not implement Module for BatchNorm. (#1513) 2024-01-01 10:13:13 +01:00
a0facd0e67 Small tweaks to batch-norm. (#1505) 2023-12-30 17:06:07 +01:00
4290b81244 [Breaking] Add training to batchnorm with exponential moving average (#1504)
* Add training to batchnorm with exponential moving average

* Add more checks to batch norm

* Resolve some review comments

* Add with_momentum varients of `new` methods

* Add check for range of momentum variable; update batch norm test

* Run cargo fmt

* Add back num_features parameter

* Format; tiny simplification
2023-12-30 16:42:08 +01:00
b3181455d5 Add fuse-conv-bn method for Conv2d (#1196)
* Add fuse-conv-bn method for Conv2d

* no unwrap

* run rustfmp and clippy
2023-10-27 15:56:50 +01:00
0acd16751d Expose the fields from batch-norm. (#1176) 2023-10-25 15:35:32 +01:00
7b1ddcff47 Add clone to various nn layers. (#910) 2023-09-20 11:33:51 +01:00
4c338b0cd9 VarBuilder cleanup (#627)
* VarBuilder cleanup.

* Implement the basic varbuilders.

* Add the sharded code.

* Proper support for tensor sharding.
2023-08-27 18:03:26 +01:00
11c7e7bd67 Some fixes for yolo-v3. (#529)
* Some fixes for yolo-v3.

* Use the running stats for inference in the batch-norm layer.

* Get some proper predictions for yolo.

* Avoid the quadratic insertion.
2023-08-20 23:19:15 +01:00
42e1cc8062 Add a batch normalization layer (#508)
* Add BatchNormalization.

* More batch-norm.

* Add some validation of the inputs.

* More validation.
2023-08-18 20:05:56 +01:00