Commit Graph

16 Commits

Author SHA1 Message Date
121a71e01f Fix the silu cuda kernel. (#1710) 2024-02-14 11:08:18 +01:00
b60064780d feat: add silu activation function (#1706)
* feat: add silu activation function

* use silu/arg in grad

* update candle-nn

* use node
2024-02-14 10:27:22 +01:00
c18a856e76 Add the rounding operators. (#1030)
* Add the rounding operators.

* Avoid tracking gradients for the rounding operations.

* Add some rounding tests.
2023-10-04 17:58:44 +01:00
402ddcfcb4 Add the missing kernel. (#955) 2023-09-24 17:21:37 +01:00
d7e48234d4 Add an erf based gelu op (#900)
* Erf based gelu.

* Add the erf backed gelu.

* Test the new gelu op (which is not gelu_new).
2023-09-19 19:54:28 +01:00
ad8a62dbf5 Add tanh. (#675)
* Add tanh.

* Use tanh in the lstm block.

* Add a test for tanh forward and backward passes.
2023-08-30 13:54:50 +01:00
59b731de99 Add the powf op. (#664)
* Add the powf op.

* Cuda kernels and backprop.

* Add a test.
2023-08-29 20:48:18 +01:00
166bfd5847 Add the recip op + use it in stable-diffusion. (#331)
* Add the recip unary op.

* Fix the cuda kernel.

* Use the recip op in sigmoid.
2023-08-06 21:14:52 +01:00
4f92420132 Add some flash attn test (#253)
* Add some flash-attn test.

* Add the cpu test.

* Fail when the head is not a multiple of 8.

* Polish the flash attention test.
2023-07-26 20:56:00 +01:00
bc3be6f9b0 Add the elu cuda kernel. (#114) 2023-07-10 07:57:01 +01:00
c71a38deb7 Tweak the include order to include math.h first. (#100) 2023-07-07 06:47:25 +01:00
f114394456 Include the math.h file to get access to constants. (#99) 2023-07-07 06:42:57 +01:00
9784d1ed9f Minor tweaks. 2023-07-03 18:31:55 +01:00
ec79fc43f2 Add the bf16 cuda kernels. 2023-06-29 23:12:02 +01:00
1ce3843cab Add the relu op. 2023-06-28 09:38:54 +01:00
d7f729fb8f Refactor the hierarchy. 2023-06-27 11:57:27 +02:00