4a95d34c83
Compat windows.
2023-08-10 17:46:47 +02:00
66d1c093e0
This is duplicated code on Cuda 12.2.
...
Without it we can compile for 52 (but I get Operation Not supported
when actually trying to use those kernels).
2023-08-10 09:20:18 +02:00
166bfd5847
Add the recip op + use it in stable-diffusion. ( #331 )
...
* Add the recip unary op.
* Fix the cuda kernel.
* Use the recip op in sigmoid.
2023-08-06 21:14:52 +01:00
4b3bd79fbd
Remove the embedding ops in favor of index-select. ( #299 )
...
* Remove the embedding ops in favor of index-select.
* Also remove the cuda kernels.
2023-08-02 05:42:11 +01:00
c950a5c6b1
Cuda support for the mnist training. ( #277 )
...
* Cuda support for the mnist training.
* min/max fix + testing.
* Add the argmin/argmax tests.
* More cuda support for argmin/argmax.
* Cuda kernels for argmin and argmax.
2023-07-29 19:48:04 +01:00
c0a8ed19eb
Support for where-cond on cuda for u8 and u32. ( #274 )
2023-07-29 11:48:58 +01:00
4f92420132
Add some flash attn test ( #253 )
...
* Add some flash-attn test.
* Add the cpu test.
* Fail when the head is not a multiple of 8.
* Polish the flash attention test.
2023-07-26 20:56:00 +01:00
944d70bd9a
Add a test for scatter add. ( #238 )
...
* Add a test for scatter add (segfaults on gpus for now).
* Bugfix for the scatter add cuda kernel.
2023-07-25 09:12:14 +01:00
74a6a769dd
Cuda kernels for IndexAdd/ScatterAdd. ( #236 )
...
* Skeleton methods for IndexAdd/ScatterAdd.
* Add a Map2InPlace trait.
* Add the glue code for the index-add/scatter-add kernels.
* Tweak the file name: embeddings -> indexing.
* Add the cuda kernel for indexadd.
* And add the scatter-add kernels.
2023-07-24 21:53:08 +01:00
581b104f97
Indexing cuda ( #235 )
...
* Allow using uint8_t for indexing.
* Revert the default cuda feature.
* Add a cuda-kernel for index-select.
* Add a test for gather.
2023-07-24 20:22:47 +01:00
b50f932e7c
Add some cmp tests. ( #233 )
...
* Add some cmp tests.
* Add the cuda kernels for comparison operations.
2023-07-24 16:53:45 +01:00
23827c49cd
Cleanup some todos. ( #226 )
...
* Cleanup some todos.
* Fix more todo.
* Optimize for the contiguous case.
* Add the IntDType trait.
* Handle the intdtype trait for more ops.
* Remove a todo.
* Remove a todo.
2023-07-23 16:00:00 +01:00
5f20acf080
Revert "Add the layer norm files. ( #222 )" ( #223 )
...
This reverts commit c8459d199d
.
2023-07-22 16:51:11 +01:00
c8459d199d
Add the layer norm files. ( #222 )
2023-07-22 15:06:35 +01:00
536c5e702e
Cuda kernels for fast min/max reductions ( #203 )
...
* Add the min/max cuda kernels.
* Better integration of the cuda kernels.
2023-07-19 18:12:27 +01:00
bc3be6f9b0
Add the elu cuda kernel. ( #114 )
2023-07-10 07:57:01 +01:00
c187f347bf
Make it easier to use whisper samples from the repo. ( #112 )
...
* Make it easier to use samples from the repo.
* Use f32 for accumulation in the f16/bf16 kernels.
2023-07-08 18:48:27 +01:00
eb64ad0d4d
Cuda kernel for the conv1d op ( #111 )
...
* Boilerplate code for conv1d.
* Boilerplate code for conv1d.
* More boilerplate for conv1d.
* Conv1d work.
* Get the conv1d cuda kernel to work.
* Conv1d support when no batch dim.
2023-07-08 18:13:25 +01:00
e676f85f00
Sketch a fast cuda kernel for reduce-sum. ( #109 )
...
* Sketch a fast cuda kernel for reduce-sum.
* Sketch the rust support code for the fast sum kernel.
* More work on the fast kernel.
* Add some testing ground.
* A couple fixes for the fast sum kernel.
2023-07-08 12:43:56 +01:00
c71a38deb7
Tweak the include order to include math.h first. ( #100 )
2023-07-07 06:47:25 +01:00
f114394456
Include the math.h file to get access to constants. ( #99 )
2023-07-07 06:42:57 +01:00
9784d1ed9f
Minor tweaks.
2023-07-03 18:31:55 +01:00
313fa022a5
Bugfix: remove the u8/bf16 conversion kernel as it is ambiguous.
2023-06-30 10:43:32 +01:00
8ad47907f3
Add the kernels.
2023-06-30 10:26:56 +01:00
6486a6d7b2
Avoid some cast kernels.
2023-06-29 23:23:44 +01:00
ec79fc43f2
Add the bf16 cuda kernels.
2023-06-29 23:12:02 +01:00
1ce3843cab
Add the relu op.
2023-06-28 09:38:54 +01:00
380d61e990
Fix two cuda bugs (matmul and where_cond).
2023-06-27 11:31:04 +01:00
d7f729fb8f
Refactor the hierarchy.
2023-06-27 11:57:27 +02:00