Commit Graph

64 Commits

Author SHA1 Message Date
bd9ab9bc04 Add a cuda kernel for dequantizing q8_0. (#1804) 2024-03-05 09:50:37 +01:00
2c95b7394a Handle Q5_0 and Q5_1 quants in cuda. 2024-02-29 10:54:01 +01:00
badf886583 Cuda kernel for dequantizing q8k. (#1760)
* Cuda kernel for dequantizing q8k.

* Clippy lints.
2024-02-26 08:42:44 +01:00
2f22afd80e Cuda acceleration for quantized model. (#1754)
* Boilerplate for the quantized cuda support.

* More basic cuda support.

* More cuda quantization (quantize on cpu for now).

* Add the dequantization bit.

* Start adding some dedicated cuda kernels from llama.cpp.

* Move the kernel code.

* Start interfacing with the kernel.

* Tweak the kernel launch params.

* Bugfix for quantized metal.

* Fix some clippy lints.

* Tweak the launch parameters.

* Tweak cuda basics to perform a quantized matmul.

* Perform the dequantization on the cpu + use cublas for matmul.

* Add the dequantization kernel.

* Test the qmatmul.

* More kernels.

* Matmul-vec kernel.

* Add a couple kernels.

* More dequantization kernels.
2024-02-25 18:11:47 +01:00
121a71e01f Fix the silu cuda kernel. (#1710) 2024-02-14 11:08:18 +01:00
b60064780d feat: add silu activation function (#1706)
* feat: add silu activation function

* use silu/arg in grad

* update candle-nn

* use node
2024-02-14 10:27:22 +01:00
d0aa197b07 ConvTranspose1d cuda support. (#1697)
* ConvTranspose1d cuda support.

* Add the conv-transpose1d kernel.

* Remove some unused variables.
2024-02-12 15:03:18 +01:00
2fe24ac5b1 Rework the cuda casting bits. (#1112) 2023-10-17 09:44:51 +01:00
8f7973958c fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 (#1037)
* fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0

* cargo fmt
2023-10-05 18:46:13 +01:00
c18a856e76 Add the rounding operators. (#1030)
* Add the rounding operators.

* Avoid tracking gradients for the rounding operations.

* Add some rounding tests.
2023-10-04 17:58:44 +01:00
fc59bc31bf fix: add missing gpu fill_* (#996) 2023-09-29 15:49:30 +01:00
5e1c595e00 Optimize the index-select cuda kernel. (#976) 2023-09-28 09:05:29 +01:00
402ddcfcb4 Add the missing kernel. (#955) 2023-09-24 17:21:37 +01:00
a96878f235 cuda cast i64 (#925) 2023-09-21 19:52:39 +01:00
d7e48234d4 Add an erf based gelu op (#900)
* Erf based gelu.

* Add the erf backed gelu.

* Test the new gelu op (which is not gelu_new).
2023-09-19 19:54:28 +01:00
dbd4561416 im2col version of the conv1d kernel. (#815)
* im2col version of the cuda conv1d kernel.

* im2col version of the conv1d cpu kernel.
2023-09-11 14:40:09 +01:00
98d1242b8f im2col based conv2d (#802)
* im2col implementation for conv2d.

* Fix for the im2col implementation to match the current conv2d.

* Small optimization.

* Add a cuda kernel.

* Handle arbitrary layouts.

* Im2Col cuda code.
2023-09-10 21:02:42 +01:00
94c6a8d3d3 Add a dedicated cuda kernel for softmax. (#746) 2023-09-05 17:53:20 +02:00
ad8a62dbf5 Add tanh. (#675)
* Add tanh.

* Use tanh in the lstm block.

* Add a test for tanh forward and backward passes.
2023-08-30 13:54:50 +01:00
393690387f Support dilation in conv-transpose2d. (#671) 2023-08-30 09:22:00 +01:00
59b731de99 Add the powf op. (#664)
* Add the powf op.

* Cuda kernels and backprop.

* Add a test.
2023-08-29 20:48:18 +01:00
71221559d3 Fix the dilated convolutions. (#659) 2023-08-29 16:37:42 +01:00
a044907ffc Dilated convolutions (#657)
* Add the dilation parameter.

* Restore the basic optimizer example.

* Dilation support in cudnn.

* Use the dilation parameter in the cpu backend.

* More dilation support.

* No support for dilation in transposed convolutions.

* Add dilation to a test.

* Remove a print.

* Helper function.
2023-08-29 16:12:11 +01:00
037b41c9dc Cuda conv transpose (#645)
* Cuda kernel for conv-transpose.

* Fix the cuda kernel.

* Fix the tests.
2023-08-28 20:58:49 +01:00
d4e75d5825 Let's keep the dirty code on its own. 2023-08-25 12:01:58 +00:00
be371e827c Intermediary float cast is necessary for cuda 11.8 2023-08-25 11:54:30 +00:00
1c1e34735e static_cast ? 2023-08-25 11:40:36 +00:00
db8bab8b7a Different casting ? 2023-08-25 10:49:22 +00:00
bc131b402b Repairing cast bf16/f16 2023-08-25 10:38:19 +00:00
ca318a6ec7 Add to the cuda example a reproduction of the issue. (#579)
* Add to the cuda example a reproduction of the issue.

* Tweak.

* Add a test using non-square matrixes.

* Fix the conv2d kernel.

* Display the error.

* And tweak the comment.
2023-08-24 12:07:31 +01:00
9a5c7db91a Add support for i64 (#563)
* Add the i64 dtype.

* Adapt the cuda kernels.
2023-08-23 10:42:19 +01:00
a1812f934f Add a yolo-v3 example. (#528)
* Add a couple functions required for yolo.

* Add the yolo-v3 example.

* Add minimum and maximum.

* Use the newly introduced maximum.

* Cuda support for min/max + add some testing.

* Allow for more tests to work with accelerate.

* Fix a typo.
2023-08-20 18:19:37 +01:00
c84883ecf2 Add a cuda kernel for upsampling. (#441)
* Add a cuda kernel for upsampling.

* Update for the latest tokenizers version.
2023-08-14 13:12:17 +01:00
a094dc503d Add a cuda kernel for avg-pool2d. (#440)
* Add a cuda kernel for avg-pool2d.

* Avoid running out of bounds.

* Finish wiring the avg pool kernel + add some testing.

* Support for max-pool + testing.
2023-08-14 12:32:05 +01:00
34f4b3187e Add a naive conv2d cuda kernel. (#438)
* Add a naive conv2d cuda kernel.

* Proper conv2d support on the rust side.

* Conv1d testing on gpu.

* Also use the test on gpus.

* Fix the clean-ptx target.
2023-08-14 10:34:42 +01:00
4a95d34c83 Compat windows. 2023-08-10 17:46:47 +02:00
66d1c093e0 This is duplicated code on Cuda 12.2.
Without it we can compile for 52 (but I get Operation Not supported
when actually trying to use those kernels).
2023-08-10 09:20:18 +02:00
166bfd5847 Add the recip op + use it in stable-diffusion. (#331)
* Add the recip unary op.

* Fix the cuda kernel.

* Use the recip op in sigmoid.
2023-08-06 21:14:52 +01:00
4b3bd79fbd Remove the embedding ops in favor of index-select. (#299)
* Remove the embedding ops in favor of index-select.

* Also remove the cuda kernels.
2023-08-02 05:42:11 +01:00
c950a5c6b1 Cuda support for the mnist training. (#277)
* Cuda support for the mnist training.

* min/max fix + testing.

* Add the argmin/argmax tests.

* More cuda support for argmin/argmax.

* Cuda kernels for argmin and argmax.
2023-07-29 19:48:04 +01:00
c0a8ed19eb Support for where-cond on cuda for u8 and u32. (#274) 2023-07-29 11:48:58 +01:00
4f92420132 Add some flash attn test (#253)
* Add some flash-attn test.

* Add the cpu test.

* Fail when the head is not a multiple of 8.

* Polish the flash attention test.
2023-07-26 20:56:00 +01:00
944d70bd9a Add a test for scatter add. (#238)
* Add a test for scatter add (segfaults on gpus for now).

* Bugfix for the scatter add cuda kernel.
2023-07-25 09:12:14 +01:00
74a6a769dd Cuda kernels for IndexAdd/ScatterAdd. (#236)
* Skeleton methods for IndexAdd/ScatterAdd.

* Add a Map2InPlace trait.

* Add the glue code for the index-add/scatter-add kernels.

* Tweak the file name: embeddings -> indexing.

* Add the cuda kernel for indexadd.

* And add the scatter-add kernels.
2023-07-24 21:53:08 +01:00
581b104f97 Indexing cuda (#235)
* Allow using uint8_t for indexing.

* Revert the default cuda feature.

* Add a cuda-kernel for index-select.

* Add a test for gather.
2023-07-24 20:22:47 +01:00
b50f932e7c Add some cmp tests. (#233)
* Add some cmp tests.

* Add the cuda kernels for comparison operations.
2023-07-24 16:53:45 +01:00
23827c49cd Cleanup some todos. (#226)
* Cleanup some todos.

* Fix more todo.

* Optimize for the contiguous case.

* Add the IntDType trait.

* Handle the intdtype trait for more ops.

* Remove a todo.

* Remove a todo.
2023-07-23 16:00:00 +01:00
5f20acf080 Revert "Add the layer norm files. (#222)" (#223)
This reverts commit c8459d199d.
2023-07-22 16:51:11 +01:00
c8459d199d Add the layer norm files. (#222) 2023-07-22 15:06:35 +01:00
536c5e702e Cuda kernels for fast min/max reductions (#203)
* Add the min/max cuda kernels.

* Better integration of the cuda kernels.
2023-07-19 18:12:27 +01:00