Commit Graph

646 Commits

Author SHA1 Message Date
ed58de7551 Fixed TP sharded version. 2023-07-27 09:58:46 +02:00
1735e4831e TP sharding v2 2023-07-27 09:58:14 +02:00
209f06d7c3 Micro-cleanup. (#256) 2023-07-27 07:55:54 +01:00
6475bfadfe Simplify Tensor::randn. (#255)
* Simplify Tensor::randn.

* Also switch Tensor::rand to use a generic dtype.

* Support sampling for f16.

* Cleanup.
2023-07-27 07:40:36 +01:00
89ba005962 Support backprop for a few more ops. (#254) 2023-07-26 21:31:54 +01:00
4f92420132 Add some flash attn test (#253)
* Add some flash-attn test.

* Add the cpu test.

* Fail when the head is not a multiple of 8.

* Polish the flash attention test.
2023-07-26 20:56:00 +01:00
ded197497c Merge pull request #252 from LaurentMazare/add_book
Adding a cargo book
2023-07-26 17:35:54 +01:00
84ad558e50 Switch to using llama-v2 by default. (#251) 2023-07-26 17:18:27 +01:00
368f169c6a Permissions. 2023-07-26 18:12:02 +02:00
8da6568c20 Typo. 2023-07-26 18:11:10 +02:00
07a22fe606 Releasing within the branch to test the setup. 2023-07-26 18:08:34 +02:00
834e1b197b Adding a documentation book. 2023-07-26 18:06:31 +02:00
89fd988836 Update to the latest gemm. (#250) 2023-07-26 17:00:02 +01:00
1235aa2536 Use bail rather than wrapping a string where possible. (#249)
* Use bail rather than wrapping a string where possible.

* Revert the cuda default bit.
2023-07-26 15:42:46 +01:00
f052ba76cb Lining up the flash attn version with the non-flash one. (#248)
* Move the flash-attn function in the proper crate.

* Causality tweak.
2023-07-26 15:11:45 +01:00
46f2d9f0ac Merge pull request #247 from LaurentMazare/add_number_of_tokens
Add number of tokens.
2023-07-26 14:45:53 +01:00
81bfa46702 Updated. 2023-07-26 15:21:50 +02:00
8b1d12bead Merge pull request #246 from LaurentMazare/rename_custom_op
Rename exposed ops.
2023-07-26 14:20:29 +01:00
035372248e Simple QOL.
- Add ms/token on llama2.c (15ms/token on my personal machine)
- Hide `Run` buttons while models are not ready
- Add dummy `progress` while weights are downloading (I briefly looked
  at putting a real progressbar.. and nothing easy enough came up.)
2023-07-26 15:17:32 +02:00
2ce5f12513 Again set a few extra params in flash-attn. (#245)
* Again set a few extra params.

* Use the appropriate kernel sizes.

* Add all the kernel sizes.

* Parallel compiling.

* Reduce the amount of parallelism.

* Add the missing kernel.

* Fix a typo.

* Remove bf16 support for now.
2023-07-26 14:16:37 +01:00
97990f4afc Add number of tokens. 2023-07-26 14:57:20 +02:00
1a5416ec35 Rename exposed ops. 2023-07-26 12:43:19 +02:00
fa2b64d678 Proper flash-attn parameters. (#244)
* Proper flash-attn parameters.

* Set the flash attention parameters.

* Add more validations.

* Setup the o_ flash attn parameters.

* More flash-attn support.

* Set more flash attn parameters.
2023-07-26 10:13:40 +01:00
e40b150bbe Better handling of dtypes in llama. (#243) 2023-07-26 08:28:33 +01:00
471855e2ee Specific cache dir for the flash attn build artifacts. (#242) 2023-07-26 08:04:02 +01:00
d9f9c859af Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.

* More flash attn.

* Set up the flash attn parameters.

* Get things to compile locally.

* Move the flash attention files in a different directory.

* Build the static C library with nvcc.

* Add more flash attention.

* Update the build part.

* Better caching.

* Exclude flash attention from the default workspace.

* Put flash-attn behind a feature gate.

* Get the flash attn kernel to run.

* Move the flags to a more appropriate place.

* Enable flash attention in llama.

* Use flash attention in llama.
2023-07-26 07:48:10 +01:00
c97d51243c Add an abstract backprop op type (#240)
* Start adding the backprop op type.

* More backprop ops.

* Finish the backprop op.
2023-07-25 14:07:40 +01:00
be9c26180c Avoid keeping track of the copy ops when not necessary. (#239) 2023-07-25 10:06:01 +01:00
944d70bd9a Add a test for scatter add. (#238)
* Add a test for scatter add (segfaults on gpus for now).

* Bugfix for the scatter add cuda kernel.
2023-07-25 09:12:14 +01:00
18cc73954a Add some testing for index-add (#237)
* Add some testing for index-add.

* Fix the cpu implementation for index-add.
2023-07-25 08:38:33 +01:00
74a6a769dd Cuda kernels for IndexAdd/ScatterAdd. (#236)
* Skeleton methods for IndexAdd/ScatterAdd.

* Add a Map2InPlace trait.

* Add the glue code for the index-add/scatter-add kernels.

* Tweak the file name: embeddings -> indexing.

* Add the cuda kernel for indexadd.

* And add the scatter-add kernels.
2023-07-24 21:53:08 +01:00
581b104f97 Indexing cuda (#235)
* Allow using uint8_t for indexing.

* Revert the default cuda feature.

* Add a cuda-kernel for index-select.

* Add a test for gather.
2023-07-24 20:22:47 +01:00
b50f932e7c Add some cmp tests. (#233)
* Add some cmp tests.

* Add the cuda kernels for comparison operations.
2023-07-24 16:53:45 +01:00
160ba09d30 Polish the llama2 wasm ui. (#232)
* Polish the llama2 wasm ui.

* readme update.
2023-07-24 15:28:27 +01:00
5a26cba733 Re-organize the wasm examples (#231)
* Move the whisper example.

* More renaming.

* Add llama2 as a new wasm example.

* Live generation.

* More of the llama wasm example.

* Formatting.
2023-07-24 12:36:02 +01:00
550a13a547 Use the binary decoder for llama2.c. (#230)
* Use the binary decoder for llama2.c.

* Add the temperature.

* Formatting tweak.

* Fix the rotary embeddings.
2023-07-24 10:56:08 +01:00
35b65fed88 Add llama2.c as an example. (#229)
* Start adding llama2.c.

* Model loading.

* Add the llama-v2 model.

* Start converting the weights.

* Rotary embedding tweaks.

* Get the model to generate some tokens.
2023-07-24 09:13:50 +01:00
b6f7dfb682 CPU implementation for the custom RMS example. (#228)
* CPU implementation for the custom RMS example.

* Add the eps parameter.
2023-07-23 20:04:20 +01:00
fe87778223 Add the copy op. (#227)
* Add the copy op.

* Tweak some cat error messages.

* Handle the contiguous case in to_vec1.

* Fast variant for to_vec2.

* Add add a faster to_vec3 variant.
2023-07-23 18:06:47 +01:00
23827c49cd Cleanup some todos. (#226)
* Cleanup some todos.

* Fix more todo.

* Optimize for the contiguous case.

* Add the IntDType trait.

* Handle the intdtype trait for more ops.

* Remove a todo.

* Remove a todo.
2023-07-23 16:00:00 +01:00
e449ce53a2 Wrapping code to call the custom op. (#225)
* Wrapping code to call the custom op.

* Get the rms example to work.

* Get around rustfmt failing in the CI.

* Fix the rms computation.
2023-07-23 11:31:17 +01:00
b8a10425ad Kernel build example (#224)
* Build example kernels.

* Add some sample custom kernel.

* Get the example kernel to compile.

* Add some cuda code.

* More cuda custom op.

* More cuda custom ops.
2023-07-23 07:15:37 +01:00
5f20acf080 Revert "Add the layer norm files. (#222)" (#223)
This reverts commit c8459d199d.
2023-07-22 16:51:11 +01:00
c8459d199d Add the layer norm files. (#222) 2023-07-22 15:06:35 +01:00
1f26042693 Move some shared functions to the nn module. (#221) 2023-07-22 13:25:11 +01:00
43c7223292 Rename the .r functions to .dims so as to be a bit more explicit. (#220) 2023-07-22 10:39:27 +01:00
52c5d8c087 Add the gather op. (#219)
* Start adding gather.

* Gather cpu implementation + use in simple training.

* Add scatter_add for the gradient of gather.

* Simple cpu implementation of scatter_add.

* Use gather in the simple-training backprop.
2023-07-22 07:21:28 +01:00
6eeea1b04e Polish the index-add op and use it in the index-select backprop (#218)
* Add the cpu version of index-add.

* More cpu support for index-add.

* Use index-add in the backprop.
2023-07-22 05:31:46 +01:00
27174a82aa Start adding index-add. 2023-07-21 20:12:48 +01:00
5cc843550d Add binary and ternary custom ops. (#217) 2023-07-21 17:29:50 +01:00