a27239f3d9
Add training for the llama2.c example ( #296 )
...
* Rework the commands and run inference by default.
* Add the training module and load the training dataset.
* Random dataset iterator.
* Proper valid-loss computation.
* Compute the evaluation loss.
* Add more substance to the training loop.
2023-08-01 17:23:07 +01:00
952eca6b54
Fixing slice errors + comments.
2023-07-27 16:59:32 +02:00
1235aa2536
Use bail rather than wrapping a string where possible. ( #249 )
...
* Use bail rather than wrapping a string where possible.
* Revert the cuda default bit.
2023-07-26 15:42:46 +01:00
fa2b64d678
Proper flash-attn parameters. ( #244 )
...
* Proper flash-attn parameters.
* Set the flash attention parameters.
* Add more validations.
* Setup the o_ flash attn parameters.
* More flash-attn support.
* Set more flash attn parameters.
2023-07-26 10:13:40 +01:00
6eeea1b04e
Polish the index-add op and use it in the index-select backprop ( #218 )
...
* Add the cpu version of index-add.
* More cpu support for index-add.
* Use index-add in the backprop.
2023-07-22 05:31:46 +01:00
410654525f
Refactor the reduce ops in order to introduce argmin/argmax. ( #212 )
...
* Refactor the reduce ops in order to introduce argmin/argmax.
* Clippy fixes.
* Use the newly introduced argmax.
* Fix the strided case.
* Handle the non-contiguous case.
2023-07-21 11:41:08 +01:00
fa08fb3126
Add the index-select op. ( #209 )
...
* Add the index-select op.
* Cpu implementation of index-select.
* Add the cpu implementation for index-select.
2023-07-20 14:01:03 +01:00
ad12e20f6b
Add cpu support for min and max. ( #202 )
...
* Add cpu support for min and max.
* Add min/max all.
2023-07-19 17:11:44 +01:00
d88b6cdca9
Add backtrace information to errors where relevant. ( #166 )
...
* Add backtrace information to errors where relevant.
* More backtrace information.
* Add to the FAQ.
2023-07-14 09:31:25 +01:00
a2f72edc0d
Simplify the parameters used by sum and sum_keepdim. ( #165 )
2023-07-14 08:22:08 +01:00
5ee3c95582
Move the variable creation to the variable module. ( #159 )
...
* Move the variable creation to the variable module.
* Make it possible to set a variable.
* Add some basic gradient descent test.
* Get the gradient descent test to work.
2023-07-13 16:55:40 +01:00
21aa29ddce
Use a rwlock for inner mutability. ( #156 )
...
* Use a rw-lock.
* Make clippy happier.
2023-07-13 11:25:24 +01:00
50b0946a2d
Tensor mutability ( #154 )
...
* Working towards tensor mutability.
* Use a ref-cell to provide tensor mutability.
2023-07-13 11:04:40 +01:00
ba35d895e7
Sketch the candle-transformers crate. ( #147 )
...
* Sketch the candle-transformers crate.
* Format the empty files.
2023-07-12 13:49:31 +01:00
a76ec797da
Cleanup the main crate error and add a couple dedicated ones ( #142 )
...
* Cosmetic cleanups to the error enum.
* More error cleanup.
* Proper error handling rather than panicing.
* Add some conv1d dedicated error.
2023-07-12 09:17:08 +01:00
37cad85869
Resurrect the llama npy support. ( #140 )
2023-07-11 19:32:10 +01:00
64264d97c1
Modular backends ( #138 )
...
* Add some trait to formalize backends.
* Use the generic backend trait.
2023-07-11 11:17:02 +01:00
270997a055
Add the elu op. ( #113 )
2023-07-09 21:56:31 +01:00
dd60bd84bb
MKL adjustments. ( #87 )
2023-07-06 11:37:27 +01:00
a57b314780
Add a batch dimension on the bert example.
2023-07-04 06:10:52 +01:00
ad52b0377c
Add the varbuilder + check shapes.
2023-07-03 15:32:20 +01:00
899c76de75
Handle more types in safetensors.
2023-07-03 10:09:46 +01:00
fe2c07e368
Add the ST error.
2023-07-03 08:44:00 +01:00
19cbbc5212
Improve how we check that the dims are in bounds.
2023-06-30 09:11:00 +01:00
b50bd880ce
Only narrow when needed + deactivate the kv cache.
2023-06-29 19:07:52 +01:00
d7f729fb8f
Refactor the hierarchy.
2023-06-27 11:57:27 +02:00