Commit Graph

17 Commits

Author SHA1 Message Date
3aac1047fe Sketch the conv1d op. 2023-07-04 10:52:34 +01:00
a57b314780 Add a batch dimension on the bert example. 2023-07-04 06:10:52 +01:00
86d691c74c Better handling of the batch dimension in matmul. 2023-07-03 22:51:40 +01:00
8ad47907f3 Add the kernels. 2023-06-30 10:26:56 +01:00
c9c468e1aa Use Map2 for binary ops. 2023-06-29 10:09:15 +01:00
83c7d660ca Add Map2. 2023-06-29 10:05:06 +01:00
367170da45 Also use Map1 for embedding. 2023-06-29 09:45:27 +01:00
8ad03a5fb6 Use Map1 on unary ops. 2023-06-29 09:37:38 +01:00
fff13dbb4e Factorize the kernel naming scheme. 2023-06-29 09:29:59 +01:00
d3c7b0d168 Use Map1 for sum. 2023-06-29 09:27:07 +01:00
122e334d0c Simplify the pattern matching logic in the cuda backend. 2023-06-29 09:21:11 +01:00
6c9e6b5a99 Get the cuda tests to pass. 2023-06-28 15:53:23 +01:00
3f0d9fbb25 Adapt the cuda bits. 2023-06-28 15:43:03 +01:00
e221d38819 Factor the slicing code in cuda. 2023-06-27 15:45:59 +01:00
07a682c2ff Run the tensor tests for the cuda backend too. 2023-06-27 15:37:01 +01:00
380d61e990 Fix two cuda bugs (matmul and where_cond). 2023-06-27 11:31:04 +01:00
d7f729fb8f Refactor the hierarchy. 2023-06-27 11:57:27 +02:00