Commit Graph

22 Commits

Author SHA1 Message Date
59a59f41a6 Add the cuda mode to llama. 2023-06-26 10:06:44 +01:00
d867155ef2 Load the weights for llama. 2023-06-26 07:23:59 +01:00
7a3101f15f Llama bugfix. 2023-06-26 07:07:56 +01:00
97424289d1 Fix the llama causal mask inversion. 2023-06-25 21:16:54 +01:00
117f014b55 Add where_cond and properly apply the causal mask. 2023-06-25 21:08:03 +01:00
25bcad290e Fix the causal mask computation. 2023-06-25 20:19:30 +01:00
8e404eb125 Get a some first inference to work on llama. 2023-06-25 18:26:15 +01:00
87c5aab005 More llama fixes. 2023-06-25 18:08:41 +01:00
60a5598c8b Fix some shape errors. 2023-06-25 17:56:59 +01:00
817e4b5005 Rework the embeddings so that it works on non-contiguous weights + factor out some code. 2023-06-25 17:37:47 +01:00
334524e2c4 Take as input slices of tensors as well as slices of &Tensors. 2023-06-25 17:07:09 +01:00
118cc30908 Add some currently broken tests. 2023-06-25 14:55:25 +01:00
90c140ff4b Start sketching the llama example. 2023-06-25 13:51:20 +01:00
6463d661d8 Tweaks. 2023-06-22 20:25:14 +01:00
aebffcfc13 Add a matmul cuda example. 2023-06-22 19:44:26 +01:00
5276755fb3 Add cuda support for unary ops. 2023-06-22 15:12:59 +01:00
e1eb86db61 Add some first binary op (add). 2023-06-22 13:52:02 +01:00
87a37b3bf3 Retrieve data from the gpu. 2023-06-22 11:01:49 +01:00
97d9142dee Add a first kernel. 2023-06-21 20:48:22 +01:00
fcb4e6b84f Use a reference for the device. 2023-06-21 19:55:57 +01:00
c654ecdb16 Add a specific example for cuda. 2023-06-21 18:56:04 +01:00
b3eb57cd0a Avoid some duplication using a macro + add some basic example to make debugging easier. 2023-06-21 10:08:41 +01:00