b60064780d
feat: add silu activation function ( #1706 )
...
* feat: add silu activation function
* use silu/arg in grad
* update candle-nn
* use node
2024-02-14 10:27:22 +01:00
e6d86b0819
Add the pow operator. ( #1583 )
...
* Add the pow operator.
* Support the pow operation in onnx.
2024-01-13 20:24:06 +01:00
96f1a28e39
Add a simple full method. ( #1455 )
...
* Add a simple implementation of the full method.
* Add the docstring.
2023-12-17 20:15:57 -05:00
4cb443d00a
Fix the logsumexp test. ( #1426 )
2023-12-12 10:56:11 -06:00
77252ffb82
Add logsumexp function ( #1424 )
2023-12-12 10:32:17 -06:00
8d6c6de8e0
Missing new test.
2023-11-20 14:38:35 +01:00
7ec345c2eb
Adding the test scaffolding.
2023-11-20 14:38:35 +01:00
c6763e3b41
Add a simple implementation of cumsum. ( #1334 )
...
* Add a simple implementation of cumsum.
* Add another test.
2023-11-15 21:11:15 +00:00
347e31c9ff
Add the tril/triu/eye ops. ( #1333 )
...
* Add tril/triu/eye.
* Revert the metal crate tweak.
2023-11-15 20:34:37 +00:00
9e666d4229
Add the var method. ( #1315 )
...
* Add the var method.
* Add a test.
2023-11-10 22:47:57 +01:00
5fc66bd4ba
Support negative steps in arange. ( #1218 )
2023-10-30 07:40:54 +00:00
154c674a79
Add i64-abs. ( #1216 )
2023-10-29 15:28:53 +00:00
87eb1658e1
Add pad_with_same. ( #1127 )
...
* More model cloning.
* More cloning on quantized models.
* Add pad-with-same.
* Add some tests.
2023-10-18 23:13:37 +01:00
7f7d95e2c3
Add the round-to function. ( #1039 )
2023-10-05 20:28:09 +01:00
8f7973958c
fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 ( #1037 )
...
* fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0
* cargo fmt
2023-10-05 18:46:13 +01:00
c18a856e76
Add the rounding operators. ( #1030 )
...
* Add the rounding operators.
* Avoid tracking gradients for the rounding operations.
* Add some rounding tests.
2023-10-04 17:58:44 +01:00
043cc25766
Fix for the index-select cuda setup. ( #1022 )
...
* Fix for index-select.
* Better fix + add some testing.
2023-10-03 10:21:46 +01:00
fc59bc31bf
fix: add missing gpu fill_* ( #996 )
2023-09-29 15:49:30 +01:00
8601537e31
Add slice-scatter. ( #927 )
...
* Add slice-scatter.
* Add the op.
* Make transpose be a no-op when the dimensions are identical.
* Add the backprop.
* And add some gradient test.
2023-09-22 12:18:16 +01:00
7b26e513f1
Add the erf function. ( #917 )
2023-09-21 06:19:10 +01:00
d7e48234d4
Add an erf based gelu op ( #900 )
...
* Erf based gelu.
* Add the erf backed gelu.
* Test the new gelu op (which is not gelu_new).
2023-09-19 19:54:28 +01:00
18d3c803a8
Scalar support in minimum/maximum. ( #832 )
...
* Scalar support in minimum/maximum.
* Add a clamp method to tensors.
2023-09-13 08:24:58 +01:00
258ac32c38
Fix cuda randn when generating an odd number of values. ( #793 )
2023-09-09 18:44:21 +01:00
5320aa6b7d
Move the test-utils bits to a shared place. ( #619 )
2023-08-27 09:42:22 +01:00
a1812f934f
Add a yolo-v3 example. ( #528 )
...
* Add a couple functions required for yolo.
* Add the yolo-v3 example.
* Add minimum and maximum.
* Use the newly introduced maximum.
* Cuda support for min/max + add some testing.
* Allow for more tests to work with accelerate.
* Fix a typo.
2023-08-20 18:19:37 +01:00
2fcb386f17
Add a broadcast variant to matmul. ( #523 )
...
* Add a broadcast variant to matmul.
* Get the test to pass.
2023-08-20 13:20:42 +01:00
c7f92f985e
Further randn tweaks: use the appropriate rng rather than the f64 one, some cleanup. ( #383 )
2023-08-10 05:48:19 +01:00
3bbc08a8df
Fix randn cpu ( #382 )
...
* Change distributions
Standard generates in [0, 1), Normal is correct.
* Add test
Not sure if this is the best place to put the test
* Remove unnecessary use
2023-08-10 05:33:44 +01:00
51e51da896
Rename the candle crate to candle-core ( #301 )
...
* Rename to candle-core.
* More candle-core renaming.
2023-08-02 08:20:22 +01:00
4b3bd79fbd
Remove the embedding ops in favor of index-select. ( #299 )
...
* Remove the embedding ops in favor of index-select.
* Also remove the cuda kernels.
2023-08-02 05:42:11 +01:00
cc76c63202
Use index-select for the embeddings as it supports backprop. ( #298 )
2023-08-01 20:44:43 +01:00
c950a5c6b1
Cuda support for the mnist training. ( #277 )
...
* Cuda support for the mnist training.
* min/max fix + testing.
* Add the argmin/argmax tests.
* More cuda support for argmin/argmax.
* Cuda kernels for argmin and argmax.
2023-07-29 19:48:04 +01:00
3eb2bc6d07
Softmax numerical stability. ( #267 )
...
* Softmax numerical stability.
* Fix the flash-attn test.
2023-07-28 13:13:01 +01:00
944d70bd9a
Add a test for scatter add. ( #238 )
...
* Add a test for scatter add (segfaults on gpus for now).
* Bugfix for the scatter add cuda kernel.
2023-07-25 09:12:14 +01:00
18cc73954a
Add some testing for index-add ( #237 )
...
* Add some testing for index-add.
* Fix the cpu implementation for index-add.
2023-07-25 08:38:33 +01:00
581b104f97
Indexing cuda ( #235 )
...
* Allow using uint8_t for indexing.
* Revert the default cuda feature.
* Add a cuda-kernel for index-select.
* Add a test for gather.
2023-07-24 20:22:47 +01:00
b50f932e7c
Add some cmp tests. ( #233 )
...
* Add some cmp tests.
* Add the cuda kernels for comparison operations.
2023-07-24 16:53:45 +01:00
23827c49cd
Cleanup some todos. ( #226 )
...
* Cleanup some todos.
* Fix more todo.
* Optimize for the contiguous case.
* Add the IntDType trait.
* Handle the intdtype trait for more ops.
* Remove a todo.
* Remove a todo.
2023-07-23 16:00:00 +01:00
43c7223292
Rename the .r functions to .dims so as to be a bit more explicit. ( #220 )
2023-07-22 10:39:27 +01:00
fa08fb3126
Add the index-select op. ( #209 )
...
* Add the index-select op.
* Cpu implementation of index-select.
* Add the cpu implementation for index-select.
2023-07-20 14:01:03 +01:00
76dcc7a381
Test the broadcasting binary ops. ( #196 )
2023-07-19 06:18:36 +01:00
a2f72edc0d
Simplify the parameters used by sum and sum_keepdim. ( #165 )
2023-07-14 08:22:08 +01:00
2bfa791336
Use the same default as pytorch for sum. ( #164 )
2023-07-13 21:32:32 +01:00
5c3864f9f7
Add more sum tests. ( #110 )
...
* Add some tests for the sum.
* More sum testing.
2023-07-08 13:15:36 +01:00
2741b39ad3
Use broadcasted scalars for const tensors.
2023-06-29 11:56:40 +01:00
caafef6cc1
Get the cpu tests to run.
2023-06-28 14:32:02 +01:00
0417d9cec8
Add more cuda testing again.
2023-06-28 08:33:43 +01:00
395c84e80a
Also run the backprop tests on cuda.
2023-06-28 08:15:03 +01:00
dbe3e4e7c0
Add some test utils module.
2023-06-27 16:20:28 +01:00
07a682c2ff
Run the tensor tests for the cuda backend too.
2023-06-27 15:37:01 +01:00