Commit Graph

888 Commits

Author SHA1 Message Date
1064b9b031 Add the cross-entropy loss. (#287) 2023-07-31 14:26:36 +01:00
ffeafbfc43 Make the nll op closer to the pytorch version + add a test. (#286) 2023-07-31 14:14:01 +01:00
b3ea96b62b Add a prompt and support more models in llama2-c. (#285)
* Support more models in llama2-c.

* Add a prompt.
2023-07-31 13:09:30 +01:00
94a43faaca Use the hub models for llama2.c (#284) 2023-07-31 12:51:14 +01:00
62a9b03715 Add a flag to set the number of epochs in the mnist training (#283)
* Add a flag to change the number of epochs for the mnist training.

* Increase the learning rate for the MLP.
2023-07-31 10:32:14 +01:00
67834119fc Fix the flash-attention function names. (#282) 2023-07-31 10:04:39 +01:00
0ace420e66 Flash attention without padding (varlen). (#281)
* Expose the seqlen variable for flash-attn without padding.

* Fix the batched call.

* Adapt for the varlen variant.

* No need to set the batch strides when in varlen mode.

* Add a test (disabled at the moment).

* Get the test to work properly.
2023-07-31 09:45:39 +01:00
a8d8f9f206 Load a trained checkpoint in the mnist example. (#280) 2023-07-30 17:01:45 +01:00
38ff693af0 Add a flag to save the trained weights. (#279) 2023-07-30 15:41:42 +01:00
ba2254556c Display the temperature being used for text generation. (#278) 2023-07-30 09:53:05 +01:00
c950a5c6b1 Cuda support for the mnist training. (#277)
* Cuda support for the mnist training.

* min/max fix + testing.

* Add the argmin/argmax tests.

* More cuda support for argmin/argmax.

* Cuda kernels for argmin and argmax.
2023-07-29 19:48:04 +01:00
16c33383eb Improve the mnist training example. (#276)
* Improve the mnist training example.

* Add some initialization routine that can be used for nn.

* Proper initialization in the mnist example.
2023-07-29 16:28:22 +01:00
bedcef64dc Merge pull request #262 from LaurentMazare/update_multiprocess
Making multiprocess require flash-attn.
2023-07-29 16:40:39 +02:00
40c80bfbb2 Merge branch 'main' into update_multiprocess 2023-07-29 16:38:35 +02:00
07eb899729 More mnist training. (#275) 2023-07-29 13:29:31 +01:00
c0a8ed19eb Support for where-cond on cuda for u8 and u32. (#274) 2023-07-29 11:48:58 +01:00
4bf2ebf836 Use u8 tensors for masks. (#273) 2023-07-29 11:32:58 +01:00
97d8712ba5 Remove single function. 2023-07-28 23:31:25 +02:00
97181a77c0 Making multiprocess require flash-attn. 2023-07-28 23:31:24 +02:00
50d8273ae4 Support both llama v1 and llama v2. (#272) 2023-07-28 18:40:59 +01:00
7513a5e005 Line-up the llama implementation with the python-transformers one. (#271)
* Line-up the llama implementation with the python-transformers one.

* Also lineup the multiprocess version.
2023-07-28 18:31:28 +01:00
cb8dd5cd53 Back to using the main branch now that the PR has been merged. (#270) 2023-07-28 16:22:44 +01:00
a0e47aba98 Fix the revision used in starcoder to use the safetensors PR. (#269) 2023-07-28 14:02:31 +01:00
fb84ead8f7 Add the starcoder example to the readme. (#268)
* Add the starcoder example to the readme.

* Tweak.
2023-07-28 13:26:23 +01:00
3eb2bc6d07 Softmax numerical stability. (#267)
* Softmax numerical stability.

* Fix the flash-attn test.
2023-07-28 13:13:01 +01:00
68eab38de6 Cuda fix for starcoder. (#266)
* Cuda fix for starcoder.

* Nicer output.
2023-07-28 12:13:41 +01:00
54ccf94472 Merge pull request #265 from LaurentMazare/fix_nccl
Fix nccl
2023-07-28 11:37:58 +01:00
4002968cf5 Put back `"dep:half" 2023-07-28 10:34:21 +00:00
be256a6ba6 Fixing. 2023-07-28 10:23:05 +00:00
d2dea11ef6 Fixing nccl feature. 2023-07-28 12:19:20 +02:00
3e89df938c Starcoder fix (#264)
* Bugfix for starcoder.

* Get some proper code generation.

* Slightly simpler softmax.
2023-07-28 11:17:49 +01:00
6a54ca115e Add some Bigcode model (#260)
* Start sketching the bigcode gpt model.

* Sketch the bigcode model.

* Implement the attention mechanism.

* Random reshaping.

* Sketch more of the example.

* Add some kv cache.

* Properly generate the position ids.

* Proper attention mask.

* Bail on upcasting.

* Properly apply the attention mask.

* Add the smaller starcoder variants.

* Update for the new hub api.

* Fix a shape issue.

* Fix another shape issue.

* Get some logits out.

* Adjust the weigth names.
2023-07-28 09:57:32 +01:00
4f260ef025 Merge pull request #216 from LaurentMazare/llama_multiprocess2
TP sharding v2
2023-07-28 08:06:13 +01:00
0b97987b21 Merge pull request #261 from LaurentMazare/upgrade_hf_hub
Upgrading hf-hub to `0.2.0` (Modified API to not pass the Repo around all the time)
2023-07-28 07:03:30 +01:00
8435a99edd Added comment about offsets. 2023-07-27 20:11:57 +02:00
ca479a873e Upgrading hf-hub to 0.2.0 (Modified API to not pass the Repo around
all the time)
2023-07-27 20:05:02 +02:00
952eca6b54 Fixing slice errors + comments. 2023-07-27 16:59:32 +02:00
f291065f6c Do not panic on empty ranges. (#257) 2023-07-27 09:28:47 +01:00
25a2086e8f Putting back Send + Sync 2023-07-27 09:58:47 +02:00
7c7e6ba201 Removing inner dependency on safetensors. 2023-07-27 09:58:47 +02:00
1553b58fe5 Tensor are not necessarily sendable (CustomOp1). 2023-07-27 09:58:47 +02:00
b7814f66b4 PyO3 is back. 2023-07-27 09:58:47 +02:00
ed58de7551 Fixed TP sharded version. 2023-07-27 09:58:46 +02:00
1735e4831e TP sharding v2 2023-07-27 09:58:14 +02:00
209f06d7c3 Micro-cleanup. (#256) 2023-07-27 07:55:54 +01:00
6475bfadfe Simplify Tensor::randn. (#255)
* Simplify Tensor::randn.

* Also switch Tensor::rand to use a generic dtype.

* Support sampling for f16.

* Cleanup.
2023-07-27 07:40:36 +01:00
89ba005962 Support backprop for a few more ops. (#254) 2023-07-26 21:31:54 +01:00
4f92420132 Add some flash attn test (#253)
* Add some flash-attn test.

* Add the cpu test.

* Fail when the head is not a multiple of 8.

* Polish the flash attention test.
2023-07-26 20:56:00 +01:00
ded197497c Merge pull request #252 from LaurentMazare/add_book
Adding a cargo book
2023-07-26 17:35:54 +01:00
84ad558e50 Switch to using llama-v2 by default. (#251) 2023-07-26 17:18:27 +01:00