b278834267
Support the Accelerate BLAS on macOS. ( #325 )
...
* Add the accelerate feature.
* Ffi tweaks.
2023-08-05 17:25:24 +01:00
f7b2a0391d
Transpose the weight matrixes for llama2.c. ( #321 )
2023-08-04 13:32:20 +01:00
8b6f5be1cc
Support q5k quantized data. ( #320 )
2023-08-04 09:51:30 +01:00
74845a4dcd
Use the assert! function as it turns out to be const. ( #316 )
2023-08-03 10:03:43 +01:00
aa76b783eb
Q6K dequantization. ( #315 )
2023-08-03 09:31:20 +01:00
25564357f7
Support some ggml quantized types ( #314 )
...
* Add the quantized types for GGML loading.
* Support quantization for Q2K.
* More quantization support.
* Fix some clippy lints.
2023-08-03 09:16:26 +01:00
634700d84a
Use some consts for ggml values. ( #312 )
2023-08-02 22:03:05 +01:00
e635f18eda
Initial support for reading ggml files. ( #311 )
...
* Start adding support for reading ggml files.
* Compute the proper tensor size.
* Print the read tensors.
* Fix file reading.
2023-08-02 21:59:02 +01:00
0902846f25
Add the AdamW optimizer. ( #307 )
...
* Add the AdamW optimizer.
* Add some AdamW test validated against PyTorch.
2023-08-02 14:03:49 +01:00
4fe8a02f88
Update the repo location. ( #305 )
2023-08-02 11:12:18 +01:00
d38943aadc
Add version numbers for all the candle crates ( #303 )
...
* Switch to candle-gemm for the time being.
* Add the missing versions.
2023-08-02 10:52:13 +01:00
51e51da896
Rename the candle crate to candle-core ( #301 )
...
* Rename to candle-core.
* More candle-core renaming.
2023-08-02 08:20:22 +01:00
4b3bd79fbd
Remove the embedding ops in favor of index-select. ( #299 )
...
* Remove the embedding ops in favor of index-select.
* Also remove the cuda kernels.
2023-08-02 05:42:11 +01:00
cc76c63202
Use index-select for the embeddings as it supports backprop. ( #298 )
2023-08-01 20:44:43 +01:00
a27239f3d9
Add training for the llama2.c example ( #296 )
...
* Rework the commands and run inference by default.
* Add the training module and load the training dataset.
* Random dataset iterator.
* Proper valid-loss computation.
* Compute the evaluation loss.
* Add more substance to the training loop.
2023-08-01 17:23:07 +01:00
afb5e24a63
Remove map ownership from save
.
2023-08-01 17:19:22 +02:00
89d1fd03e5
Adding new surface for savetensors (global load, global save).
2023-08-01 15:00:38 +02:00
310094310b
Modifying safetensors
export to get simple load and save.
2023-08-01 15:00:38 +02:00
ad9d8fe400
Complexifying our hello world
2023-08-01 14:26:02 +02:00
6b98b66eb3
Remove the end of text tokens. ( #289 )
2023-07-31 20:43:57 +01:00
38ff693af0
Add a flag to save the trained weights. ( #279 )
2023-07-30 15:41:42 +01:00
c950a5c6b1
Cuda support for the mnist training. ( #277 )
...
* Cuda support for the mnist training.
* min/max fix + testing.
* Add the argmin/argmax tests.
* More cuda support for argmin/argmax.
* Cuda kernels for argmin and argmax.
2023-07-29 19:48:04 +01:00
16c33383eb
Improve the mnist training example. ( #276 )
...
* Improve the mnist training example.
* Add some initialization routine that can be used for nn.
* Proper initialization in the mnist example.
2023-07-29 16:28:22 +01:00
c0a8ed19eb
Support for where-cond on cuda for u8 and u32. ( #274 )
2023-07-29 11:48:58 +01:00
3eb2bc6d07
Softmax numerical stability. ( #267 )
...
* Softmax numerical stability.
* Fix the flash-attn test.
2023-07-28 13:13:01 +01:00
3e89df938c
Starcoder fix ( #264 )
...
* Bugfix for starcoder.
* Get some proper code generation.
* Slightly simpler softmax.
2023-07-28 11:17:49 +01:00
4f260ef025
Merge pull request #216 from LaurentMazare/llama_multiprocess2
...
TP sharding v2
2023-07-28 08:06:13 +01:00
952eca6b54
Fixing slice errors + comments.
2023-07-27 16:59:32 +02:00
f291065f6c
Do not panic on empty ranges. ( #257 )
2023-07-27 09:28:47 +01:00
25a2086e8f
Putting back Send + Sync
2023-07-27 09:58:47 +02:00
7c7e6ba201
Removing inner dependency on safetensors.
2023-07-27 09:58:47 +02:00
ed58de7551
Fixed TP sharded version.
2023-07-27 09:58:46 +02:00
1735e4831e
TP sharding v2
2023-07-27 09:58:14 +02:00
6475bfadfe
Simplify Tensor::randn. ( #255 )
...
* Simplify Tensor::randn.
* Also switch Tensor::rand to use a generic dtype.
* Support sampling for f16.
* Cleanup.
2023-07-27 07:40:36 +01:00
89ba005962
Support backprop for a few more ops. ( #254 )
2023-07-26 21:31:54 +01:00
1235aa2536
Use bail rather than wrapping a string where possible. ( #249 )
...
* Use bail rather than wrapping a string where possible.
* Revert the cuda default bit.
2023-07-26 15:42:46 +01:00
8b1d12bead
Merge pull request #246 from LaurentMazare/rename_custom_op
...
Rename exposed ops.
2023-07-26 14:20:29 +01:00
1a5416ec35
Rename exposed ops.
2023-07-26 12:43:19 +02:00
fa2b64d678
Proper flash-attn parameters. ( #244 )
...
* Proper flash-attn parameters.
* Set the flash attention parameters.
* Add more validations.
* Setup the o_ flash attn parameters.
* More flash-attn support.
* Set more flash attn parameters.
2023-07-26 10:13:40 +01:00
d9f9c859af
Add flash attention ( #241 )
...
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.
* More flash attn.
* Set up the flash attn parameters.
* Get things to compile locally.
* Move the flash attention files in a different directory.
* Build the static C library with nvcc.
* Add more flash attention.
* Update the build part.
* Better caching.
* Exclude flash attention from the default workspace.
* Put flash-attn behind a feature gate.
* Get the flash attn kernel to run.
* Move the flags to a more appropriate place.
* Enable flash attention in llama.
* Use flash attention in llama.
2023-07-26 07:48:10 +01:00
c97d51243c
Add an abstract backprop op type ( #240 )
...
* Start adding the backprop op type.
* More backprop ops.
* Finish the backprop op.
2023-07-25 14:07:40 +01:00
be9c26180c
Avoid keeping track of the copy ops when not necessary. ( #239 )
2023-07-25 10:06:01 +01:00
944d70bd9a
Add a test for scatter add. ( #238 )
...
* Add a test for scatter add (segfaults on gpus for now).
* Bugfix for the scatter add cuda kernel.
2023-07-25 09:12:14 +01:00
18cc73954a
Add some testing for index-add ( #237 )
...
* Add some testing for index-add.
* Fix the cpu implementation for index-add.
2023-07-25 08:38:33 +01:00
74a6a769dd
Cuda kernels for IndexAdd/ScatterAdd. ( #236 )
...
* Skeleton methods for IndexAdd/ScatterAdd.
* Add a Map2InPlace trait.
* Add the glue code for the index-add/scatter-add kernels.
* Tweak the file name: embeddings -> indexing.
* Add the cuda kernel for indexadd.
* And add the scatter-add kernels.
2023-07-24 21:53:08 +01:00
581b104f97
Indexing cuda ( #235 )
...
* Allow using uint8_t for indexing.
* Revert the default cuda feature.
* Add a cuda-kernel for index-select.
* Add a test for gather.
2023-07-24 20:22:47 +01:00
b50f932e7c
Add some cmp tests. ( #233 )
...
* Add some cmp tests.
* Add the cuda kernels for comparison operations.
2023-07-24 16:53:45 +01:00
fe87778223
Add the copy op. ( #227 )
...
* Add the copy op.
* Tweak some cat error messages.
* Handle the contiguous case in to_vec1.
* Fast variant for to_vec2.
* Add add a faster to_vec3 variant.
2023-07-23 18:06:47 +01:00
23827c49cd
Cleanup some todos. ( #226 )
...
* Cleanup some todos.
* Fix more todo.
* Optimize for the contiguous case.
* Add the IntDType trait.
* Handle the intdtype trait for more ops.
* Remove a todo.
* Remove a todo.
2023-07-23 16:00:00 +01:00
e449ce53a2
Wrapping code to call the custom op. ( #225 )
...
* Wrapping code to call the custom op.
* Get the rms example to work.
* Get around rustfmt failing in the CI.
* Fix the rms computation.
2023-07-23 11:31:17 +01:00