Commit Graph

22 Commits

Author SHA1 Message Date
db59816087 Add a GRU layer. (#688)
* Add a GRU layer.

* Fix the n gate computation.
2023-08-31 08:43:10 +01:00
eaf760a751 Add a python variant for the lstm test. (#682) 2023-08-30 22:32:08 +01:00
21e1c73892 Add a LSTM test. (#681)
* Add a LSTM test.

* Clippy.
2023-08-30 20:05:42 +02:00
5320aa6b7d Move the test-utils bits to a shared place. (#619) 2023-08-27 09:42:22 +01:00
11c7e7bd67 Some fixes for yolo-v3. (#529)
* Some fixes for yolo-v3.

* Use the running stats for inference in the batch-norm layer.

* Get some proper predictions for yolo.

* Avoid the quadratic insertion.
2023-08-20 23:19:15 +01:00
a1812f934f Add a yolo-v3 example. (#528)
* Add a couple functions required for yolo.

* Add the yolo-v3 example.

* Add minimum and maximum.

* Use the newly introduced maximum.

* Cuda support for min/max + add some testing.

* Allow for more tests to work with accelerate.

* Fix a typo.
2023-08-20 18:19:37 +01:00
42e1cc8062 Add a batch normalization layer (#508)
* Add BatchNormalization.

* More batch-norm.

* Add some validation of the inputs.

* More validation.
2023-08-18 20:05:56 +01:00
c78ce76501 Add a simple Module trait and implement it for the various nn layers (#500)
* Start adding the module trait.

* Use the module trait.

* Implement module for qmatmul.
2023-08-18 09:38:22 +01:00
eab54e4490 Fix the tests for mkl. (#437) 2023-08-14 08:09:27 +01:00
89d3926c9b Fixes for the stable diffusion example. (#342)
* Fixes for the stable diffusion example.

* Bugfix.

* Another fix.

* Fix for group-norm.

* More fixes to get SD to work.
2023-08-08 14:57:09 +01:00
5bb2fce998 Implement group-norm. (#334)
* Implement group-norm.

* Add some testing for group-norm.
2023-08-07 06:53:05 +01:00
0902846f25 Add the AdamW optimizer. (#307)
* Add the AdamW optimizer.

* Add some AdamW test validated against PyTorch.
2023-08-02 14:03:49 +01:00
ff876c2103 Llama more training (#297)
* Rework the var-builder to handle initializations.

* Add some helper functions for layer creation.

* Improve the layer initializations.

* Get initialized variables.

* Precompute the rot embeddings when training lamas.
2023-08-01 19:53:41 +01:00
1064b9b031 Add the cross-entropy loss. (#287) 2023-07-31 14:26:36 +01:00
ffeafbfc43 Make the nll op closer to the pytorch version + add a test. (#286) 2023-07-31 14:14:01 +01:00
3eb2bc6d07 Softmax numerical stability. (#267)
* Softmax numerical stability.

* Fix the flash-attn test.
2023-07-28 13:13:01 +01:00
a2f72edc0d Simplify the parameters used by sum and sum_keepdim. (#165) 2023-07-14 08:22:08 +01:00
2bfa791336 Use the same default as pytorch for sum. (#164) 2023-07-13 21:32:32 +01:00
57be3638d8 Add the pytorch version of the linear regression as a comment. (#163)
* Add the pytorch version of the linear regression.

* Typo.
2023-07-13 21:05:57 +01:00
23e105cd94 Add the gradient for reduce-sum. (#162)
* Add the gradient for reduce-sum.

* And add the gradient for the broadcast ops.

* Add some backprop tests.

* Add some linear regression example.
2023-07-13 20:14:10 +01:00
ded93a1169 Add the SGD optimizer (#160)
* Add the nn::optim and some conversion traits.

* Add the backward_step function for SGD.

* Get the SGD optimizer to work and add a test.

* Make the test slighly simpler.
2023-07-13 19:05:44 +01:00
71cd3745a9 Add some layer-norm tests. (#121) 2023-07-10 14:43:04 +01:00