Commit Graph

173 Commits

Author SHA1 Message Date
51e51da896 Rename the candle crate to candle-core (#301)
* Rename to candle-core.

* More candle-core renaming.
2023-08-02 08:20:22 +01:00
4b3bd79fbd Remove the embedding ops in favor of index-select. (#299)
* Remove the embedding ops in favor of index-select.

* Also remove the cuda kernels.
2023-08-02 05:42:11 +01:00
ff876c2103 Llama more training (#297)
* Rework the var-builder to handle initializations.

* Add some helper functions for layer creation.

* Improve the layer initializations.

* Get initialized variables.

* Precompute the rot embeddings when training lamas.
2023-08-01 19:53:41 +01:00
a27239f3d9 Add training for the llama2.c example (#296)
* Rework the commands and run inference by default.

* Add the training module and load the training dataset.

* Random dataset iterator.

* Proper valid-loss computation.

* Compute the evaluation loss.

* Add more substance to the training loop.
2023-08-01 17:23:07 +01:00
75e0448114 Move the weight bits in a separate module. (#295) 2023-08-01 10:37:06 +01:00
614f911e9e Add some batcher variants that handle errors. (#294) 2023-08-01 09:40:34 +01:00
e1e8127f15 Add the batcher. (#293) 2023-08-01 09:16:10 +01:00
fa98ca0c35 Use subcommands in llama2. (#292) 2023-08-01 05:57:41 +01:00
1a07ff8d17 Pre-tokenized evaluation mode for llama2.c. (#291) 2023-08-01 05:36:25 +01:00
f28558d0b7 Evaluate on the pre-tokenized file. (#290) 2023-07-31 21:31:38 +01:00
6b98b66eb3 Remove the end of text tokens. (#289) 2023-07-31 20:43:57 +01:00
9ae1f6afee Add an eval mode to llama2-c (#288)
* Add an eval mode to llama2-c.

* Encode line by line.

* Get the eval to run.
2023-07-31 17:22:14 +01:00
ffeafbfc43 Make the nll op closer to the pytorch version + add a test. (#286) 2023-07-31 14:14:01 +01:00
b3ea96b62b Add a prompt and support more models in llama2-c. (#285)
* Support more models in llama2-c.

* Add a prompt.
2023-07-31 13:09:30 +01:00
94a43faaca Use the hub models for llama2.c (#284) 2023-07-31 12:51:14 +01:00
62a9b03715 Add a flag to set the number of epochs in the mnist training (#283)
* Add a flag to change the number of epochs for the mnist training.

* Increase the learning rate for the MLP.
2023-07-31 10:32:14 +01:00
a8d8f9f206 Load a trained checkpoint in the mnist example. (#280) 2023-07-30 17:01:45 +01:00
38ff693af0 Add a flag to save the trained weights. (#279) 2023-07-30 15:41:42 +01:00
c950a5c6b1 Cuda support for the mnist training. (#277)
* Cuda support for the mnist training.

* min/max fix + testing.

* Add the argmin/argmax tests.

* More cuda support for argmin/argmax.

* Cuda kernels for argmin and argmax.
2023-07-29 19:48:04 +01:00
16c33383eb Improve the mnist training example. (#276)
* Improve the mnist training example.

* Add some initialization routine that can be used for nn.

* Proper initialization in the mnist example.
2023-07-29 16:28:22 +01:00
40c80bfbb2 Merge branch 'main' into update_multiprocess 2023-07-29 16:38:35 +02:00
07eb899729 More mnist training. (#275) 2023-07-29 13:29:31 +01:00
4bf2ebf836 Use u8 tensors for masks. (#273) 2023-07-29 11:32:58 +01:00
97d8712ba5 Remove single function. 2023-07-28 23:31:25 +02:00
97181a77c0 Making multiprocess require flash-attn. 2023-07-28 23:31:24 +02:00
50d8273ae4 Support both llama v1 and llama v2. (#272) 2023-07-28 18:40:59 +01:00
7513a5e005 Line-up the llama implementation with the python-transformers one. (#271)
* Line-up the llama implementation with the python-transformers one.

* Also lineup the multiprocess version.
2023-07-28 18:31:28 +01:00
cb8dd5cd53 Back to using the main branch now that the PR has been merged. (#270) 2023-07-28 16:22:44 +01:00
a0e47aba98 Fix the revision used in starcoder to use the safetensors PR. (#269) 2023-07-28 14:02:31 +01:00
3eb2bc6d07 Softmax numerical stability. (#267)
* Softmax numerical stability.

* Fix the flash-attn test.
2023-07-28 13:13:01 +01:00
68eab38de6 Cuda fix for starcoder. (#266)
* Cuda fix for starcoder.

* Nicer output.
2023-07-28 12:13:41 +01:00
3e89df938c Starcoder fix (#264)
* Bugfix for starcoder.

* Get some proper code generation.

* Slightly simpler softmax.
2023-07-28 11:17:49 +01:00
6a54ca115e Add some Bigcode model (#260)
* Start sketching the bigcode gpt model.

* Sketch the bigcode model.

* Implement the attention mechanism.

* Random reshaping.

* Sketch more of the example.

* Add some kv cache.

* Properly generate the position ids.

* Proper attention mask.

* Bail on upcasting.

* Properly apply the attention mask.

* Add the smaller starcoder variants.

* Update for the new hub api.

* Fix a shape issue.

* Fix another shape issue.

* Get some logits out.

* Adjust the weigth names.
2023-07-28 09:57:32 +01:00
4f260ef025 Merge pull request #216 from LaurentMazare/llama_multiprocess2
TP sharding v2
2023-07-28 08:06:13 +01:00
ca479a873e Upgrading hf-hub to 0.2.0 (Modified API to not pass the Repo around
all the time)
2023-07-27 20:05:02 +02:00
25a2086e8f Putting back Send + Sync 2023-07-27 09:58:47 +02:00
7c7e6ba201 Removing inner dependency on safetensors. 2023-07-27 09:58:47 +02:00
ed58de7551 Fixed TP sharded version. 2023-07-27 09:58:46 +02:00
1735e4831e TP sharding v2 2023-07-27 09:58:14 +02:00
209f06d7c3 Micro-cleanup. (#256) 2023-07-27 07:55:54 +01:00
84ad558e50 Switch to using llama-v2 by default. (#251) 2023-07-26 17:18:27 +01:00
1235aa2536 Use bail rather than wrapping a string where possible. (#249)
* Use bail rather than wrapping a string where possible.

* Revert the cuda default bit.
2023-07-26 15:42:46 +01:00
f052ba76cb Lining up the flash attn version with the non-flash one. (#248)
* Move the flash-attn function in the proper crate.

* Causality tweak.
2023-07-26 15:11:45 +01:00
8b1d12bead Merge pull request #246 from LaurentMazare/rename_custom_op
Rename exposed ops.
2023-07-26 14:20:29 +01:00
2ce5f12513 Again set a few extra params in flash-attn. (#245)
* Again set a few extra params.

* Use the appropriate kernel sizes.

* Add all the kernel sizes.

* Parallel compiling.

* Reduce the amount of parallelism.

* Add the missing kernel.

* Fix a typo.

* Remove bf16 support for now.
2023-07-26 14:16:37 +01:00
1a5416ec35 Rename exposed ops. 2023-07-26 12:43:19 +02:00
fa2b64d678 Proper flash-attn parameters. (#244)
* Proper flash-attn parameters.

* Set the flash attention parameters.

* Add more validations.

* Setup the o_ flash attn parameters.

* More flash-attn support.

* Set more flash attn parameters.
2023-07-26 10:13:40 +01:00
e40b150bbe Better handling of dtypes in llama. (#243) 2023-07-26 08:28:33 +01:00
d9f9c859af Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.

* More flash attn.

* Set up the flash attn parameters.

* Get things to compile locally.

* Move the flash attention files in a different directory.

* Build the static C library with nvcc.

* Add more flash attention.

* Update the build part.

* Better caching.

* Exclude flash attention from the default workspace.

* Put flash-attn behind a feature gate.

* Get the flash attn kernel to run.

* Move the flags to a more appropriate place.

* Enable flash attention in llama.

* Use flash attention in llama.
2023-07-26 07:48:10 +01:00
550a13a547 Use the binary decoder for llama2.c. (#230)
* Use the binary decoder for llama2.c.

* Add the temperature.

* Formatting tweak.

* Fix the rotary embeddings.
2023-07-24 10:56:08 +01:00