Commit Graph

1419 Commits

Author SHA1 Message Date
95a2c8e7da Add helper functions for fortran contiguous data. 2023-06-26 13:02:06 +01:00
f6104c4b64 Add the reduce-sum kernel. 2023-06-26 12:35:26 +01:00
16f0f5b9d2 Add a cuda kernel for embeddings. 2023-06-26 11:47:57 +01:00
5952c3fa91 Cleanup the broadcast setup. 2023-06-26 10:49:34 +01:00
217bdcdf4d Fix the error message. 2023-06-26 10:14:34 +01:00
59a59f41a6 Add the cuda mode to llama. 2023-06-26 10:06:44 +01:00
512d12e38d Avoid copying the data around when loading weights. 2023-06-26 08:09:03 +01:00
4ad5d17d8c Slightly more efficient weight loading. 2023-06-26 07:56:25 +01:00
11696e6377 Faster model weight loading. 2023-06-26 07:40:11 +01:00
d867155ef2 Load the weights for llama. 2023-06-26 07:23:59 +01:00
7a3101f15f Llama bugfix. 2023-06-26 07:07:56 +01:00
97424289d1 Fix the llama causal mask inversion. 2023-06-25 21:16:54 +01:00
117f014b55 Add where_cond and properly apply the causal mask. 2023-06-25 21:08:03 +01:00
25bcad290e Fix the causal mask computation. 2023-06-25 20:19:30 +01:00
8e404eb125 Get a some first inference to work on llama. 2023-06-25 18:26:15 +01:00
87c5aab005 More llama fixes. 2023-06-25 18:08:41 +01:00
60a5598c8b Fix some shape errors. 2023-06-25 17:56:59 +01:00
817e4b5005 Rework the embeddings so that it works on non-contiguous weights + factor out some code. 2023-06-25 17:37:47 +01:00
334524e2c4 Take as input slices of tensors as well as slices of &Tensors. 2023-06-25 17:07:09 +01:00
8b67f294e8 Fix the cat implementation + more testing. 2023-06-25 15:32:13 +01:00
118cc30908 Add some currently broken tests. 2023-06-25 14:55:25 +01:00
bb6450ebbb Bugfix for Tensor::cat + add some tests. 2023-06-25 14:20:42 +01:00
90c140ff4b Start sketching the llama example. 2023-06-25 13:51:20 +01:00
a9c113248a Take references as input for Tensor::cat. 2023-06-25 13:03:05 +01:00
5e03a1bc29 One more test. 2023-06-25 10:57:46 +01:00
ba0693a908 Fix the reduce_sum implementation and add some tests. 2023-06-25 10:55:04 +01:00
0f369dd870 Add the cpu implementation for reduce_sum. 2023-06-25 10:37:04 +01:00
3852a85af3 Boilerplate code for the sum operator. 2023-06-25 09:35:17 +01:00
7ccf27dda2 More general broadcast setup. 2023-06-25 08:55:09 +01:00
213445c0e5 Move the backprop bits to a separate file. 2023-06-24 20:57:49 +01:00
0988706c88 Support wider shapes for llama. 2023-06-24 20:08:18 +01:00
6b2cd9c51c Add the broadcast operator. 2023-06-24 19:16:03 +01:00
96c098b6cd Remove the unecessary features. 2023-06-24 18:15:44 +01:00
a7f80e258f Read and write npy files. 2023-06-24 18:12:10 +01:00
a6ca9baf3c Backprop for narrow. 2023-06-24 15:17:57 +01:00
fbbf3951dd More narrow testing. 2023-06-24 15:10:31 +01:00
0f34738831 Fix the cpu implementation for narrow. 2023-06-24 15:01:32 +01:00
1b5f892d73 Add a currently wrong test for narrow. 2023-06-24 08:50:37 +01:00
d6cb4f1c53 Add the source offset when copying the data around. 2023-06-24 08:35:49 +01:00
4db972781f Handle copying for the u32 type. 2023-06-24 08:24:06 +01:00
dd657397b2 Skeleton implementation for the narrow method and op. 2023-06-24 08:17:35 +01:00
3deacba5f9 Reshape can now return a view. 2023-06-24 07:14:09 +01:00
47f9c48e7c Avoid duplicating the storage by refcounting it. 2023-06-24 07:03:21 +01:00
b4653e41be Helper function to build 3d arrays. 2023-06-24 06:29:06 +01:00
ae5dc5fbc6 Softmax tests + fix. 2023-06-23 22:46:36 +01:00
d0a91db8fd Softmax cpu implementation. 2023-06-23 22:26:53 +01:00
8443963d4f Skeleton implementation for softmax. 2023-06-23 22:00:13 +01:00
5d44e76e3f Add the casting operation. 2023-06-23 21:22:07 +01:00
8ed350dc94 Add a couple unitary ops. 2023-06-23 20:19:20 +01:00
fe75a01188 Cleanup the tensor creation code. 2023-06-23 19:52:21 +01:00