Commit Graph

612 Commits

Author SHA1 Message Date
459e2e1ae3 Properly handle the stride in conv1d. 2023-07-04 15:05:04 +01:00
b3d4d0fd0f Very inefficient conv1d implementation. 2023-07-04 13:50:41 +01:00
950b4af49e Proper conv1d dispatch. 2023-07-04 11:29:28 +01:00
a424d95473 Add more of the conv1d op. 2023-07-04 11:15:45 +01:00
3aac1047fe Sketch the conv1d op. 2023-07-04 10:52:34 +01:00
a57b314780 Add a batch dimension on the bert example. 2023-07-04 06:10:52 +01:00
86d691c74c Better handling of the batch dimension in matmul. 2023-07-03 22:51:40 +01:00
ad52b0377c Add the varbuilder + check shapes. 2023-07-03 15:32:20 +01:00
0b3cc215f1 Address comments. 2023-07-03 13:52:27 +02:00
5bc66c68fa Adding saving capabilities. 2023-07-03 13:39:24 +02:00
81cec86e75 Adding a bit more docs around safety. 2023-07-03 11:55:54 +02:00
899c76de75 Handle more types in safetensors. 2023-07-03 10:09:46 +01:00
783b7054ee Move more safetensors bits to the shared module. 2023-07-03 09:34:08 +01:00
fe2c07e368 Add the ST error. 2023-07-03 08:44:00 +01:00
cf2789fb81 Move some safetensors bits in the candle-core crate. 2023-07-03 08:37:46 +01:00
78871ffe38 Add dtype support. 2023-07-02 20:12:26 +01:00
bbe0c5fbaa Do not use rayon for a single thread (bis). 2023-06-30 18:47:22 +01:00
6b67d25d9f Do not use rayon for a single thread. 2023-06-30 18:46:32 +01:00
fbc329ed85 Add the verbose cpu cast operations. 2023-06-30 10:33:29 +01:00
8ad47907f3 Add the kernels. 2023-06-30 10:26:56 +01:00
19cbbc5212 Improve how we check that the dims are in bounds. 2023-06-30 09:11:00 +01:00
b50bd880ce Only narrow when needed + deactivate the kv cache. 2023-06-29 19:07:52 +01:00
2741b39ad3 Use broadcasted scalars for const tensors. 2023-06-29 11:56:40 +01:00
b4aab7b95f Put more requirements on the withdtype trait. 2023-06-29 11:37:42 +01:00
c9c468e1aa Use Map2 for binary ops. 2023-06-29 10:09:15 +01:00
83c7d660ca Add Map2. 2023-06-29 10:05:06 +01:00
367170da45 Also use Map1 for embedding. 2023-06-29 09:45:27 +01:00
8ad03a5fb6 Use Map1 on unary ops. 2023-06-29 09:37:38 +01:00
fff13dbb4e Factorize the kernel naming scheme. 2023-06-29 09:29:59 +01:00
d3c7b0d168 Use Map1 for sum. 2023-06-29 09:27:07 +01:00
122e334d0c Simplify the pattern matching logic in the cuda backend. 2023-06-29 09:21:11 +01:00
eaa3ce359e Cosmetic change. 2023-06-28 22:02:23 +01:00
1328b5cb20 Factor some code out. 2023-06-28 21:56:44 +01:00
c583ee0f2c Add map2. 2023-06-28 21:38:01 +01:00
46c07b924c Tweak some comment. 2023-06-28 21:10:54 +01:00
2ae368e98e Switch from a macro to a trait to make things more generic. 2023-06-28 21:06:56 +01:00
6c9e6b5a99 Get the cuda tests to pass. 2023-06-28 15:53:23 +01:00
3f0d9fbb25 Adapt the cuda bits. 2023-06-28 15:43:03 +01:00
cca699be6c Fix some cpu issue. 2023-06-28 15:09:15 +01:00
1c755c0e5b Remove some todos. 2023-06-28 14:33:06 +01:00
caafef6cc1 Get the cpu tests to run. 2023-06-28 14:32:02 +01:00
14449ff80c Get the cpu backend to compile. 2023-06-28 14:12:38 +01:00
54a6c40f27 Propagate the changes on the cpu backend. 2023-06-28 14:00:49 +01:00
303b853098 Propagate the layout refactoring. 2023-06-28 13:42:23 +01:00
30b355ccd2 Simplify the narrow implementation. 2023-06-28 13:09:59 +01:00
c1bbbf94f6 Start refactoring the stride. 2023-06-28 12:57:30 +01:00
7938d2b848 Add the grad for narrow. 2023-06-28 10:46:00 +01:00
615196e7be Add more gradients. 2023-06-28 09:59:52 +01:00
1ce3843cab Add the relu op. 2023-06-28 09:38:54 +01:00
19183b8e4f Factor out the gemm bits. 2023-06-28 08:51:13 +01:00