Commit Graph

23 Commits

Author SHA1 Message Date
2c3d871b2e Add a simpler way to specify the dim index for some ops. 2023-07-05 20:22:43 +01:00
6d1e79d378 Bugfix for to_scalar (use the proper start offset). 2023-07-05 06:42:29 +01:00
950b4af49e Proper conv1d dispatch. 2023-07-04 11:29:28 +01:00
a424d95473 Add more of the conv1d op. 2023-07-04 11:15:45 +01:00
3aac1047fe Sketch the conv1d op. 2023-07-04 10:52:34 +01:00
19cbbc5212 Improve how we check that the dims are in bounds. 2023-06-30 09:11:00 +01:00
b50bd880ce Only narrow when needed + deactivate the kv cache. 2023-06-29 19:07:52 +01:00
2741b39ad3 Use broadcasted scalars for const tensors. 2023-06-29 11:56:40 +01:00
122e334d0c Simplify the pattern matching logic in the cuda backend. 2023-06-29 09:21:11 +01:00
3f0d9fbb25 Adapt the cuda bits. 2023-06-28 15:43:03 +01:00
caafef6cc1 Get the cpu tests to run. 2023-06-28 14:32:02 +01:00
14449ff80c Get the cpu backend to compile. 2023-06-28 14:12:38 +01:00
303b853098 Propagate the layout refactoring. 2023-06-28 13:42:23 +01:00
30b355ccd2 Simplify the narrow implementation. 2023-06-28 13:09:59 +01:00
c1bbbf94f6 Start refactoring the stride. 2023-06-28 12:57:30 +01:00
615196e7be Add more gradients. 2023-06-28 09:59:52 +01:00
1ce3843cab Add the relu op. 2023-06-28 09:38:54 +01:00
b0f5f2d22d Add some display tests + bugfixes. 2023-06-27 21:37:28 +01:00
934655a60d Add squeeze/unsqueeze/stack. 2023-06-27 19:32:00 +01:00
1d504cc6b3 Rework the debug trait. 2023-06-27 19:10:30 +01:00
684f66326d Add the get method. 2023-06-27 17:39:58 +01:00
c44e5346f4 Add some helper functions. 2023-06-27 17:37:09 +01:00
d7f729fb8f Refactor the hierarchy. 2023-06-27 11:57:27 +02:00