581b104f97
Indexing cuda ( #235 )
...
* Allow using uint8_t for indexing.
* Revert the default cuda feature.
* Add a cuda-kernel for index-select.
* Add a test for gather.
2023-07-24 20:22:47 +01:00
b50f932e7c
Add some cmp tests. ( #233 )
...
* Add some cmp tests.
* Add the cuda kernels for comparison operations.
2023-07-24 16:53:45 +01:00
160ba09d30
Polish the llama2 wasm ui. ( #232 )
...
* Polish the llama2 wasm ui.
* readme update.
2023-07-24 15:28:27 +01:00
5a26cba733
Re-organize the wasm examples ( #231 )
...
* Move the whisper example.
* More renaming.
* Add llama2 as a new wasm example.
* Live generation.
* More of the llama wasm example.
* Formatting.
2023-07-24 12:36:02 +01:00
550a13a547
Use the binary decoder for llama2.c. ( #230 )
...
* Use the binary decoder for llama2.c.
* Add the temperature.
* Formatting tweak.
* Fix the rotary embeddings.
2023-07-24 10:56:08 +01:00
35b65fed88
Add llama2.c as an example. ( #229 )
...
* Start adding llama2.c.
* Model loading.
* Add the llama-v2 model.
* Start converting the weights.
* Rotary embedding tweaks.
* Get the model to generate some tokens.
2023-07-24 09:13:50 +01:00
b6f7dfb682
CPU implementation for the custom RMS example. ( #228 )
...
* CPU implementation for the custom RMS example.
* Add the eps parameter.
2023-07-23 20:04:20 +01:00
fe87778223
Add the copy op. ( #227 )
...
* Add the copy op.
* Tweak some cat error messages.
* Handle the contiguous case in to_vec1.
* Fast variant for to_vec2.
* Add add a faster to_vec3 variant.
2023-07-23 18:06:47 +01:00
23827c49cd
Cleanup some todos. ( #226 )
...
* Cleanup some todos.
* Fix more todo.
* Optimize for the contiguous case.
* Add the IntDType trait.
* Handle the intdtype trait for more ops.
* Remove a todo.
* Remove a todo.
2023-07-23 16:00:00 +01:00
e449ce53a2
Wrapping code to call the custom op. ( #225 )
...
* Wrapping code to call the custom op.
* Get the rms example to work.
* Get around rustfmt failing in the CI.
* Fix the rms computation.
2023-07-23 11:31:17 +01:00
b8a10425ad
Kernel build example ( #224 )
...
* Build example kernels.
* Add some sample custom kernel.
* Get the example kernel to compile.
* Add some cuda code.
* More cuda custom op.
* More cuda custom ops.
2023-07-23 07:15:37 +01:00
5f20acf080
Revert "Add the layer norm files. ( #222 )" ( #223 )
...
This reverts commit c8459d199d
.
2023-07-22 16:51:11 +01:00
c8459d199d
Add the layer norm files. ( #222 )
2023-07-22 15:06:35 +01:00
1f26042693
Move some shared functions to the nn module. ( #221 )
2023-07-22 13:25:11 +01:00
43c7223292
Rename the .r functions to .dims so as to be a bit more explicit. ( #220 )
2023-07-22 10:39:27 +01:00
52c5d8c087
Add the gather op. ( #219 )
...
* Start adding gather.
* Gather cpu implementation + use in simple training.
* Add scatter_add for the gradient of gather.
* Simple cpu implementation of scatter_add.
* Use gather in the simple-training backprop.
2023-07-22 07:21:28 +01:00
6eeea1b04e
Polish the index-add op and use it in the index-select backprop ( #218 )
...
* Add the cpu version of index-add.
* More cpu support for index-add.
* Use index-add in the backprop.
2023-07-22 05:31:46 +01:00
27174a82aa
Start adding index-add.
2023-07-21 20:12:48 +01:00
5cc843550d
Add binary and ternary custom ops. ( #217 )
2023-07-21 17:29:50 +01:00
4a100875bf
Use a macro to handle the dtype pattern matching. ( #215 )
2023-07-21 16:03:51 +01:00
a6bcdfb269
Custom ops with a single argument ( #214 )
...
* Add the CustomOp1 trait.
* Add an example of custom op.
* Polish the custom op example.
* Add some backward pass test for custom ops.
2023-07-21 15:18:05 +01:00
b02229ce92
Add some epsilon tolerance to grad tests so that they work on cuda / mkl. ( #213 )
2023-07-21 12:45:14 +01:00
410654525f
Refactor the reduce ops in order to introduce argmin/argmax. ( #212 )
...
* Refactor the reduce ops in order to introduce argmin/argmax.
* Clippy fixes.
* Use the newly introduced argmax.
* Fix the strided case.
* Handle the non-contiguous case.
2023-07-21 11:41:08 +01:00
c60831aad4
Add more gradient tests + bugfixes. ( #211 )
...
* Add more gradient tests + bugfixes.
* More tests and fixes.
* More tests.
2023-07-21 06:52:39 +01:00
4845d5cc64
More realistic training setup. ( #210 )
...
* More realistic training setup.
* Compute the model accuracy.
* Very inefficient backprop for index select.
* More backprop.
* Fix some backprop issues.
* Backprop fix.
* Another broadcasting backprop fix.
* Better backprop for reducing ops.
* Training again.
* Add some gradient tests.
* Get the training to work.
2023-07-20 18:25:41 +01:00
fa08fb3126
Add the index-select op. ( #209 )
...
* Add the index-select op.
* Cpu implementation of index-select.
* Add the cpu implementation for index-select.
2023-07-20 14:01:03 +01:00
2a8f28d687
Op refactor ( #208 )
...
* Add the binary and unary op enums to factorize some code.
* Bugfix.
2023-07-20 12:28:45 +01:00
e9c052bf94
Add the comparison operations. ( #207 )
...
* Add the comparison operations.
* Add the helper functions on the tensor side.
* More cmp operations.
* Cpu implementation for the comparison operations.
2023-07-20 09:40:31 +01:00
dc416243a3
Bump the hf-hub dependency to 0.1.3. ( #206 )
2023-07-20 07:27:52 +01:00
12d6dc018d
Support for MQA for llama v2. ( #205 )
...
* Support for MQA for llama v2.
* More llama-v2.
* Move the rotary embedding precomputation in the cache.
* Add a v2 flag.
* Use the hf model.
2023-07-20 06:39:04 +01:00
c34f932319
Fix the mkl build. ( #204 )
...
* Fix the mkl build.
* Fix the build properly.
2023-07-19 19:41:11 +01:00
536c5e702e
Cuda kernels for fast min/max reductions ( #203 )
...
* Add the min/max cuda kernels.
* Better integration of the cuda kernels.
2023-07-19 18:12:27 +01:00
001f9a59ce
Merge pull request #201 from LaurentMazare/remove_wrapper
...
[Proposal] Remove SafeTensor wrapper (allows finer control for users).
2023-07-19 19:02:37 +02:00
9515e8ea6c
Merge branch 'main' into remove_wrapper
2023-07-19 18:53:55 +02:00
ad12e20f6b
Add cpu support for min and max. ( #202 )
...
* Add cpu support for min and max.
* Add min/max all.
2023-07-19 17:11:44 +01:00
e6584476c4
Merge pull request #200 from LaurentMazare/removing_candle_hub
...
Removing `candle-hub` internal to extract into `hf-hub` standalone.
2023-07-19 17:27:55 +02:00
cb687b4897
Add some more developed training examples. ( #199 )
...
* Use contiguous tensors for variables.
* Sketch the mnist example.
* Start adding the reduce ops.
* Renaming.
* Refactor the reduce operations.
* Bugfix for the broadcasting vectorization.
2023-07-19 15:37:52 +01:00
dfd624dbd3
[Proposal] Remove SafeTensor wrapper (allows finer control for users).
2023-07-19 16:25:44 +02:00
439321745a
Removing candle-hub
internal to extract into hf-hub
standalone.
2023-07-19 15:04:38 +02:00
67e20c3792
Sum over more dims. ( #197 )
2023-07-19 06:46:32 +01:00
76dcc7a381
Test the broadcasting binary ops. ( #196 )
2023-07-19 06:18:36 +01:00
fd55fc9592
Add an optimized case when performing the softmax over the last dimension. ( #195 )
2023-07-18 17:59:50 +01:00
6623c227d8
Allow the compiler to vectorize some broadcasting loops. ( #194 )
...
* Allow the compiler to vectorize some broadcasting loops.
* Improve the symmetrical broadcasting case.
2023-07-18 17:12:32 +01:00
79a5b686d0
Properly use the offset when broadcasting on a narrow slice. ( #193 )
2023-07-18 16:36:23 +01:00
a45a3f0312
Optimize the sum for the contiguous case. ( #192 )
2023-07-18 14:57:06 +01:00
3307db204a
Mklize more unary ops. ( #191 )
...
* Mklize more unary ops.
* Even more unary ops.
2023-07-18 13:32:49 +01:00
ff61a42ad7
Use mkl to accelerate binary ops. ( #190 )
...
* Vectorized binary ops with mkl.
* Improve the binary op mkl support.
* Push the support for mkl binary ops.
* Proper vectorization of binary ops.
* Proper mkl'isation when broadcasting binary ops.
2023-07-18 12:04:39 +01:00
b706f32839
Add Shape try into ( #189 )
...
* Add the TryInto trait for shapes.
* Use the vectorized operations in block mode too.
2023-07-18 10:52:16 +01:00
d6313d2447
Add more tracing details to bert. ( #188 )
2023-07-18 08:11:05 +01:00
d73df74cb2
Preliminary support for mkl based gelu. ( #187 )
...
* Preliminary support for mkl based gelu.
* Add the vectorized function for unary ops.
* Get the mkl specialized gelu to work.
2023-07-18 07:48:48 +01:00