cc76c63202
Use index-select for the embeddings as it supports backprop. ( #298 )
2023-08-01 20:44:43 +01:00
ff876c2103
Llama more training ( #297 )
...
* Rework the var-builder to handle initializations.
* Add some helper functions for layer creation.
* Improve the layer initializations.
* Get initialized variables.
* Precompute the rot embeddings when training lamas.
2023-08-01 19:53:41 +01:00
a27239f3d9
Add training for the llama2.c example ( #296 )
...
* Rework the commands and run inference by default.
* Add the training module and load the training dataset.
* Random dataset iterator.
* Proper valid-loss computation.
* Compute the evaluation loss.
* Add more substance to the training loop.
2023-08-01 17:23:07 +01:00
babee9f011
Merge pull request #259 from LaurentMazare/book_2
...
Book 2 (load/save)
2023-08-01 17:26:57 +02:00
afb5e24a63
Remove map ownership from save
.
2023-08-01 17:19:22 +02:00
89d1fd03e5
Adding new surface for savetensors (global load, global save).
2023-08-01 15:00:38 +02:00
310094310b
Modifying safetensors
export to get simple load and save.
2023-08-01 15:00:38 +02:00
836ba3e090
Merge pull request #258 from LaurentMazare/start_book
...
Starting the book.
2023-08-01 14:59:34 +02:00
091e781977
Grammarly pass.
2023-08-01 14:26:02 +02:00
5cead227ef
Adressed comments.
2023-08-01 14:26:02 +02:00
ebd0315623
Typo.
2023-08-01 14:26:02 +02:00
ad9d8fe400
Complexifying our hello world
2023-08-01 14:26:02 +02:00
5bc5716b85
Revert "Making sure the CI actually works"
...
This reverts commit 699346b603cec1f279d94e9aa3210c193ba973f8.
2023-08-01 14:26:02 +02:00
ba37de94d4
Making sure the CI actually works
2023-08-01 14:26:02 +02:00
6242a1470e
Starting the book.
2023-08-01 14:26:02 +02:00
75e0448114
Move the weight bits in a separate module. ( #295 )
2023-08-01 10:37:06 +01:00
614f911e9e
Add some batcher variants that handle errors. ( #294 )
2023-08-01 09:40:34 +01:00
e1e8127f15
Add the batcher. ( #293 )
2023-08-01 09:16:10 +01:00
fa98ca0c35
Use subcommands in llama2. ( #292 )
2023-08-01 05:57:41 +01:00
1a07ff8d17
Pre-tokenized evaluation mode for llama2.c. ( #291 )
2023-08-01 05:36:25 +01:00
f28558d0b7
Evaluate on the pre-tokenized file. ( #290 )
2023-07-31 21:31:38 +01:00
6b98b66eb3
Remove the end of text tokens. ( #289 )
2023-07-31 20:43:57 +01:00
9ae1f6afee
Add an eval mode to llama2-c ( #288 )
...
* Add an eval mode to llama2-c.
* Encode line by line.
* Get the eval to run.
2023-07-31 17:22:14 +01:00
1064b9b031
Add the cross-entropy loss. ( #287 )
2023-07-31 14:26:36 +01:00
ffeafbfc43
Make the nll op closer to the pytorch version + add a test. ( #286 )
2023-07-31 14:14:01 +01:00
b3ea96b62b
Add a prompt and support more models in llama2-c. ( #285 )
...
* Support more models in llama2-c.
* Add a prompt.
2023-07-31 13:09:30 +01:00
94a43faaca
Use the hub models for llama2.c ( #284 )
2023-07-31 12:51:14 +01:00
62a9b03715
Add a flag to set the number of epochs in the mnist training ( #283 )
...
* Add a flag to change the number of epochs for the mnist training.
* Increase the learning rate for the MLP.
2023-07-31 10:32:14 +01:00
67834119fc
Fix the flash-attention function names. ( #282 )
2023-07-31 10:04:39 +01:00
0ace420e66
Flash attention without padding (varlen). ( #281 )
...
* Expose the seqlen variable for flash-attn without padding.
* Fix the batched call.
* Adapt for the varlen variant.
* No need to set the batch strides when in varlen mode.
* Add a test (disabled at the moment).
* Get the test to work properly.
2023-07-31 09:45:39 +01:00
a8d8f9f206
Load a trained checkpoint in the mnist example. ( #280 )
2023-07-30 17:01:45 +01:00
38ff693af0
Add a flag to save the trained weights. ( #279 )
2023-07-30 15:41:42 +01:00
ba2254556c
Display the temperature being used for text generation. ( #278 )
2023-07-30 09:53:05 +01:00
c950a5c6b1
Cuda support for the mnist training. ( #277 )
...
* Cuda support for the mnist training.
* min/max fix + testing.
* Add the argmin/argmax tests.
* More cuda support for argmin/argmax.
* Cuda kernels for argmin and argmax.
2023-07-29 19:48:04 +01:00
16c33383eb
Improve the mnist training example. ( #276 )
...
* Improve the mnist training example.
* Add some initialization routine that can be used for nn.
* Proper initialization in the mnist example.
2023-07-29 16:28:22 +01:00
bedcef64dc
Merge pull request #262 from LaurentMazare/update_multiprocess
...
Making multiprocess require flash-attn.
2023-07-29 16:40:39 +02:00
40c80bfbb2
Merge branch 'main' into update_multiprocess
2023-07-29 16:38:35 +02:00
07eb899729
More mnist training. ( #275 )
2023-07-29 13:29:31 +01:00
c0a8ed19eb
Support for where-cond on cuda for u8 and u32. ( #274 )
2023-07-29 11:48:58 +01:00
4bf2ebf836
Use u8 tensors for masks. ( #273 )
2023-07-29 11:32:58 +01:00
97d8712ba5
Remove single function.
2023-07-28 23:31:25 +02:00
97181a77c0
Making multiprocess require flash-attn.
2023-07-28 23:31:24 +02:00
50d8273ae4
Support both llama v1 and llama v2. ( #272 )
2023-07-28 18:40:59 +01:00
7513a5e005
Line-up the llama implementation with the python-transformers one. ( #271 )
...
* Line-up the llama implementation with the python-transformers one.
* Also lineup the multiprocess version.
2023-07-28 18:31:28 +01:00
cb8dd5cd53
Back to using the main branch now that the PR has been merged. ( #270 )
2023-07-28 16:22:44 +01:00
a0e47aba98
Fix the revision used in starcoder to use the safetensors PR. ( #269 )
2023-07-28 14:02:31 +01:00
fb84ead8f7
Add the starcoder example to the readme. ( #268 )
...
* Add the starcoder example to the readme.
* Tweak.
2023-07-28 13:26:23 +01:00
3eb2bc6d07
Softmax numerical stability. ( #267 )
...
* Softmax numerical stability.
* Fix the flash-attn test.
2023-07-28 13:13:01 +01:00
68eab38de6
Cuda fix for starcoder. ( #266 )
...
* Cuda fix for starcoder.
* Nicer output.
2023-07-28 12:13:41 +01:00
54ccf94472
Merge pull request #265 from LaurentMazare/fix_nccl
...
Fix nccl
2023-07-28 11:37:58 +01:00