ad73e93da2
Detach the tensors on batch-norm eval. ( #1702 )
...
* Detach the tensors on batch-norm eval.
* Fix pyo3 bindings.
* Black tweak.
* Formatting.
* Also update the pyo3-onnx formatting.
* Apply black.
2024-02-13 14:26:32 +01:00
16c33383eb
Improve the mnist training example. ( #276 )
...
* Improve the mnist training example.
* Add some initialization routine that can be used for nn.
* Proper initialization in the mnist example.
2023-07-29 16:28:22 +01:00
6475bfadfe
Simplify Tensor::randn. ( #255 )
...
* Simplify Tensor::randn.
* Also switch Tensor::rand to use a generic dtype.
* Support sampling for f16.
* Cleanup.
2023-07-27 07:40:36 +01:00
d88b6cdca9
Add backtrace information to errors where relevant. ( #166 )
...
* Add backtrace information to errors where relevant.
* More backtrace information.
* Add to the FAQ.
2023-07-14 09:31:25 +01:00
ded93a1169
Add the SGD optimizer ( #160 )
...
* Add the nn::optim and some conversion traits.
* Add the backward_step function for SGD.
* Get the SGD optimizer to work and add a test.
* Make the test slighly simpler.
2023-07-13 19:05:44 +01:00
5ee3c95582
Move the variable creation to the variable module. ( #159 )
...
* Move the variable creation to the variable module.
* Make it possible to set a variable.
* Add some basic gradient descent test.
* Get the gradient descent test to work.
2023-07-13 16:55:40 +01:00
6991036bc5
Introduce the variables api used for adjusting parameters during the training loop. ( #158 )
...
* Add the variable api.
* And add a comment.
2023-07-13 14:09:51 +01:00