89d3926c9b
Fixes for the stable diffusion example. ( #342 )
...
* Fixes for the stable diffusion example.
* Bugfix.
* Another fix.
* Fix for group-norm.
* More fixes to get SD to work.
2023-08-08 14:57:09 +01:00
5bb2fce998
Implement group-norm. ( #334 )
...
* Implement group-norm.
* Add some testing for group-norm.
2023-08-07 06:53:05 +01:00
0902846f25
Add the AdamW optimizer. ( #307 )
...
* Add the AdamW optimizer.
* Add some AdamW test validated against PyTorch.
2023-08-02 14:03:49 +01:00
ff876c2103
Llama more training ( #297 )
...
* Rework the var-builder to handle initializations.
* Add some helper functions for layer creation.
* Improve the layer initializations.
* Get initialized variables.
* Precompute the rot embeddings when training lamas.
2023-08-01 19:53:41 +01:00
1064b9b031
Add the cross-entropy loss. ( #287 )
2023-07-31 14:26:36 +01:00
ffeafbfc43
Make the nll op closer to the pytorch version + add a test. ( #286 )
2023-07-31 14:14:01 +01:00
3eb2bc6d07
Softmax numerical stability. ( #267 )
...
* Softmax numerical stability.
* Fix the flash-attn test.
2023-07-28 13:13:01 +01:00
a2f72edc0d
Simplify the parameters used by sum and sum_keepdim. ( #165 )
2023-07-14 08:22:08 +01:00
2bfa791336
Use the same default as pytorch for sum. ( #164 )
2023-07-13 21:32:32 +01:00
57be3638d8
Add the pytorch version of the linear regression as a comment. ( #163 )
...
* Add the pytorch version of the linear regression.
* Typo.
2023-07-13 21:05:57 +01:00
23e105cd94
Add the gradient for reduce-sum. ( #162 )
...
* Add the gradient for reduce-sum.
* And add the gradient for the broadcast ops.
* Add some backprop tests.
* Add some linear regression example.
2023-07-13 20:14:10 +01:00
ded93a1169
Add the SGD optimizer ( #160 )
...
* Add the nn::optim and some conversion traits.
* Add the backward_step function for SGD.
* Get the SGD optimizer to work and add a test.
* Make the test slighly simpler.
2023-07-13 19:05:44 +01:00
71cd3745a9
Add some layer-norm tests. ( #121 )
2023-07-10 14:43:04 +01:00