Commit Graph

792 Commits

Author SHA1 Message Date
367170da45 Also use Map1 for embedding. 2023-06-29 09:45:27 +01:00
8ad03a5fb6 Use Map1 on unary ops. 2023-06-29 09:37:38 +01:00
fff13dbb4e Factorize the kernel naming scheme. 2023-06-29 09:29:59 +01:00
d3c7b0d168 Use Map1 for sum. 2023-06-29 09:27:07 +01:00
122e334d0c Simplify the pattern matching logic in the cuda backend. 2023-06-29 09:21:11 +01:00
eaa3ce359e Cosmetic change. 2023-06-28 22:02:23 +01:00
1328b5cb20 Factor some code out. 2023-06-28 21:56:44 +01:00
c583ee0f2c Add map2. 2023-06-28 21:38:01 +01:00
46c07b924c Tweak some comment. 2023-06-28 21:10:54 +01:00
2ae368e98e Switch from a macro to a trait to make things more generic. 2023-06-28 21:06:56 +01:00
ece3ec6167 Final updates -> moving to deterministic for easier comparison. 2023-06-28 14:56:39 +00:00
926fffa0b7 Ok. 2023-06-28 14:56:39 +00:00
e29dae044d Tmp. 2023-06-28 14:56:38 +00:00
6c9e6b5a99 Get the cuda tests to pass. 2023-06-28 15:53:23 +01:00
3f0d9fbb25 Adapt the cuda bits. 2023-06-28 15:43:03 +01:00
cca699be6c Fix some cpu issue. 2023-06-28 15:09:15 +01:00
1c755c0e5b Remove some todos. 2023-06-28 14:33:06 +01:00
caafef6cc1 Get the cpu tests to run. 2023-06-28 14:32:02 +01:00
14449ff80c Get the cpu backend to compile. 2023-06-28 14:12:38 +01:00
54a6c40f27 Propagate the changes on the cpu backend. 2023-06-28 14:00:49 +01:00
303b853098 Propagate the layout refactoring. 2023-06-28 13:42:23 +01:00
30b355ccd2 Simplify the narrow implementation. 2023-06-28 13:09:59 +01:00
c1bbbf94f6 Start refactoring the stride. 2023-06-28 12:57:30 +01:00
7938d2b848 Add the grad for narrow. 2023-06-28 10:46:00 +01:00
615196e7be Add more gradients. 2023-06-28 09:59:52 +01:00
1ce3843cab Add the relu op. 2023-06-28 09:38:54 +01:00
19183b8e4f Factor out the gemm bits. 2023-06-28 08:51:13 +01:00
0417d9cec8 Add more cuda testing again. 2023-06-28 08:33:43 +01:00
395c84e80a Also run the backprop tests on cuda. 2023-06-28 08:15:03 +01:00
b0f5f2d22d Add some display tests + bugfixes. 2023-06-27 21:37:28 +01:00
8c81a70170 PyTorch like display implementation. 2023-06-27 21:16:35 +01:00
934655a60d Add squeeze/unsqueeze/stack. 2023-06-27 19:32:00 +01:00
1d504cc6b3 Rework the debug trait. 2023-06-27 19:10:30 +01:00
684f66326d Add the get method. 2023-06-27 17:39:58 +01:00
c44e5346f4 Add some helper functions. 2023-06-27 17:37:09 +01:00
dbe3e4e7c0 Add some test utils module. 2023-06-27 16:20:28 +01:00
e221d38819 Factor the slicing code in cuda. 2023-06-27 15:45:59 +01:00
07a682c2ff Run the tensor tests for the cuda backend too. 2023-06-27 15:37:01 +01:00
ca6aa8ff12 Use num-cpus to enable parallelism. 2023-06-27 14:42:26 +01:00
318503cd38 Cache the causal mask in llama. 2023-06-27 12:21:08 +01:00
380d61e990 Fix two cuda bugs (matmul and where_cond). 2023-06-27 11:31:04 +01:00
d7f729fb8f Refactor the hierarchy. 2023-06-27 11:57:27 +02:00