674eb35e10
Remove some dead-code pragmas. ( #137 )
2023-07-11 09:33:59 +01:00
ae79c00e48
Allow for uniform initialization in a single step. ( #136 )
2023-07-11 08:52:29 +01:00
f29b77ec19
Random initializers. ( #128 )
...
* Random initialization.
* CPU rng generation.
2023-07-10 18:26:21 +01:00
270997a055
Add the elu op. ( #113 )
2023-07-09 21:56:31 +01:00
eb64ad0d4d
Cuda kernel for the conv1d op ( #111 )
...
* Boilerplate code for conv1d.
* Boilerplate code for conv1d.
* More boilerplate for conv1d.
* Conv1d work.
* Get the conv1d cuda kernel to work.
* Conv1d support when no batch dim.
2023-07-08 18:13:25 +01:00
e676f85f00
Sketch a fast cuda kernel for reduce-sum. ( #109 )
...
* Sketch a fast cuda kernel for reduce-sum.
* Sketch the rust support code for the fast sum kernel.
* More work on the fast kernel.
* Add some testing ground.
* A couple fixes for the fast sum kernel.
2023-07-08 12:43:56 +01:00
02b5c38049
Use cublas bf16. ( #101 )
2023-07-07 08:00:12 +01:00
dd60bd84bb
MKL adjustments. ( #87 )
2023-07-06 11:37:27 +01:00
c297a50960
Add mkl support for matrix multiply. ( #86 )
...
* Fix some rebase issues.
* Use mkl instead.
* Use mkl in bert.
* Add the optional mkl feature.
* Conditional compilation based on the mkl feature.
* Add more mkl support.
2023-07-06 11:05:05 +01:00
a424d95473
Add more of the conv1d op.
2023-07-04 11:15:45 +01:00
3aac1047fe
Sketch the conv1d op.
2023-07-04 10:52:34 +01:00
a57b314780
Add a batch dimension on the bert example.
2023-07-04 06:10:52 +01:00
86d691c74c
Better handling of the batch dimension in matmul.
2023-07-03 22:51:40 +01:00
8ad47907f3
Add the kernels.
2023-06-30 10:26:56 +01:00
c9c468e1aa
Use Map2 for binary ops.
2023-06-29 10:09:15 +01:00
83c7d660ca
Add Map2.
2023-06-29 10:05:06 +01:00
367170da45
Also use Map1 for embedding.
2023-06-29 09:45:27 +01:00
8ad03a5fb6
Use Map1 on unary ops.
2023-06-29 09:37:38 +01:00
fff13dbb4e
Factorize the kernel naming scheme.
2023-06-29 09:29:59 +01:00
d3c7b0d168
Use Map1 for sum.
2023-06-29 09:27:07 +01:00
122e334d0c
Simplify the pattern matching logic in the cuda backend.
2023-06-29 09:21:11 +01:00
6c9e6b5a99
Get the cuda tests to pass.
2023-06-28 15:53:23 +01:00
3f0d9fbb25
Adapt the cuda bits.
2023-06-28 15:43:03 +01:00
e221d38819
Factor the slicing code in cuda.
2023-06-27 15:45:59 +01:00
07a682c2ff
Run the tensor tests for the cuda backend too.
2023-06-27 15:37:01 +01:00
380d61e990
Fix two cuda bugs (matmul and where_cond).
2023-06-27 11:31:04 +01:00
d7f729fb8f
Refactor the hierarchy.
2023-06-27 11:57:27 +02:00