62a9b03715
Add a flag to set the number of epochs in the mnist training ( #283 )
...
* Add a flag to change the number of epochs for the mnist training.
* Increase the learning rate for the MLP.
2023-07-31 10:32:14 +01:00
67834119fc
Fix the flash-attention function names. ( #282 )
2023-07-31 10:04:39 +01:00
0ace420e66
Flash attention without padding (varlen). ( #281 )
...
* Expose the seqlen variable for flash-attn without padding.
* Fix the batched call.
* Adapt for the varlen variant.
* No need to set the batch strides when in varlen mode.
* Add a test (disabled at the moment).
* Get the test to work properly.
2023-07-31 09:45:39 +01:00
a8d8f9f206
Load a trained checkpoint in the mnist example. ( #280 )
2023-07-30 17:01:45 +01:00
38ff693af0
Add a flag to save the trained weights. ( #279 )
2023-07-30 15:41:42 +01:00
ba2254556c
Display the temperature being used for text generation. ( #278 )
2023-07-30 09:53:05 +01:00
c950a5c6b1
Cuda support for the mnist training. ( #277 )
...
* Cuda support for the mnist training.
* min/max fix + testing.
* Add the argmin/argmax tests.
* More cuda support for argmin/argmax.
* Cuda kernels for argmin and argmax.
2023-07-29 19:48:04 +01:00
16c33383eb
Improve the mnist training example. ( #276 )
...
* Improve the mnist training example.
* Add some initialization routine that can be used for nn.
* Proper initialization in the mnist example.
2023-07-29 16:28:22 +01:00
bedcef64dc
Merge pull request #262 from LaurentMazare/update_multiprocess
...
Making multiprocess require flash-attn.
2023-07-29 16:40:39 +02:00
40c80bfbb2
Merge branch 'main' into update_multiprocess
2023-07-29 16:38:35 +02:00
07eb899729
More mnist training. ( #275 )
2023-07-29 13:29:31 +01:00
c0a8ed19eb
Support for where-cond on cuda for u8 and u32. ( #274 )
2023-07-29 11:48:58 +01:00
4bf2ebf836
Use u8 tensors for masks. ( #273 )
2023-07-29 11:32:58 +01:00
97d8712ba5
Remove single function.
2023-07-28 23:31:25 +02:00
97181a77c0
Making multiprocess require flash-attn.
2023-07-28 23:31:24 +02:00
50d8273ae4
Support both llama v1 and llama v2. ( #272 )
2023-07-28 18:40:59 +01:00
7513a5e005
Line-up the llama implementation with the python-transformers one. ( #271 )
...
* Line-up the llama implementation with the python-transformers one.
* Also lineup the multiprocess version.
2023-07-28 18:31:28 +01:00
cb8dd5cd53
Back to using the main branch now that the PR has been merged. ( #270 )
2023-07-28 16:22:44 +01:00
a0e47aba98
Fix the revision used in starcoder to use the safetensors PR. ( #269 )
2023-07-28 14:02:31 +01:00
fb84ead8f7
Add the starcoder example to the readme. ( #268 )
...
* Add the starcoder example to the readme.
* Tweak.
2023-07-28 13:26:23 +01:00
3eb2bc6d07
Softmax numerical stability. ( #267 )
...
* Softmax numerical stability.
* Fix the flash-attn test.
2023-07-28 13:13:01 +01:00
68eab38de6
Cuda fix for starcoder. ( #266 )
...
* Cuda fix for starcoder.
* Nicer output.
2023-07-28 12:13:41 +01:00
54ccf94472
Merge pull request #265 from LaurentMazare/fix_nccl
...
Fix nccl
2023-07-28 11:37:58 +01:00
4002968cf5
Put back `"dep:half"
2023-07-28 10:34:21 +00:00
be256a6ba6
Fixing.
2023-07-28 10:23:05 +00:00
d2dea11ef6
Fixing nccl feature.
2023-07-28 12:19:20 +02:00
3e89df938c
Starcoder fix ( #264 )
...
* Bugfix for starcoder.
* Get some proper code generation.
* Slightly simpler softmax.
2023-07-28 11:17:49 +01:00
6a54ca115e
Add some Bigcode model ( #260 )
...
* Start sketching the bigcode gpt model.
* Sketch the bigcode model.
* Implement the attention mechanism.
* Random reshaping.
* Sketch more of the example.
* Add some kv cache.
* Properly generate the position ids.
* Proper attention mask.
* Bail on upcasting.
* Properly apply the attention mask.
* Add the smaller starcoder variants.
* Update for the new hub api.
* Fix a shape issue.
* Fix another shape issue.
* Get some logits out.
* Adjust the weigth names.
2023-07-28 09:57:32 +01:00
4f260ef025
Merge pull request #216 from LaurentMazare/llama_multiprocess2
...
TP sharding v2
2023-07-28 08:06:13 +01:00
0b97987b21
Merge pull request #261 from LaurentMazare/upgrade_hf_hub
...
Upgrading hf-hub to `0.2.0` (Modified API to not pass the Repo around all the time)
2023-07-28 07:03:30 +01:00
8435a99edd
Added comment about offsets.
2023-07-27 20:11:57 +02:00
ca479a873e
Upgrading hf-hub to 0.2.0
(Modified API to not pass the Repo around
...
all the time)
2023-07-27 20:05:02 +02:00
952eca6b54
Fixing slice errors + comments.
2023-07-27 16:59:32 +02:00
f291065f6c
Do not panic on empty ranges. ( #257 )
2023-07-27 09:28:47 +01:00
25a2086e8f
Putting back Send + Sync
2023-07-27 09:58:47 +02:00
7c7e6ba201
Removing inner dependency on safetensors.
2023-07-27 09:58:47 +02:00
1553b58fe5
Tensor are not necessarily sendable (CustomOp1).
2023-07-27 09:58:47 +02:00
b7814f66b4
PyO3 is back.
2023-07-27 09:58:47 +02:00
ed58de7551
Fixed TP sharded version.
2023-07-27 09:58:46 +02:00
1735e4831e
TP sharding v2
2023-07-27 09:58:14 +02:00
209f06d7c3
Micro-cleanup. ( #256 )
2023-07-27 07:55:54 +01:00
6475bfadfe
Simplify Tensor::randn. ( #255 )
...
* Simplify Tensor::randn.
* Also switch Tensor::rand to use a generic dtype.
* Support sampling for f16.
* Cleanup.
2023-07-27 07:40:36 +01:00
89ba005962
Support backprop for a few more ops. ( #254 )
2023-07-26 21:31:54 +01:00
4f92420132
Add some flash attn test ( #253 )
...
* Add some flash-attn test.
* Add the cpu test.
* Fail when the head is not a multiple of 8.
* Polish the flash attention test.
2023-07-26 20:56:00 +01:00
ded197497c
Merge pull request #252 from LaurentMazare/add_book
...
Adding a cargo book
2023-07-26 17:35:54 +01:00
84ad558e50
Switch to using llama-v2 by default. ( #251 )
2023-07-26 17:18:27 +01:00
368f169c6a
Permissions.
2023-07-26 18:12:02 +02:00
8da6568c20
Typo.
2023-07-26 18:11:10 +02:00
07a22fe606
Releasing within the branch to test the setup.
2023-07-26 18:08:34 +02:00
834e1b197b
Adding a documentation book.
2023-07-26 18:06:31 +02:00