Commit Graph

42 Commits

Author SHA1 Message Date
581b104f97 Indexing cuda (#235)
* Allow using uint8_t for indexing.

* Revert the default cuda feature.

* Add a cuda-kernel for index-select.

* Add a test for gather.
2023-07-24 20:22:47 +01:00
b50f932e7c Add some cmp tests. (#233)
* Add some cmp tests.

* Add the cuda kernels for comparison operations.
2023-07-24 16:53:45 +01:00
e449ce53a2 Wrapping code to call the custom op. (#225)
* Wrapping code to call the custom op.

* Get the rms example to work.

* Get around rustfmt failing in the CI.

* Fix the rms computation.
2023-07-23 11:31:17 +01:00
b8a10425ad Kernel build example (#224)
* Build example kernels.

* Add some sample custom kernel.

* Get the example kernel to compile.

* Add some cuda code.

* More cuda custom op.

* More cuda custom ops.
2023-07-23 07:15:37 +01:00
43c7223292 Rename the .r functions to .dims so as to be a bit more explicit. (#220) 2023-07-22 10:39:27 +01:00
52c5d8c087 Add the gather op. (#219)
* Start adding gather.

* Gather cpu implementation + use in simple training.

* Add scatter_add for the gradient of gather.

* Simple cpu implementation of scatter_add.

* Use gather in the simple-training backprop.
2023-07-22 07:21:28 +01:00
27174a82aa Start adding index-add. 2023-07-21 20:12:48 +01:00
410654525f Refactor the reduce ops in order to introduce argmin/argmax. (#212)
* Refactor the reduce ops in order to introduce argmin/argmax.

* Clippy fixes.

* Use the newly introduced argmax.

* Fix the strided case.

* Handle the non-contiguous case.
2023-07-21 11:41:08 +01:00
fa08fb3126 Add the index-select op. (#209)
* Add the index-select op.

* Cpu implementation of index-select.

* Add the cpu implementation for index-select.
2023-07-20 14:01:03 +01:00
2a8f28d687 Op refactor (#208)
* Add the binary and unary op enums to factorize some code.

* Bugfix.
2023-07-20 12:28:45 +01:00
e9c052bf94 Add the comparison operations. (#207)
* Add the comparison operations.

* Add the helper functions on the tensor side.

* More cmp operations.

* Cpu implementation for the comparison operations.
2023-07-20 09:40:31 +01:00
536c5e702e Cuda kernels for fast min/max reductions (#203)
* Add the min/max cuda kernels.

* Better integration of the cuda kernels.
2023-07-19 18:12:27 +01:00
cb687b4897 Add some more developed training examples. (#199)
* Use contiguous tensors for variables.

* Sketch the mnist example.

* Start adding the reduce ops.

* Renaming.

* Refactor the reduce operations.

* Bugfix for the broadcasting vectorization.
2023-07-19 15:37:52 +01:00
d88b6cdca9 Add backtrace information to errors where relevant. (#166)
* Add backtrace information to errors where relevant.

* More backtrace information.

* Add to the FAQ.
2023-07-14 09:31:25 +01:00
64264d97c1 Modular backends (#138)
* Add some trait to formalize backends.

* Use the generic backend trait.
2023-07-11 11:17:02 +01:00
674eb35e10 Remove some dead-code pragmas. (#137) 2023-07-11 09:33:59 +01:00
ae79c00e48 Allow for uniform initialization in a single step. (#136) 2023-07-11 08:52:29 +01:00
f29b77ec19 Random initializers. (#128)
* Random initialization.

* CPU rng generation.
2023-07-10 18:26:21 +01:00
270997a055 Add the elu op. (#113) 2023-07-09 21:56:31 +01:00
eb64ad0d4d Cuda kernel for the conv1d op (#111)
* Boilerplate code for conv1d.

* Boilerplate code for conv1d.

* More boilerplate for conv1d.

* Conv1d work.

* Get the conv1d cuda kernel to work.

* Conv1d support when no batch dim.
2023-07-08 18:13:25 +01:00
e676f85f00 Sketch a fast cuda kernel for reduce-sum. (#109)
* Sketch a fast cuda kernel for reduce-sum.

* Sketch the rust support code for the fast sum kernel.

* More work on the fast kernel.

* Add some testing ground.

* A couple fixes for the fast sum kernel.
2023-07-08 12:43:56 +01:00
02b5c38049 Use cublas bf16. (#101) 2023-07-07 08:00:12 +01:00
dd60bd84bb MKL adjustments. (#87) 2023-07-06 11:37:27 +01:00
c297a50960 Add mkl support for matrix multiply. (#86)
* Fix some rebase issues.

* Use mkl instead.

* Use mkl in bert.

* Add the optional mkl feature.

* Conditional compilation based on the mkl feature.

* Add more mkl support.
2023-07-06 11:05:05 +01:00
a424d95473 Add more of the conv1d op. 2023-07-04 11:15:45 +01:00
3aac1047fe Sketch the conv1d op. 2023-07-04 10:52:34 +01:00
a57b314780 Add a batch dimension on the bert example. 2023-07-04 06:10:52 +01:00
86d691c74c Better handling of the batch dimension in matmul. 2023-07-03 22:51:40 +01:00
8ad47907f3 Add the kernels. 2023-06-30 10:26:56 +01:00
c9c468e1aa Use Map2 for binary ops. 2023-06-29 10:09:15 +01:00
83c7d660ca Add Map2. 2023-06-29 10:05:06 +01:00
367170da45 Also use Map1 for embedding. 2023-06-29 09:45:27 +01:00
8ad03a5fb6 Use Map1 on unary ops. 2023-06-29 09:37:38 +01:00
fff13dbb4e Factorize the kernel naming scheme. 2023-06-29 09:29:59 +01:00
d3c7b0d168 Use Map1 for sum. 2023-06-29 09:27:07 +01:00
122e334d0c Simplify the pattern matching logic in the cuda backend. 2023-06-29 09:21:11 +01:00
6c9e6b5a99 Get the cuda tests to pass. 2023-06-28 15:53:23 +01:00
3f0d9fbb25 Adapt the cuda bits. 2023-06-28 15:43:03 +01:00
e221d38819 Factor the slicing code in cuda. 2023-06-27 15:45:59 +01:00
07a682c2ff Run the tensor tests for the cuda backend too. 2023-06-27 15:37:01 +01:00
380d61e990 Fix two cuda bugs (matmul and where_cond). 2023-06-27 11:31:04 +01:00
d7f729fb8f Refactor the hierarchy. 2023-06-27 11:57:27 +02:00