Commit Graph

215 Commits

Author SHA1 Message Date
3eb2bc6d07 Softmax numerical stability. (#267)
* Softmax numerical stability.

* Fix the flash-attn test.
2023-07-28 13:13:01 +01:00
3e89df938c Starcoder fix (#264)
* Bugfix for starcoder.

* Get some proper code generation.

* Slightly simpler softmax.
2023-07-28 11:17:49 +01:00
4f260ef025 Merge pull request #216 from LaurentMazare/llama_multiprocess2
TP sharding v2
2023-07-28 08:06:13 +01:00
952eca6b54 Fixing slice errors + comments. 2023-07-27 16:59:32 +02:00
f291065f6c Do not panic on empty ranges. (#257) 2023-07-27 09:28:47 +01:00
25a2086e8f Putting back Send + Sync 2023-07-27 09:58:47 +02:00
7c7e6ba201 Removing inner dependency on safetensors. 2023-07-27 09:58:47 +02:00
ed58de7551 Fixed TP sharded version. 2023-07-27 09:58:46 +02:00
1735e4831e TP sharding v2 2023-07-27 09:58:14 +02:00
6475bfadfe Simplify Tensor::randn. (#255)
* Simplify Tensor::randn.

* Also switch Tensor::rand to use a generic dtype.

* Support sampling for f16.

* Cleanup.
2023-07-27 07:40:36 +01:00
89ba005962 Support backprop for a few more ops. (#254) 2023-07-26 21:31:54 +01:00
1235aa2536 Use bail rather than wrapping a string where possible. (#249)
* Use bail rather than wrapping a string where possible.

* Revert the cuda default bit.
2023-07-26 15:42:46 +01:00
8b1d12bead Merge pull request #246 from LaurentMazare/rename_custom_op
Rename exposed ops.
2023-07-26 14:20:29 +01:00
1a5416ec35 Rename exposed ops. 2023-07-26 12:43:19 +02:00
fa2b64d678 Proper flash-attn parameters. (#244)
* Proper flash-attn parameters.

* Set the flash attention parameters.

* Add more validations.

* Setup the o_ flash attn parameters.

* More flash-attn support.

* Set more flash attn parameters.
2023-07-26 10:13:40 +01:00
d9f9c859af Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.

* More flash attn.

* Set up the flash attn parameters.

* Get things to compile locally.

* Move the flash attention files in a different directory.

* Build the static C library with nvcc.

* Add more flash attention.

* Update the build part.

* Better caching.

* Exclude flash attention from the default workspace.

* Put flash-attn behind a feature gate.

* Get the flash attn kernel to run.

* Move the flags to a more appropriate place.

* Enable flash attention in llama.

* Use flash attention in llama.
2023-07-26 07:48:10 +01:00
c97d51243c Add an abstract backprop op type (#240)
* Start adding the backprop op type.

* More backprop ops.

* Finish the backprop op.
2023-07-25 14:07:40 +01:00
be9c26180c Avoid keeping track of the copy ops when not necessary. (#239) 2023-07-25 10:06:01 +01:00
944d70bd9a Add a test for scatter add. (#238)
* Add a test for scatter add (segfaults on gpus for now).

* Bugfix for the scatter add cuda kernel.
2023-07-25 09:12:14 +01:00
18cc73954a Add some testing for index-add (#237)
* Add some testing for index-add.

* Fix the cpu implementation for index-add.
2023-07-25 08:38:33 +01:00
74a6a769dd Cuda kernels for IndexAdd/ScatterAdd. (#236)
* Skeleton methods for IndexAdd/ScatterAdd.

* Add a Map2InPlace trait.

* Add the glue code for the index-add/scatter-add kernels.

* Tweak the file name: embeddings -> indexing.

* Add the cuda kernel for indexadd.

* And add the scatter-add kernels.
2023-07-24 21:53:08 +01:00
581b104f97 Indexing cuda (#235)
* Allow using uint8_t for indexing.

* Revert the default cuda feature.

* Add a cuda-kernel for index-select.

* Add a test for gather.
2023-07-24 20:22:47 +01:00
b50f932e7c Add some cmp tests. (#233)
* Add some cmp tests.

* Add the cuda kernels for comparison operations.
2023-07-24 16:53:45 +01:00
fe87778223 Add the copy op. (#227)
* Add the copy op.

* Tweak some cat error messages.

* Handle the contiguous case in to_vec1.

* Fast variant for to_vec2.

* Add add a faster to_vec3 variant.
2023-07-23 18:06:47 +01:00
23827c49cd Cleanup some todos. (#226)
* Cleanup some todos.

* Fix more todo.

* Optimize for the contiguous case.

* Add the IntDType trait.

* Handle the intdtype trait for more ops.

* Remove a todo.

* Remove a todo.
2023-07-23 16:00:00 +01:00
e449ce53a2 Wrapping code to call the custom op. (#225)
* Wrapping code to call the custom op.

* Get the rms example to work.

* Get around rustfmt failing in the CI.

* Fix the rms computation.
2023-07-23 11:31:17 +01:00
b8a10425ad Kernel build example (#224)
* Build example kernels.

* Add some sample custom kernel.

* Get the example kernel to compile.

* Add some cuda code.

* More cuda custom op.

* More cuda custom ops.
2023-07-23 07:15:37 +01:00
43c7223292 Rename the .r functions to .dims so as to be a bit more explicit. (#220) 2023-07-22 10:39:27 +01:00
52c5d8c087 Add the gather op. (#219)
* Start adding gather.

* Gather cpu implementation + use in simple training.

* Add scatter_add for the gradient of gather.

* Simple cpu implementation of scatter_add.

* Use gather in the simple-training backprop.
2023-07-22 07:21:28 +01:00
6eeea1b04e Polish the index-add op and use it in the index-select backprop (#218)
* Add the cpu version of index-add.

* More cpu support for index-add.

* Use index-add in the backprop.
2023-07-22 05:31:46 +01:00
27174a82aa Start adding index-add. 2023-07-21 20:12:48 +01:00
5cc843550d Add binary and ternary custom ops. (#217) 2023-07-21 17:29:50 +01:00
4a100875bf Use a macro to handle the dtype pattern matching. (#215) 2023-07-21 16:03:51 +01:00
a6bcdfb269 Custom ops with a single argument (#214)
* Add the CustomOp1 trait.

* Add an example of custom op.

* Polish the custom op example.

* Add some backward pass test for custom ops.
2023-07-21 15:18:05 +01:00
b02229ce92 Add some epsilon tolerance to grad tests so that they work on cuda / mkl. (#213) 2023-07-21 12:45:14 +01:00
410654525f Refactor the reduce ops in order to introduce argmin/argmax. (#212)
* Refactor the reduce ops in order to introduce argmin/argmax.

* Clippy fixes.

* Use the newly introduced argmax.

* Fix the strided case.

* Handle the non-contiguous case.
2023-07-21 11:41:08 +01:00
c60831aad4 Add more gradient tests + bugfixes. (#211)
* Add more gradient tests + bugfixes.

* More tests and fixes.

* More tests.
2023-07-21 06:52:39 +01:00
4845d5cc64 More realistic training setup. (#210)
* More realistic training setup.

* Compute the model accuracy.

* Very inefficient backprop for index select.

* More backprop.

* Fix some backprop issues.

* Backprop fix.

* Another broadcasting backprop fix.

* Better backprop for reducing ops.

* Training again.

* Add some gradient tests.

* Get the training to work.
2023-07-20 18:25:41 +01:00
fa08fb3126 Add the index-select op. (#209)
* Add the index-select op.

* Cpu implementation of index-select.

* Add the cpu implementation for index-select.
2023-07-20 14:01:03 +01:00
2a8f28d687 Op refactor (#208)
* Add the binary and unary op enums to factorize some code.

* Bugfix.
2023-07-20 12:28:45 +01:00
e9c052bf94 Add the comparison operations. (#207)
* Add the comparison operations.

* Add the helper functions on the tensor side.

* More cmp operations.

* Cpu implementation for the comparison operations.
2023-07-20 09:40:31 +01:00
536c5e702e Cuda kernels for fast min/max reductions (#203)
* Add the min/max cuda kernels.

* Better integration of the cuda kernels.
2023-07-19 18:12:27 +01:00
9515e8ea6c Merge branch 'main' into remove_wrapper 2023-07-19 18:53:55 +02:00
ad12e20f6b Add cpu support for min and max. (#202)
* Add cpu support for min and max.

* Add min/max all.
2023-07-19 17:11:44 +01:00
cb687b4897 Add some more developed training examples. (#199)
* Use contiguous tensors for variables.

* Sketch the mnist example.

* Start adding the reduce ops.

* Renaming.

* Refactor the reduce operations.

* Bugfix for the broadcasting vectorization.
2023-07-19 15:37:52 +01:00
dfd624dbd3 [Proposal] Remove SafeTensor wrapper (allows finer control for users). 2023-07-19 16:25:44 +02:00
67e20c3792 Sum over more dims. (#197) 2023-07-19 06:46:32 +01:00
76dcc7a381 Test the broadcasting binary ops. (#196) 2023-07-19 06:18:36 +01:00
fd55fc9592 Add an optimized case when performing the softmax over the last dimension. (#195) 2023-07-18 17:59:50 +01:00
6623c227d8 Allow the compiler to vectorize some broadcasting loops. (#194)
* Allow the compiler to vectorize some broadcasting loops.

* Improve the symmetrical broadcasting case.
2023-07-18 17:12:32 +01:00