fe87778223
Add the copy op. ( #227 )
...
* Add the copy op.
* Tweak some cat error messages.
* Handle the contiguous case in to_vec1.
* Fast variant for to_vec2.
* Add add a faster to_vec3 variant.
2023-07-23 18:06:47 +01:00
23827c49cd
Cleanup some todos. ( #226 )
...
* Cleanup some todos.
* Fix more todo.
* Optimize for the contiguous case.
* Add the IntDType trait.
* Handle the intdtype trait for more ops.
* Remove a todo.
* Remove a todo.
2023-07-23 16:00:00 +01:00
e449ce53a2
Wrapping code to call the custom op. ( #225 )
...
* Wrapping code to call the custom op.
* Get the rms example to work.
* Get around rustfmt failing in the CI.
* Fix the rms computation.
2023-07-23 11:31:17 +01:00
b8a10425ad
Kernel build example ( #224 )
...
* Build example kernels.
* Add some sample custom kernel.
* Get the example kernel to compile.
* Add some cuda code.
* More cuda custom op.
* More cuda custom ops.
2023-07-23 07:15:37 +01:00
43c7223292
Rename the .r functions to .dims so as to be a bit more explicit. ( #220 )
2023-07-22 10:39:27 +01:00
52c5d8c087
Add the gather op. ( #219 )
...
* Start adding gather.
* Gather cpu implementation + use in simple training.
* Add scatter_add for the gradient of gather.
* Simple cpu implementation of scatter_add.
* Use gather in the simple-training backprop.
2023-07-22 07:21:28 +01:00
6eeea1b04e
Polish the index-add op and use it in the index-select backprop ( #218 )
...
* Add the cpu version of index-add.
* More cpu support for index-add.
* Use index-add in the backprop.
2023-07-22 05:31:46 +01:00
27174a82aa
Start adding index-add.
2023-07-21 20:12:48 +01:00
5cc843550d
Add binary and ternary custom ops. ( #217 )
2023-07-21 17:29:50 +01:00
4a100875bf
Use a macro to handle the dtype pattern matching. ( #215 )
2023-07-21 16:03:51 +01:00
a6bcdfb269
Custom ops with a single argument ( #214 )
...
* Add the CustomOp1 trait.
* Add an example of custom op.
* Polish the custom op example.
* Add some backward pass test for custom ops.
2023-07-21 15:18:05 +01:00
b02229ce92
Add some epsilon tolerance to grad tests so that they work on cuda / mkl. ( #213 )
2023-07-21 12:45:14 +01:00
410654525f
Refactor the reduce ops in order to introduce argmin/argmax. ( #212 )
...
* Refactor the reduce ops in order to introduce argmin/argmax.
* Clippy fixes.
* Use the newly introduced argmax.
* Fix the strided case.
* Handle the non-contiguous case.
2023-07-21 11:41:08 +01:00
c60831aad4
Add more gradient tests + bugfixes. ( #211 )
...
* Add more gradient tests + bugfixes.
* More tests and fixes.
* More tests.
2023-07-21 06:52:39 +01:00
4845d5cc64
More realistic training setup. ( #210 )
...
* More realistic training setup.
* Compute the model accuracy.
* Very inefficient backprop for index select.
* More backprop.
* Fix some backprop issues.
* Backprop fix.
* Another broadcasting backprop fix.
* Better backprop for reducing ops.
* Training again.
* Add some gradient tests.
* Get the training to work.
2023-07-20 18:25:41 +01:00
fa08fb3126
Add the index-select op. ( #209 )
...
* Add the index-select op.
* Cpu implementation of index-select.
* Add the cpu implementation for index-select.
2023-07-20 14:01:03 +01:00
2a8f28d687
Op refactor ( #208 )
...
* Add the binary and unary op enums to factorize some code.
* Bugfix.
2023-07-20 12:28:45 +01:00
e9c052bf94
Add the comparison operations. ( #207 )
...
* Add the comparison operations.
* Add the helper functions on the tensor side.
* More cmp operations.
* Cpu implementation for the comparison operations.
2023-07-20 09:40:31 +01:00
536c5e702e
Cuda kernels for fast min/max reductions ( #203 )
...
* Add the min/max cuda kernels.
* Better integration of the cuda kernels.
2023-07-19 18:12:27 +01:00
9515e8ea6c
Merge branch 'main' into remove_wrapper
2023-07-19 18:53:55 +02:00
ad12e20f6b
Add cpu support for min and max. ( #202 )
...
* Add cpu support for min and max.
* Add min/max all.
2023-07-19 17:11:44 +01:00
cb687b4897
Add some more developed training examples. ( #199 )
...
* Use contiguous tensors for variables.
* Sketch the mnist example.
* Start adding the reduce ops.
* Renaming.
* Refactor the reduce operations.
* Bugfix for the broadcasting vectorization.
2023-07-19 15:37:52 +01:00
dfd624dbd3
[Proposal] Remove SafeTensor wrapper (allows finer control for users).
2023-07-19 16:25:44 +02:00
67e20c3792
Sum over more dims. ( #197 )
2023-07-19 06:46:32 +01:00
76dcc7a381
Test the broadcasting binary ops. ( #196 )
2023-07-19 06:18:36 +01:00
fd55fc9592
Add an optimized case when performing the softmax over the last dimension. ( #195 )
2023-07-18 17:59:50 +01:00
6623c227d8
Allow the compiler to vectorize some broadcasting loops. ( #194 )
...
* Allow the compiler to vectorize some broadcasting loops.
* Improve the symmetrical broadcasting case.
2023-07-18 17:12:32 +01:00
79a5b686d0
Properly use the offset when broadcasting on a narrow slice. ( #193 )
2023-07-18 16:36:23 +01:00
a45a3f0312
Optimize the sum for the contiguous case. ( #192 )
2023-07-18 14:57:06 +01:00
3307db204a
Mklize more unary ops. ( #191 )
...
* Mklize more unary ops.
* Even more unary ops.
2023-07-18 13:32:49 +01:00
ff61a42ad7
Use mkl to accelerate binary ops. ( #190 )
...
* Vectorized binary ops with mkl.
* Improve the binary op mkl support.
* Push the support for mkl binary ops.
* Proper vectorization of binary ops.
* Proper mkl'isation when broadcasting binary ops.
2023-07-18 12:04:39 +01:00
b706f32839
Add Shape try into ( #189 )
...
* Add the TryInto trait for shapes.
* Use the vectorized operations in block mode too.
2023-07-18 10:52:16 +01:00
d73df74cb2
Preliminary support for mkl based gelu. ( #187 )
...
* Preliminary support for mkl based gelu.
* Add the vectorized function for unary ops.
* Get the mkl specialized gelu to work.
2023-07-18 07:48:48 +01:00
c3a73c583e
Add support for mkl tanh. ( #185 )
2023-07-17 22:06:43 +01:00
acb2f90469
Broadcasting performance optimization (cpu) ( #182 )
...
* Avoid recomputing the index from scratch each time.
* More performance optimisations.
2023-07-17 13:41:09 +01:00
5b1c0bc9be
Performance improvement. ( #181 )
2023-07-17 11:07:14 +01:00
28e1c07304
Process unary functions per block ( #180 )
...
* Process unary functions per block.
* Add some inline hints.
2023-07-17 10:22:33 +01:00
104f89df31
Centralize the dependency versions and inherit them. ( #177 )
2023-07-16 07:47:17 +01:00
18ea92d83b
Iteration over strided blocks ( #175 )
...
* Introduce the strided blocks.
* Use the strided blocks to fasten the copy.
* Add more testing.
2023-07-15 21:30:35 +01:00
66750f9827
Add some 'cuda-if-available' helper function. ( #172 )
2023-07-15 08:25:15 +01:00
3672e1a46f
Revert "Testing fmt CI check behind cuda feature flag."
...
This reverts commit b9605310b1
.
2023-07-14 15:18:14 +00:00
b9605310b1
Testing fmt CI check behind cuda feature flag.
2023-07-14 15:14:52 +00:00
dcb4a9291e
Expliciting how to enable cuda.
2023-07-14 17:08:05 +02:00
4ed56d7861
Removing cuda default.
...
Seems very important for a lot of exploring users usually on laptop
without GPUs.
Adding more README instructions in a follow up.
2023-07-14 16:52:15 +02:00
88f666781f
Wasm proof of concept. ( #167 )
...
* Wasm proof of concept.
* Run whisper inference in the browser.
* Some fixes.
* Move the wasm example.
* Change the tokenizer config.
2023-07-14 14:51:46 +01:00
d88b6cdca9
Add backtrace information to errors where relevant. ( #166 )
...
* Add backtrace information to errors where relevant.
* More backtrace information.
* Add to the FAQ.
2023-07-14 09:31:25 +01:00
a2f72edc0d
Simplify the parameters used by sum and sum_keepdim. ( #165 )
2023-07-14 08:22:08 +01:00
2bfa791336
Use the same default as pytorch for sum. ( #164 )
2023-07-13 21:32:32 +01:00
23e105cd94
Add the gradient for reduce-sum. ( #162 )
...
* Add the gradient for reduce-sum.
* And add the gradient for the broadcast ops.
* Add some backprop tests.
* Add some linear regression example.
2023-07-13 20:14:10 +01:00
ded93a1169
Add the SGD optimizer ( #160 )
...
* Add the nn::optim and some conversion traits.
* Add the backward_step function for SGD.
* Get the SGD optimizer to work and add a test.
* Make the test slighly simpler.
2023-07-13 19:05:44 +01:00