Compare commits

..

153 Commits

Author SHA1 Message Date
a35a935118 Transpose the rhs in linear. 2023-08-03 16:51:24 +01:00
df6667ba88 Add some tracing to llama. (#318) 2023-08-03 13:52:22 +01:00
a79286885c Support safetensors weights in llama2.c inference. (#317) 2023-08-03 11:10:58 +01:00
74845a4dcd Use the assert! function as it turns out to be const. (#316) 2023-08-03 10:03:43 +01:00
aa76b783eb Q6K dequantization. (#315) 2023-08-03 09:31:20 +01:00
25564357f7 Support some ggml quantized types (#314)
* Add the quantized types for GGML loading.

* Support quantization for Q2K.

* More quantization support.

* Fix some clippy lints.
2023-08-03 09:16:26 +01:00
634700d84a Use some consts for ggml values. (#312) 2023-08-02 22:03:05 +01:00
e635f18eda Initial support for reading ggml files. (#311)
* Start adding support for reading ggml files.

* Compute the proper tensor size.

* Print the read tensors.

* Fix file reading.
2023-08-02 21:59:02 +01:00
52414ba5c8 Bugfix for the llama2 wasm example. (#310)
* Clean-up the llama2.c wasm example.

* Use a proper tokenizer.

* Add a prompt.

* Bugfix for the llama2 wasm example.
2023-08-02 17:32:36 +01:00
186c308d51 Wasm llama2 tweaks (#309)
* Clean-up the llama2.c wasm example.

* Use a proper tokenizer.
2023-08-02 15:49:43 +01:00
4f17290ce0 Use AdamW in the llama2 training. (#308) 2023-08-02 14:14:02 +01:00
0902846f25 Add the AdamW optimizer. (#307)
* Add the AdamW optimizer.

* Add some AdamW test validated against PyTorch.
2023-08-02 14:03:49 +01:00
e2acbe1e72 Update the wasm example locations in the readme. (#306) 2023-08-02 11:36:43 +01:00
4fe8a02f88 Update the repo location. (#305) 2023-08-02 11:12:18 +01:00
03a421f714 Add some missing readme files. (#304) 2023-08-02 10:57:12 +01:00
d38943aadc Add version numbers for all the candle crates (#303)
* Switch to candle-gemm for the time being.

* Add the missing versions.
2023-08-02 10:52:13 +01:00
51e51da896 Rename the candle crate to candle-core (#301)
* Rename to candle-core.

* More candle-core renaming.
2023-08-02 08:20:22 +01:00
6e33ff62d6 Update cudarc now that it includes the cublas-f16 and nccl changes. (#300) 2023-08-02 05:54:28 +01:00
4b3bd79fbd Remove the embedding ops in favor of index-select. (#299)
* Remove the embedding ops in favor of index-select.

* Also remove the cuda kernels.
2023-08-02 05:42:11 +01:00
cc76c63202 Use index-select for the embeddings as it supports backprop. (#298) 2023-08-01 20:44:43 +01:00
ff876c2103 Llama more training (#297)
* Rework the var-builder to handle initializations.

* Add some helper functions for layer creation.

* Improve the layer initializations.

* Get initialized variables.

* Precompute the rot embeddings when training lamas.
2023-08-01 19:53:41 +01:00
a27239f3d9 Add training for the llama2.c example (#296)
* Rework the commands and run inference by default.

* Add the training module and load the training dataset.

* Random dataset iterator.

* Proper valid-loss computation.

* Compute the evaluation loss.

* Add more substance to the training loop.
2023-08-01 17:23:07 +01:00
babee9f011 Merge pull request #259 from LaurentMazare/book_2
Book 2 (load/save)
2023-08-01 17:26:57 +02:00
afb5e24a63 Remove map ownership from save. 2023-08-01 17:19:22 +02:00
89d1fd03e5 Adding new surface for savetensors (global load, global save). 2023-08-01 15:00:38 +02:00
310094310b Modifying safetensors export to get simple load and save. 2023-08-01 15:00:38 +02:00
836ba3e090 Merge pull request #258 from LaurentMazare/start_book
Starting the book.
2023-08-01 14:59:34 +02:00
091e781977 Grammarly pass. 2023-08-01 14:26:02 +02:00
5cead227ef Adressed comments. 2023-08-01 14:26:02 +02:00
ebd0315623 Typo. 2023-08-01 14:26:02 +02:00
ad9d8fe400 Complexifying our hello world 2023-08-01 14:26:02 +02:00
5bc5716b85 Revert "Making sure the CI actually works"
This reverts commit 699346b603cec1f279d94e9aa3210c193ba973f8.
2023-08-01 14:26:02 +02:00
ba37de94d4 Making sure the CI actually works 2023-08-01 14:26:02 +02:00
6242a1470e Starting the book. 2023-08-01 14:26:02 +02:00
75e0448114 Move the weight bits in a separate module. (#295) 2023-08-01 10:37:06 +01:00
614f911e9e Add some batcher variants that handle errors. (#294) 2023-08-01 09:40:34 +01:00
e1e8127f15 Add the batcher. (#293) 2023-08-01 09:16:10 +01:00
fa98ca0c35 Use subcommands in llama2. (#292) 2023-08-01 05:57:41 +01:00
1a07ff8d17 Pre-tokenized evaluation mode for llama2.c. (#291) 2023-08-01 05:36:25 +01:00
f28558d0b7 Evaluate on the pre-tokenized file. (#290) 2023-07-31 21:31:38 +01:00
6b98b66eb3 Remove the end of text tokens. (#289) 2023-07-31 20:43:57 +01:00
9ae1f6afee Add an eval mode to llama2-c (#288)
* Add an eval mode to llama2-c.

* Encode line by line.

* Get the eval to run.
2023-07-31 17:22:14 +01:00
1064b9b031 Add the cross-entropy loss. (#287) 2023-07-31 14:26:36 +01:00
ffeafbfc43 Make the nll op closer to the pytorch version + add a test. (#286) 2023-07-31 14:14:01 +01:00
b3ea96b62b Add a prompt and support more models in llama2-c. (#285)
* Support more models in llama2-c.

* Add a prompt.
2023-07-31 13:09:30 +01:00
94a43faaca Use the hub models for llama2.c (#284) 2023-07-31 12:51:14 +01:00
62a9b03715 Add a flag to set the number of epochs in the mnist training (#283)
* Add a flag to change the number of epochs for the mnist training.

* Increase the learning rate for the MLP.
2023-07-31 10:32:14 +01:00
67834119fc Fix the flash-attention function names. (#282) 2023-07-31 10:04:39 +01:00
0ace420e66 Flash attention without padding (varlen). (#281)
* Expose the seqlen variable for flash-attn without padding.

* Fix the batched call.

* Adapt for the varlen variant.

* No need to set the batch strides when in varlen mode.

* Add a test (disabled at the moment).

* Get the test to work properly.
2023-07-31 09:45:39 +01:00
a8d8f9f206 Load a trained checkpoint in the mnist example. (#280) 2023-07-30 17:01:45 +01:00
38ff693af0 Add a flag to save the trained weights. (#279) 2023-07-30 15:41:42 +01:00
ba2254556c Display the temperature being used for text generation. (#278) 2023-07-30 09:53:05 +01:00
c950a5c6b1 Cuda support for the mnist training. (#277)
* Cuda support for the mnist training.

* min/max fix + testing.

* Add the argmin/argmax tests.

* More cuda support for argmin/argmax.

* Cuda kernels for argmin and argmax.
2023-07-29 19:48:04 +01:00
16c33383eb Improve the mnist training example. (#276)
* Improve the mnist training example.

* Add some initialization routine that can be used for nn.

* Proper initialization in the mnist example.
2023-07-29 16:28:22 +01:00
bedcef64dc Merge pull request #262 from LaurentMazare/update_multiprocess
Making multiprocess require flash-attn.
2023-07-29 16:40:39 +02:00
40c80bfbb2 Merge branch 'main' into update_multiprocess 2023-07-29 16:38:35 +02:00
07eb899729 More mnist training. (#275) 2023-07-29 13:29:31 +01:00
c0a8ed19eb Support for where-cond on cuda for u8 and u32. (#274) 2023-07-29 11:48:58 +01:00
4bf2ebf836 Use u8 tensors for masks. (#273) 2023-07-29 11:32:58 +01:00
97d8712ba5 Remove single function. 2023-07-28 23:31:25 +02:00
97181a77c0 Making multiprocess require flash-attn. 2023-07-28 23:31:24 +02:00
50d8273ae4 Support both llama v1 and llama v2. (#272) 2023-07-28 18:40:59 +01:00
7513a5e005 Line-up the llama implementation with the python-transformers one. (#271)
* Line-up the llama implementation with the python-transformers one.

* Also lineup the multiprocess version.
2023-07-28 18:31:28 +01:00
cb8dd5cd53 Back to using the main branch now that the PR has been merged. (#270) 2023-07-28 16:22:44 +01:00
a0e47aba98 Fix the revision used in starcoder to use the safetensors PR. (#269) 2023-07-28 14:02:31 +01:00
fb84ead8f7 Add the starcoder example to the readme. (#268)
* Add the starcoder example to the readme.

* Tweak.
2023-07-28 13:26:23 +01:00
3eb2bc6d07 Softmax numerical stability. (#267)
* Softmax numerical stability.

* Fix the flash-attn test.
2023-07-28 13:13:01 +01:00
68eab38de6 Cuda fix for starcoder. (#266)
* Cuda fix for starcoder.

* Nicer output.
2023-07-28 12:13:41 +01:00
54ccf94472 Merge pull request #265 from LaurentMazare/fix_nccl
Fix nccl
2023-07-28 11:37:58 +01:00
4002968cf5 Put back `"dep:half" 2023-07-28 10:34:21 +00:00
be256a6ba6 Fixing. 2023-07-28 10:23:05 +00:00
d2dea11ef6 Fixing nccl feature. 2023-07-28 12:19:20 +02:00
3e89df938c Starcoder fix (#264)
* Bugfix for starcoder.

* Get some proper code generation.

* Slightly simpler softmax.
2023-07-28 11:17:49 +01:00
6a54ca115e Add some Bigcode model (#260)
* Start sketching the bigcode gpt model.

* Sketch the bigcode model.

* Implement the attention mechanism.

* Random reshaping.

* Sketch more of the example.

* Add some kv cache.

* Properly generate the position ids.

* Proper attention mask.

* Bail on upcasting.

* Properly apply the attention mask.

* Add the smaller starcoder variants.

* Update for the new hub api.

* Fix a shape issue.

* Fix another shape issue.

* Get some logits out.

* Adjust the weigth names.
2023-07-28 09:57:32 +01:00
4f260ef025 Merge pull request #216 from LaurentMazare/llama_multiprocess2
TP sharding v2
2023-07-28 08:06:13 +01:00
0b97987b21 Merge pull request #261 from LaurentMazare/upgrade_hf_hub
Upgrading hf-hub to `0.2.0` (Modified API to not pass the Repo around all the time)
2023-07-28 07:03:30 +01:00
8435a99edd Added comment about offsets. 2023-07-27 20:11:57 +02:00
ca479a873e Upgrading hf-hub to 0.2.0 (Modified API to not pass the Repo around
all the time)
2023-07-27 20:05:02 +02:00
952eca6b54 Fixing slice errors + comments. 2023-07-27 16:59:32 +02:00
f291065f6c Do not panic on empty ranges. (#257) 2023-07-27 09:28:47 +01:00
25a2086e8f Putting back Send + Sync 2023-07-27 09:58:47 +02:00
7c7e6ba201 Removing inner dependency on safetensors. 2023-07-27 09:58:47 +02:00
1553b58fe5 Tensor are not necessarily sendable (CustomOp1). 2023-07-27 09:58:47 +02:00
b7814f66b4 PyO3 is back. 2023-07-27 09:58:47 +02:00
ed58de7551 Fixed TP sharded version. 2023-07-27 09:58:46 +02:00
1735e4831e TP sharding v2 2023-07-27 09:58:14 +02:00
209f06d7c3 Micro-cleanup. (#256) 2023-07-27 07:55:54 +01:00
6475bfadfe Simplify Tensor::randn. (#255)
* Simplify Tensor::randn.

* Also switch Tensor::rand to use a generic dtype.

* Support sampling for f16.

* Cleanup.
2023-07-27 07:40:36 +01:00
89ba005962 Support backprop for a few more ops. (#254) 2023-07-26 21:31:54 +01:00
4f92420132 Add some flash attn test (#253)
* Add some flash-attn test.

* Add the cpu test.

* Fail when the head is not a multiple of 8.

* Polish the flash attention test.
2023-07-26 20:56:00 +01:00
ded197497c Merge pull request #252 from LaurentMazare/add_book
Adding a cargo book
2023-07-26 17:35:54 +01:00
84ad558e50 Switch to using llama-v2 by default. (#251) 2023-07-26 17:18:27 +01:00
368f169c6a Permissions. 2023-07-26 18:12:02 +02:00
8da6568c20 Typo. 2023-07-26 18:11:10 +02:00
07a22fe606 Releasing within the branch to test the setup. 2023-07-26 18:08:34 +02:00
834e1b197b Adding a documentation book. 2023-07-26 18:06:31 +02:00
89fd988836 Update to the latest gemm. (#250) 2023-07-26 17:00:02 +01:00
1235aa2536 Use bail rather than wrapping a string where possible. (#249)
* Use bail rather than wrapping a string where possible.

* Revert the cuda default bit.
2023-07-26 15:42:46 +01:00
f052ba76cb Lining up the flash attn version with the non-flash one. (#248)
* Move the flash-attn function in the proper crate.

* Causality tweak.
2023-07-26 15:11:45 +01:00
46f2d9f0ac Merge pull request #247 from LaurentMazare/add_number_of_tokens
Add number of tokens.
2023-07-26 14:45:53 +01:00
81bfa46702 Updated. 2023-07-26 15:21:50 +02:00
8b1d12bead Merge pull request #246 from LaurentMazare/rename_custom_op
Rename exposed ops.
2023-07-26 14:20:29 +01:00
035372248e Simple QOL.
- Add ms/token on llama2.c (15ms/token on my personal machine)
- Hide `Run` buttons while models are not ready
- Add dummy `progress` while weights are downloading (I briefly looked
  at putting a real progressbar.. and nothing easy enough came up.)
2023-07-26 15:17:32 +02:00
2ce5f12513 Again set a few extra params in flash-attn. (#245)
* Again set a few extra params.

* Use the appropriate kernel sizes.

* Add all the kernel sizes.

* Parallel compiling.

* Reduce the amount of parallelism.

* Add the missing kernel.

* Fix a typo.

* Remove bf16 support for now.
2023-07-26 14:16:37 +01:00
97990f4afc Add number of tokens. 2023-07-26 14:57:20 +02:00
1a5416ec35 Rename exposed ops. 2023-07-26 12:43:19 +02:00
fa2b64d678 Proper flash-attn parameters. (#244)
* Proper flash-attn parameters.

* Set the flash attention parameters.

* Add more validations.

* Setup the o_ flash attn parameters.

* More flash-attn support.

* Set more flash attn parameters.
2023-07-26 10:13:40 +01:00
e40b150bbe Better handling of dtypes in llama. (#243) 2023-07-26 08:28:33 +01:00
471855e2ee Specific cache dir for the flash attn build artifacts. (#242) 2023-07-26 08:04:02 +01:00
d9f9c859af Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.

* More flash attn.

* Set up the flash attn parameters.

* Get things to compile locally.

* Move the flash attention files in a different directory.

* Build the static C library with nvcc.

* Add more flash attention.

* Update the build part.

* Better caching.

* Exclude flash attention from the default workspace.

* Put flash-attn behind a feature gate.

* Get the flash attn kernel to run.

* Move the flags to a more appropriate place.

* Enable flash attention in llama.

* Use flash attention in llama.
2023-07-26 07:48:10 +01:00
c97d51243c Add an abstract backprop op type (#240)
* Start adding the backprop op type.

* More backprop ops.

* Finish the backprop op.
2023-07-25 14:07:40 +01:00
be9c26180c Avoid keeping track of the copy ops when not necessary. (#239) 2023-07-25 10:06:01 +01:00
944d70bd9a Add a test for scatter add. (#238)
* Add a test for scatter add (segfaults on gpus for now).

* Bugfix for the scatter add cuda kernel.
2023-07-25 09:12:14 +01:00
18cc73954a Add some testing for index-add (#237)
* Add some testing for index-add.

* Fix the cpu implementation for index-add.
2023-07-25 08:38:33 +01:00
74a6a769dd Cuda kernels for IndexAdd/ScatterAdd. (#236)
* Skeleton methods for IndexAdd/ScatterAdd.

* Add a Map2InPlace trait.

* Add the glue code for the index-add/scatter-add kernels.

* Tweak the file name: embeddings -> indexing.

* Add the cuda kernel for indexadd.

* And add the scatter-add kernels.
2023-07-24 21:53:08 +01:00
581b104f97 Indexing cuda (#235)
* Allow using uint8_t for indexing.

* Revert the default cuda feature.

* Add a cuda-kernel for index-select.

* Add a test for gather.
2023-07-24 20:22:47 +01:00
b50f932e7c Add some cmp tests. (#233)
* Add some cmp tests.

* Add the cuda kernels for comparison operations.
2023-07-24 16:53:45 +01:00
160ba09d30 Polish the llama2 wasm ui. (#232)
* Polish the llama2 wasm ui.

* readme update.
2023-07-24 15:28:27 +01:00
5a26cba733 Re-organize the wasm examples (#231)
* Move the whisper example.

* More renaming.

* Add llama2 as a new wasm example.

* Live generation.

* More of the llama wasm example.

* Formatting.
2023-07-24 12:36:02 +01:00
550a13a547 Use the binary decoder for llama2.c. (#230)
* Use the binary decoder for llama2.c.

* Add the temperature.

* Formatting tweak.

* Fix the rotary embeddings.
2023-07-24 10:56:08 +01:00
35b65fed88 Add llama2.c as an example. (#229)
* Start adding llama2.c.

* Model loading.

* Add the llama-v2 model.

* Start converting the weights.

* Rotary embedding tweaks.

* Get the model to generate some tokens.
2023-07-24 09:13:50 +01:00
b6f7dfb682 CPU implementation for the custom RMS example. (#228)
* CPU implementation for the custom RMS example.

* Add the eps parameter.
2023-07-23 20:04:20 +01:00
fe87778223 Add the copy op. (#227)
* Add the copy op.

* Tweak some cat error messages.

* Handle the contiguous case in to_vec1.

* Fast variant for to_vec2.

* Add add a faster to_vec3 variant.
2023-07-23 18:06:47 +01:00
23827c49cd Cleanup some todos. (#226)
* Cleanup some todos.

* Fix more todo.

* Optimize for the contiguous case.

* Add the IntDType trait.

* Handle the intdtype trait for more ops.

* Remove a todo.

* Remove a todo.
2023-07-23 16:00:00 +01:00
e449ce53a2 Wrapping code to call the custom op. (#225)
* Wrapping code to call the custom op.

* Get the rms example to work.

* Get around rustfmt failing in the CI.

* Fix the rms computation.
2023-07-23 11:31:17 +01:00
b8a10425ad Kernel build example (#224)
* Build example kernels.

* Add some sample custom kernel.

* Get the example kernel to compile.

* Add some cuda code.

* More cuda custom op.

* More cuda custom ops.
2023-07-23 07:15:37 +01:00
5f20acf080 Revert "Add the layer norm files. (#222)" (#223)
This reverts commit c8459d199d.
2023-07-22 16:51:11 +01:00
c8459d199d Add the layer norm files. (#222) 2023-07-22 15:06:35 +01:00
1f26042693 Move some shared functions to the nn module. (#221) 2023-07-22 13:25:11 +01:00
43c7223292 Rename the .r functions to .dims so as to be a bit more explicit. (#220) 2023-07-22 10:39:27 +01:00
52c5d8c087 Add the gather op. (#219)
* Start adding gather.

* Gather cpu implementation + use in simple training.

* Add scatter_add for the gradient of gather.

* Simple cpu implementation of scatter_add.

* Use gather in the simple-training backprop.
2023-07-22 07:21:28 +01:00
6eeea1b04e Polish the index-add op and use it in the index-select backprop (#218)
* Add the cpu version of index-add.

* More cpu support for index-add.

* Use index-add in the backprop.
2023-07-22 05:31:46 +01:00
27174a82aa Start adding index-add. 2023-07-21 20:12:48 +01:00
5cc843550d Add binary and ternary custom ops. (#217) 2023-07-21 17:29:50 +01:00
4a100875bf Use a macro to handle the dtype pattern matching. (#215) 2023-07-21 16:03:51 +01:00
a6bcdfb269 Custom ops with a single argument (#214)
* Add the CustomOp1 trait.

* Add an example of custom op.

* Polish the custom op example.

* Add some backward pass test for custom ops.
2023-07-21 15:18:05 +01:00
b02229ce92 Add some epsilon tolerance to grad tests so that they work on cuda / mkl. (#213) 2023-07-21 12:45:14 +01:00
410654525f Refactor the reduce ops in order to introduce argmin/argmax. (#212)
* Refactor the reduce ops in order to introduce argmin/argmax.

* Clippy fixes.

* Use the newly introduced argmax.

* Fix the strided case.

* Handle the non-contiguous case.
2023-07-21 11:41:08 +01:00
c60831aad4 Add more gradient tests + bugfixes. (#211)
* Add more gradient tests + bugfixes.

* More tests and fixes.

* More tests.
2023-07-21 06:52:39 +01:00
4845d5cc64 More realistic training setup. (#210)
* More realistic training setup.

* Compute the model accuracy.

* Very inefficient backprop for index select.

* More backprop.

* Fix some backprop issues.

* Backprop fix.

* Another broadcasting backprop fix.

* Better backprop for reducing ops.

* Training again.

* Add some gradient tests.

* Get the training to work.
2023-07-20 18:25:41 +01:00
fa08fb3126 Add the index-select op. (#209)
* Add the index-select op.

* Cpu implementation of index-select.

* Add the cpu implementation for index-select.
2023-07-20 14:01:03 +01:00
2a8f28d687 Op refactor (#208)
* Add the binary and unary op enums to factorize some code.

* Bugfix.
2023-07-20 12:28:45 +01:00
e9c052bf94 Add the comparison operations. (#207)
* Add the comparison operations.

* Add the helper functions on the tensor side.

* More cmp operations.

* Cpu implementation for the comparison operations.
2023-07-20 09:40:31 +01:00
dc416243a3 Bump the hf-hub dependency to 0.1.3. (#206) 2023-07-20 07:27:52 +01:00
12d6dc018d Support for MQA for llama v2. (#205)
* Support for MQA for llama v2.

* More llama-v2.

* Move the rotary embedding precomputation in the cache.

* Add a v2 flag.

* Use the hf model.
2023-07-20 06:39:04 +01:00
c34f932319 Fix the mkl build. (#204)
* Fix the mkl build.

* Fix the build properly.
2023-07-19 19:41:11 +01:00
536c5e702e Cuda kernels for fast min/max reductions (#203)
* Add the min/max cuda kernels.

* Better integration of the cuda kernels.
2023-07-19 18:12:27 +01:00
001f9a59ce Merge pull request #201 from LaurentMazare/remove_wrapper
[Proposal] Remove SafeTensor wrapper (allows finer control for users).
2023-07-19 19:02:37 +02:00
9515e8ea6c Merge branch 'main' into remove_wrapper 2023-07-19 18:53:55 +02:00
ad12e20f6b Add cpu support for min and max. (#202)
* Add cpu support for min and max.

* Add min/max all.
2023-07-19 17:11:44 +01:00
e6584476c4 Merge pull request #200 from LaurentMazare/removing_candle_hub
Removing `candle-hub` internal to extract into `hf-hub` standalone.
2023-07-19 17:27:55 +02:00
cb687b4897 Add some more developed training examples. (#199)
* Use contiguous tensors for variables.

* Sketch the mnist example.

* Start adding the reduce ops.

* Renaming.

* Refactor the reduce operations.

* Bugfix for the broadcasting vectorization.
2023-07-19 15:37:52 +01:00
439321745a Removing candle-hub internal to extract into hf-hub standalone. 2023-07-19 15:04:38 +02:00
176 changed files with 13324 additions and 2994 deletions

42
.github/workflows/book-cd.yml vendored Normal file
View File

@ -0,0 +1,42 @@
name: Deploy Rust book
on:
# TODO put this back only when merging after this PR lands.
pull_request:
push:
branches:
- main
jobs:
deploy:
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir mdbook
curl -sSL $url | tar -xz --directory=./mdbook
echo `pwd`/mdbook >> $GITHUB_PATH
- name: Deploy GitHub Pages
run: |
# This assumes your book is in the root of your repository.
# Just add a `cd` here if you need to change to another directory.
cd candle-book
mdbook build
git worktree add gh-pages
git config user.name "Deploy from CI"
git config user.email ""
cd gh-pages
# Delete the ref to avoid keeping history.
git update-ref -d refs/heads/gh-pages
rm -rf *
mv ../book/* .
git add .
git commit -m "Deploy $GITHUB_SHA to gh-pages"
git push --force --set-upstream origin gh-pages

29
.github/workflows/book.yml vendored Normal file
View File

@ -0,0 +1,29 @@
name: CI
on:
pull_request:
jobs:
test:
name: Test candle-book
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@master
- name: Install Rust
run: |
rustup set profile minimal
rustup toolchain install stable
rustup default stable
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir bin
curl -sSL $url | tar -xz --directory=bin
echo "$(pwd)/bin" >> $GITHUB_PATH
- name: Run tests
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/

8
.gitignore vendored
View File

@ -1,6 +1,7 @@
# Generated by Cargo
# will have compiled files and executables
debug/
data/
dist/
target/
@ -23,6 +24,7 @@ flamegraph.svg
*.swp
trace-*.json
candle-wasm-example/*.wav
candle-wasm-example/*.safetensors
candle-wasm-example/package-lock.json
candle-wasm-examples/*/*.bin
candle-wasm-examples/*/*.wav
candle-wasm-examples/*/*.safetensors
candle-wasm-examples/*/package-lock.json

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "candle-examples/examples/flash-attn/cutlass"]
path = candle-flash-attn/cutlass
url = https://github.com/NVIDIA/cutlass.git

View File

@ -2,46 +2,47 @@
members = [
"candle-core",
"candle-examples",
"candle-hub",
"candle-nn",
"candle-pyo3",
"candle-transformers",
"candle-wasm-example",
"candle-wasm-examples/llama2-c",
"candle-wasm-examples/whisper",
]
exclude = [
"candle-flash-attn",
"candle-kernels",
]
[workspace.package]
version = "0.1.0"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
[workspace.dependencies]
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
# cudarc = { version = "0.9.13", optional = true, features = ["f16"] }
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] }
futures = "0.3.28"
# TODO: Switch back to the official gemm implementation once the following are available.
# https://github.com/sarah-ek/gemm/pull/8.
# https://github.com/sarah-ek/gemm/pull/9.
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vec-plus-wasm-simd" }
half = { version = "2.3.1", features = ["num-traits"] }
indicatif = "0.17.5"
intel-mkl-src = { version = "0.8.1", features = ["mkl-dynamic-lp64-iomp"] }
cudarc = { version = "0.9.13", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.5", package = "candle-gemm" }
hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
memmap2 = "0.7.1"
num_cpus = "1.15.0"
num-traits = "0.2.15"
rand = "0.8.5"
reqwest = "0.11.18"
safetensors = "0.3.1"
serde = { version = "1.0.166", features = ["derive"] }
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99"
sha256 = "=1.1.4"
thiserror = "1"
tokenizers = { version = "0.13.3", default-features = false }
tokio = "1.28.2"
tokio-test = "0.4.2"
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"

View File

@ -1,59 +1,95 @@
# candle
ML framework for Rust
[![Latest version](https://img.shields.io/crates/v/candle-core.svg)](https://crates.io/crates/candle-core)
[![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core)
![License](https://img.shields.io/crates/l/candle-core.svg)
Candle is a minimalist ML framework for Rust with a focus on easiness of use and
on performance (including GPU support). Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
```rust
let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?;
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
let c = a.matmul(&b)?;
println!("{c}");
```
## Check out our examples
Check out our [examples](./candle-examples/examples/):
- [Whisper](./candle-examples/examples/whisper/)
- [Llama](./candle-examples/examples/llama/)
- [Bert](./candle-examples/examples/bert/) (Useful for sentence embeddings)
- [Falcon](./candle-examples/examples/falcon/)
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [Llama and Llama-v2](./candle-examples/examples/llama/): general LLM.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
generation.
Run them using the following commands:
```
cargo run --example bert --release
cargo run --example whisper --release
cargo run --example llama --release
cargo run --example falcon --release
cargo run --example bert --release
cargo run --example bigcode --release
```
In order to use **CUDA** add `--features cuda` to the example command line.
There are also some wasm examples for whisper and
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
`trunk` or try them online:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
For llama2, run the following command to retrieve the weight files and start a
test server:
```bash
cd candle-wasm-examples/llama2-c
wget https://karpathy.ai/llama2c/model.bin
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
trunk serve --release --public-url /candle-llama2/ --port 8081
```
And then browse to
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
<!--- ANCHOR: features --->
## Features
- Simple syntax (looks and like PyTorch)
- CPU and Cuda backends, m1, f16, bf16 (and tentatively wasm)
- Simple syntax, looks and like PyTorch.
- CPU and Cuda backends, m1, f16, bf16.
- Enable serverless (CPU), small and fast deployments
- Model training
- Distributed computing (NCCL).
- Models out of the box (Llama, Whisper, Falcon, ...)
- Emphasis on enabling users to use custom ops/kernels
- WASM support, run your models in a browser.
- Model training.
- Distributed computing using NCCL.
- Models out of the box: Llama, Whisper, Falcon, StarCoder...
- Embed user-defined ops/kernels, such as [flash-attention
v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
<!--- ANCHOR_END: features --->
## How to use ?
<!--- ANCHOR: cheatsheet --->
Cheatsheet:
| | Using PyTorch | Using Candle |
|------------|------------------------------------------|------------------------------------------------------------------|
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(` |
| | | ` &[[1f32, 2.]], [3., 4.]],` |
| | | ` &Device::Cpu)?` |
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.]], [3., 4.]], &Device::Cpu)?` |
| Creation | `torch.zeros((2, 2))` | `Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?` |
| Indexing | `tensor[:, :4]` | `tensor.i((.., ..4))?` |
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
| Operations | `a.matmul(b)` | `a.matmul(&b)?` |
| Arithmetic | `a + b` | `&a + &b` |
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` |
| Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` |
| Saving | `torch.save({"A": A}, "model.bin")` | `tensor.save_safetensors("A", "model.safetensors")?` |
| Loading | `weights = torch.load("model.bin")` | TODO (see the examples for now) |
| Saving | `torch.save({"A": A}, "model.bin")` | `candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?` |
| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` |
<!--- ANCHOR_END: cheatsheet --->
## Structure

1
candle-book/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
book

6
candle-book/book.toml Normal file
View File

@ -0,0 +1,6 @@
[book]
authors = ["Nicolas Patry"]
language = "en"
multilingual = false
src = "src"
title = "Candle Documentation"

View File

@ -0,0 +1,6 @@
# Introduction
{{#include ../../README.md:features}}
This book will introduce step by step how to use `candle`.

View File

@ -0,0 +1,27 @@
# Summary
[Introduction](README.md)
# User Guide
- [Installation](guide/installation.md)
- [Hello World - MNIST](guide/hello_world.md)
- [PyTorch cheatsheet](guide/cheatsheet.md)
# Reference Guide
- [Running a model](inference/README.md)
- [Using the hub](inference/hub.md)
- [Serialization](inference/serialization.md)
- [Advanced Cuda usage](inference/cuda/README.md)
- [Writing a custom kernel](inference/cuda/writing.md)
- [Porting a custom kernel](inference/cuda/porting.md)
- [Error management](error_manage.md)
- [Creating apps](apps/README.md)
- [Creating a WASM app](apps/wasm.md)
- [Creating a REST api webserver](apps/rest.md)
- [Creating a desktop Tauri app](apps/dekstop.md)
- [Training](training/README.md)
- [MNIST](training/mnist.md)
- [Fine-tuning](training/finetuning.md)
- [Using MKL](advanced/mkl.md)

View File

@ -0,0 +1 @@
# Using MKL

View File

@ -0,0 +1 @@
# Creating apps

View File

@ -0,0 +1 @@
# Creating a desktop Tauri app

View File

@ -0,0 +1 @@
# Creating a REST api webserver

View File

@ -0,0 +1 @@
# Creating a WASM app

View File

@ -0,0 +1 @@
# Chapter 1

View File

@ -0,0 +1 @@
# Error management

View File

@ -0,0 +1,3 @@
# Pytorch cheatsheet
{{#include ../../../README.md:cheatsheet}}

View File

@ -0,0 +1,195 @@
# Hello world!
We will now create the hello world of the ML world, building a model capable of solving MNIST dataset.
Open `src/main.rs` and fill in this content:
```rust
# extern crate candle_core;
use candle_core::{DType, Device, Result, Tensor};
struct Model {
first: Tensor,
second: Tensor,
}
impl Model {
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = image.matmul(&self.first)?;
let x = x.relu()?;
x.matmul(&self.second)
}
}
fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to use the GPU.
let device = Device::Cpu;
let first = Tensor::zeros((784, 100), DType::F32, &device)?;
let second = Tensor::zeros((100, 10), DType::F32, &device)?;
let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}
```
Everything should now run with:
```bash
cargo run --release
```
## Using a `Linear` layer.
Now that we have this, we might want to complexify things a bit, for instance by adding `bias` and creating
the classical `Linear` layer. We can do as such
```rust
# extern crate candle_core;
# use candle_core::{DType, Device, Result, Tensor};
struct Linear{
weight: Tensor,
bias: Tensor,
}
impl Linear{
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.weight)?;
x.broadcast_add(&self.bias)
}
}
struct Model {
first: Linear,
second: Linear,
}
impl Model {
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = self.first.forward(image)?;
let x = x.relu()?;
self.second.forward(&x)
}
}
```
This will change the model running code into a new function
```rust
# extern crate candle_core;
# use candle_core::{DType, Device, Result, Tensor};
# struct Linear{
# weight: Tensor,
# bias: Tensor,
# }
# impl Linear{
# fn forward(&self, x: &Tensor) -> Result<Tensor> {
# let x = x.matmul(&self.weight)?;
# x.broadcast_add(&self.bias)
# }
# }
#
# struct Model {
# first: Linear,
# second: Linear,
# }
#
# impl Model {
# fn forward(&self, image: &Tensor) -> Result<Tensor> {
# let x = self.first.forward(image)?;
# let x = x.relu()?;
# self.second.forward(&x)
# }
# }
fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to use the GPU.
// Use Device::Cpu; to use the CPU.
let device = Device::cuda_if_available(0)?;
// Creating a dummy model
let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
let first = Linear{weight, bias};
let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
let second = Linear{weight, bias};
let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
// Inference on the model
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}
```
Now it works, it is a great way to create your own layers.
But most of the classical layers are already implemented in [candle-nn](https://github.com/LaurentMazare/candle/tree/main/candle-nn).
## Using `candle_nn`.
For instance [Linear](https://github.com/LaurentMazare/candle/blob/main/candle-nn/src/linear.rs) is already there.
This Linear is coded with PyTorch layout in mind, to reuse better existing models out there, so it uses the transpose of the weights and not the weights directly.
So instead we can simplify our example:
```bash
cargo add --git https://github.com/LaurentMazare/candle.git candle-nn
```
And rewrite our examples using it
```rust
# extern crate candle_core;
# extern crate candle_nn;
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::Linear;
struct Model {
first: Linear,
second: Linear,
}
impl Model {
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = self.first.forward(image)?;
let x = x.relu()?;
self.second.forward(&x)
}
}
fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to use the GPU.
let device = Device::Cpu;
// This has changed (784, 100) -> (100, 784) !
let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
let first = Linear::new(weight, Some(bias));
let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
let second = Linear::new(weight, Some(bias));
let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}
```
Feel free to modify this example to use `Conv2d` to create a classical convnet instead.
Now that we have the running dummy code we can get to more advanced topics:
- [For PyTorch users](./guide/cheatsheet.md)
- [Running existing models](./inference/README.md)
- [Training models](./training/README.md)

View File

@ -0,0 +1,24 @@
# Installation
Start by creating a new app:
```bash
cargo new myapp
cd myapp
cargo add --git https://github.com/LaurentMazare/candle.git candle
```
At this point, candle will be built **without** CUDA support.
To get CUDA support use the `cuda` feature
```bash
cargo add --git https://github.com/LaurentMazare/candle.git candle --features cuda
```
You can check everything works properly:
```bash
cargo build
```
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)

View File

@ -0,0 +1 @@
# Running a model

View File

@ -0,0 +1 @@
# Advanced Cuda usage

View File

@ -0,0 +1 @@
# Porting a custom kernel

View File

@ -0,0 +1 @@
# Writing a custom kernel

View File

@ -0,0 +1 @@
# Using the hub

View File

@ -0,0 +1 @@
# Serialization

View File

@ -0,0 +1 @@
# Training

View File

@ -0,0 +1 @@
# Fine-tuning

View File

@ -0,0 +1 @@
# MNIST

View File

@ -1,18 +1,17 @@
[package]
name = "candle"
version = "0.1.0"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/LaurentMazare/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
name = "candle-core"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
readme = "README.md"
[dependencies]
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", optional = true }
candle-kernels = { path = "../candle-kernels", version = "0.1.0", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }

View File

@ -2,9 +2,14 @@
extern crate intel_mkl_src;
use anyhow::Result;
use candle::{Device, Tensor};
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
let c = a.matmul(&b)?;
println!("{a} {b} {c}");
let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
let t1 = Tensor::new(data, &Device::Cpu)?;
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];

View File

@ -1,81 +0,0 @@
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
// use candle::quantized::GgmlType;
use candle::{DType, Device, Result, Tensor};
// use clap::{Parser, Subcommand};
// fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
// let dim = dim.to_index(xs.shape(), "softmax")?;
// let max = xs.max_keepdim(dim)?;
// let diff = xs.broadcast_sub(&max)?;
// let num = diff.exp()?;
// let den = num.sum_keepdim(dim)?;
// num.broadcast_div(&den)
// }
trait Benchmark {
type PreProcessData;
type RunResult;
fn preprocess() -> Result<Self::PreProcessData>;
fn run_one(_: &Self::PreProcessData) -> Result<Self::RunResult>;
const ITERS: usize;
}
struct Matmul;
impl Benchmark for Matmul {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let lhs = Tensor::randn((1024, 1024), DType::F32, &Device::Cpu, 1.0, 0.0)?;
let rhs = Tensor::randn((1024, 1024), DType::F32, &Device::Cpu, 1.0, 0.0)?;
Ok((lhs, rhs))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.matmul(&d.1)
}
const ITERS: usize = 100;
}
// struct Softmax;
// impl Benchmark for Softmax {
// type PreProcessData = Tensor;
// type RunResult = Tensor;
// fn preprocess() -> Result<Self::PreProcessData> {
// // Typical whisper tiny size.
// let x = Tensor::randn(0f32, 1., (1, 6, 200, 1500), &Device::Cpu)?;
// Ok(x)
// }
//
// fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
// softmax(d, D::Minus1)
// }
//
// const ITERS: usize = 100;
// }
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
use std::hint::black_box;
let iters = iters.unwrap_or(B::ITERS);
let d = B::preprocess()?;
let start = std::time::Instant::now();
for _iter in 0..iters {
let _res = black_box(B::run_one(black_box(&d))?);
}
println!("{:?}", start.elapsed() / iters as u32);
Ok(())
}
fn main() -> Result<()> {
run::<Matmul>(None)?;
Ok(())
}

View File

@ -2,7 +2,7 @@
extern crate intel_mkl_src;
use anyhow::Result;
use candle::{Device, Tensor};
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
use std::str::FromStr;
use anyhow::Result;
use candle::{Device, Tensor};
use candle_core::{Device, Tensor};
fn cos_sin(n: usize, device: &Device) -> Result<Tensor> {
let thetas: Vec<_> = (0..n).map(|i| (i as f32 / n as f32)).collect();

View File

@ -1,6 +1,7 @@
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
pub(crate) trait BackendStorage: Sized {
pub trait BackendStorage: Sized {
type Device: BackendDevice;
fn try_clone(&self, _: &Layout) -> Result<Self>;
@ -16,16 +17,15 @@ pub(crate) trait BackendStorage: Sized {
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self>;
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self>;
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
fn binary_impl<B: crate::op::BinaryOp>(&self, _: &Self, _: &Layout, _: &Layout)
-> Result<Self>;
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
@ -37,7 +37,26 @@ pub(crate) trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self>;
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self>;
fn matmul(
&self,
@ -50,7 +69,7 @@ pub(crate) trait BackendStorage: Sized {
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
}
pub(crate) trait BackendDevice: Sized + std::fmt::Debug + Clone {
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
type Storage: BackendStorage;
// TODO: Make the usize generic and part of a generic DeviceLocation.

View File

@ -1,6 +1,20 @@
use crate::{op::Op, Error, Result, Tensor, TensorId};
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
// arg has been reduced to node via reduce_dims, expand it back to arg.
// This has to handle keepdims.
fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
if arg.rank() == node.rank() {
// keepdim = true
node.broadcast_as(arg.shape())
} else {
// keepdim = false
// first expand the reduced dims.
node.reshape(reduced_dims)?.broadcast_as(arg.shape())
}
}
impl Tensor {
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
@ -24,7 +38,10 @@ impl Tensor {
nodes
} else if let Some(op) = node.op() {
match op {
Op::WhereCond(t1, t2, t3) => {
Op::IndexAdd(t1, t2, t3, _)
| Op::ScatterAdd(t1, t2, t3, _)
| Op::CustomOp3(t1, t2, t3, _)
| Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t2, nodes, already_seen);
@ -38,11 +55,10 @@ impl Tensor {
kernel: rhs,
..
}
| Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
| Op::Div(lhs, rhs)
| Op::Embedding(lhs, rhs)
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
@ -65,24 +81,17 @@ impl Tensor {
}
}
Op::Reshape(node)
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Sum(node, _)
| Op::Cmp(node, _)
| Op::Reduce(node, _, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Narrow(node, _, _, _)
| Op::Softmax(node, _)
| Op::Sqr(node)
| Op::Sqrt(node)
| Op::Gelu(node)
| Op::Relu(node)
| Op::Unary(node, _)
| Op::Elu(node, _)
| Op::Exp(node)
| Op::Log(node)
| Op::Sin(node)
| Op::Cos(node)
| Op::Abs(node)
| Op::Neg(node) => {
| Op::CustomOp1(node, _) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
@ -116,19 +125,19 @@ impl Tensor {
// this is out of scope.
if let Some(op) = node.op() {
match op {
Op::Add(lhs, rhs) => {
Op::Binary(lhs, rhs, BinaryOp::Add) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
}
Op::Sub(lhs, rhs) => {
Op::Binary(lhs, rhs, BinaryOp::Sub) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
}
Op::Mul(lhs, rhs) => {
Op::Binary(lhs, rhs, BinaryOp::Mul) => {
let lhs_grad = grad.mul(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
@ -136,13 +145,13 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Div(lhs, rhs) => {
Op::Binary(lhs, rhs, BinaryOp::Div) => {
let lhs_grad = grad.div(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
}
Op::WhereCond(pred, t, f) => {
let zeros = grad.zeros_like()?;
@ -153,9 +162,30 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
}
Op::ScatterAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;
let src_grad = grad.gather(indexes, *dim)?;
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::IndexAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;
let src_grad = grad.index_select(indexes, *dim)?;
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::IndexSelect(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
}
Op::Matmul(lhs, rhs) => {
// Skipping checks, the op went ok, we can skip
@ -195,41 +225,69 @@ impl Tensor {
}
}
let arg_grad = grad.sum(sum_dims.as_slice())?;
let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
for _i in 0..left_dims {
arg_grad = arg_grad.squeeze(0)?
}
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
*sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
}
Op::Sum(arg, _sum_dims) => {
Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
let grad = broadcast_back(arg, &grad, reduced_dims)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&grad)?
*sum_grad = sum_grad.add(&grad)?;
}
Op::Cmp(_args, _) => {}
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
}
Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
}
Op::Copy(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Log(arg) => {
Op::Unary(arg, UnaryOp::Log) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * *node)?)?
*sum_grad = sum_grad.add(&(grad / arg)?)?
}
Op::Sin(arg) => {
Op::Unary(arg, UnaryOp::Sin) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
}
Op::Cos(arg) => {
Op::Unary(arg, UnaryOp::Cos) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }),
Op::Exp(arg) => {
Op::Unary(arg, UnaryOp::Abs) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad / arg)?)?
let ones = arg.ones_like()?;
let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
*sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
}
Op::Neg(arg) => {
Op::Unary(arg, UnaryOp::Exp) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * *node)?)?
}
Op::Unary(arg, UnaryOp::Neg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?
}
@ -259,24 +317,60 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Softmax(_arg, _) => {
return Err(Error::BackwardNotSupported { op: "softmax" })
}
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }),
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
Op::Sqr(arg) => {
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
}
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::CustomOp1(arg, c) => {
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
}
Op::CustomOp2(arg1, arg2, c) => {
let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
if let Some(arg_grad1) = arg_grad1 {
let sum_grad = grads.or_insert(arg1)?;
*sum_grad = sum_grad.add(&arg_grad1)?
}
if let Some(arg_grad2) = arg_grad2 {
let sum_grad = grads.or_insert(arg2)?;
*sum_grad = sum_grad.add(&arg_grad2)?
}
}
Op::CustomOp3(arg1, arg2, arg3, c) => {
let (arg_grad1, arg_grad2, arg_grad3) =
c.bwd(arg1, arg2, arg3, node, &grad)?;
if let Some(arg_grad1) = arg_grad1 {
let sum_grad = grads.or_insert(arg1)?;
*sum_grad = sum_grad.add(&arg_grad1)?
}
if let Some(arg_grad2) = arg_grad2 {
let sum_grad = grads.or_insert(arg2)?;
*sum_grad = sum_grad.add(&arg_grad2)?
}
if let Some(arg_grad3) = arg_grad3 {
let sum_grad = grads.or_insert(arg3)?;
*sum_grad = sum_grad.add(&arg_grad3)?
}
}
Op::Unary(arg, UnaryOp::Sqr) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Sqrt(arg) => {
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
Op::Unary(arg, UnaryOp::Sqrt) => {
let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,12 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
use candle_kernels as kernels;
pub use cudarc;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{
CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig,
ValidAsZeroBits,
};
use half::{bf16, f16};
use std::sync::{Arc, Mutex};
@ -32,9 +35,6 @@ pub enum CudaError {
#[error("internal error '{0}'")]
InternalError(&'static str),
#[error("internal error '{0}'")]
WrappedError(Box<dyn std::error::Error + 'static + std::marker::Send + std::marker::Sync>),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
@ -100,7 +100,7 @@ impl std::ops::Deref for CudaDevice {
}
}
trait WrapErr<O> {
pub trait WrapErr<O> {
fn w(self) -> std::result::Result<O, crate::Error>;
}
@ -170,7 +170,7 @@ impl CudaDevice {
})
}
fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
// Leaking the string here is a bit sad but we need a &'static str and this is only
// done once per kernel name.
@ -255,6 +255,8 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
@ -282,6 +284,8 @@ impl BackendDevice for CudaDevice {
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
@ -398,6 +402,82 @@ trait Map2 {
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
Ok(out)
}
}
trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()>;
fn map(
&self,
dst: &mut S,
dst_s: &Shape,
src: &S,
src_l: &Layout,
d: &CudaDevice,
) -> Result<()> {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
}
}
}
trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
wrap: W,
) -> Result<S>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
S::F64(s) => self.f(s, d, l, S::F64)?,
};
Ok(out)
}
}
trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<S>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
};
Ok(out)
@ -515,14 +595,15 @@ impl<'a> Map1 for Sum<'a> {
}
}
struct FastSum<'a>(&'a [usize]);
impl<'a> Map1 for FastSum<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
struct FastReduce<'a>(&'a [usize], ReduceOp);
impl<'a> Map1Any for FastReduce<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
wrap: W,
) -> Result<S> {
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
let src_el: usize = src_dims.iter().product();
@ -557,16 +638,36 @@ impl<'a> Map1 for FastSum<'a> {
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
.w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("fast_sum"), kernels::REDUCE)?;
let out = dev.alloc_zeros::<T>(dst_el).w()?;
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
let (name, check_empty, return_index) = match self.1 {
ReduceOp::Sum => ("fast_sum", false, false),
ReduceOp::Min => ("fast_min", true, false),
ReduceOp::Max => ("fast_max", true, false),
ReduceOp::ArgMin => ("fast_argmin", true, true),
ReduceOp::ArgMax => ("fast_argmax", true, true),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
if return_index {
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U32(out))
} else {
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(wrap(out))
}
}
}
impl<U: crate::op::UnaryOp> Map1 for U {
impl<U: UnaryOpT> Map1 for U {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@ -589,46 +690,200 @@ impl<U: crate::op::UnaryOp> Map1 for U {
}
}
struct Embedding<'a>(&'a CudaStorage, &'a Layout);
impl<'a> Map1 for Embedding<'a> {
struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map1 for IndexSelect<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
rhs: &CudaSlice<T>,
src: &CudaSlice<T>,
dev: &CudaDevice,
rhs_l: &Layout,
src_l: &Layout,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
let ids = match &self.0.slice {
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
let (name, ids) = match &self.0.slice {
CudaStorageSlice::U32(slice) => {
("is_u32", *slice.slice(ids_l.start_offset()..).device_ptr())
}
CudaStorageSlice::U8(slice) => {
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
}
_ => Err(CudaError::UnexpectedDType {
msg: "embedding ids should be u32",
msg: "index_select ids should be u8 or u32",
expected: DType::U32,
got: self.0.dtype(),
})
.w()?,
};
let ids = &ids;
let shape = ids_l.shape();
let (v_size, h_size) = rhs_l
.shape()
.r2()
.map_err(|e| CudaError::WrappedError(Box::new(e)))
.w()?;
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("emb"), kernels::EMBEDDINGS)?;
let ids_shape = ids_l.shape();
let ids_dims = ids_shape.dims();
let ids_el = ids_shape.elem_count();
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
};
let left_size: usize = src_l.dims()[..self.2].iter().product();
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
let dim_size = src_l.dims()[self.2];
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
let params = (
ids_el,
ids_dims.len(),
&ds,
ids,
&src,
&out,
left_size,
dim_size,
right_size,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct Gather<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map1 for Gather<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
src_l: &Layout,
) -> Result<CudaSlice<T>> {
let ids = &self.0;
let ids_l = &self.1;
let dim = self.2;
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => {
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "gather ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let el = ids_l.shape().elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
};
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let ids_dim_sz = ids_l.dims()[dim];
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
el, ids, &src, &out, left_sz, src_dim_sz, ids_dim_sz, right_sz,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct IndexAdd<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map2InPlace for IndexAdd<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()> {
let ids = &self.0;
let ids_l = &self.1;
let dim = self.2;
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "index-add ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let dst_dim_sz = dst_shape.dims()[dim];
let ids_dim_sz = ids_l.dims()[0];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let params = (
ids, ids_dim_sz, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
}
struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize);
impl<'a> Map2InPlace for ScatterAdd<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()> {
let ids = &self.0;
let ids_l = &self.1;
let dim = self.2;
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
msg: "scatter-add ids should be u8 or u32",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
};
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let dst_dim_sz = dst_shape.dims()[dim];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let params = (ids, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
}
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl<'a> Map2 for Conv1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -680,16 +935,22 @@ impl<'a> Map2 for WhereCond<'a> {
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
let ids = match &self.0.slice {
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
let (ids, name) = match &self.0.slice {
CudaStorageSlice::U8(slice) => {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_u8")
}
CudaStorageSlice::U32(slice) => {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_u32")
}
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u32",
msg: "where conditions should be u8 or u32",
expected: DType::U32,
got: self.0.dtype(),
})
.w()?,
};
let ids = &ids;
let shape = ids_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
@ -699,7 +960,7 @@ impl<'a> Map2 for WhereCond<'a> {
.w()?;
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
@ -709,7 +970,7 @@ impl<'a> Map2 for WhereCond<'a> {
}
}
impl<U: crate::op::BinaryOp> Map2 for U {
impl<U: crate::op::BinaryOpT> Map2 for U {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
lhs: &CudaSlice<T>,
@ -737,6 +998,43 @@ impl<U: crate::op::BinaryOp> Map2 for U {
}
}
struct Cmp(CmpOp);
impl Map2Any for Cmp {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
lhs: &CudaSlice<T>,
lhs_l: &Layout,
rhs: &CudaSlice<T>,
rhs_l: &Layout,
dev: &CudaDevice,
) -> Result<S> {
let shape = lhs_l.shape();
let dims = shape.dims();
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dims_and_strides = dev
.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?;
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let name = match self.0 {
CmpOp::Eq => "eq",
CmpOp::Ne => "ne",
CmpOp::Lt => "lt",
CmpOp::Le => "le",
CmpOp::Gt => "gt",
CmpOp::Ge => "ge",
};
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U8(out))
}
}
fn slice_src_and_dst<'a, T>(
src: &'a CudaSlice<T>,
src_l: &Layout,
@ -762,6 +1060,50 @@ pub struct CudaStorage {
device: CudaDevice,
}
pub trait CudaDType: Sized {
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
}
macro_rules! cuda_dtype {
($ty:ty, $dtype:ident) => {
impl CudaDType for $ty {
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>> {
match &s.slice {
CudaStorageSlice::$dtype(data) => Ok(&data),
_ => Err(crate::Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}
.bt()),
}
}
fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
let slice = CudaStorageSlice::$dtype(slice);
CudaStorage { slice, device }
}
}
};
}
cuda_dtype!(u8, U8);
cuda_dtype!(u32, U32);
cuda_dtype!(f16, F16);
cuda_dtype!(bf16, BF16);
cuda_dtype!(f32, F32);
cuda_dtype!(f64, F64);
impl CudaStorage {
pub fn wrap_cuda_slice<T: CudaDType>(slice: CudaSlice<T>, device: CudaDevice) -> CudaStorage {
T::wrap_cuda_slice(slice, device)
}
pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
T::as_cuda_slice(self)
}
}
fn gemm_config<T>(
alpha: T,
beta: T,
@ -788,8 +1130,7 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})
.w()?
})?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
@ -801,8 +1142,7 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})
.w()?
})?
};
// The setup below was copied from:
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
@ -827,8 +1167,7 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})
.w()?,
})?,
};
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
@ -838,8 +1177,7 @@ fn gemm_config<T>(
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})
.w()?,
})?,
};
Ok(StridedBatchedConfig {
@ -876,7 +1214,6 @@ impl BackendStorage for CudaStorage {
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
use cudarc::driver::DevicePtr;
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
@ -955,23 +1292,25 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device().clone();
let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;
let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = Cmp(op).map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?;
Ok(Self { slice, device })
}
fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = U::V.map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
fn binary_impl<B: crate::op::BinaryOp>(
fn binary_impl<B: BinaryOpT>(
&self,
rhs: &Self,
lhs_l: &Layout,
@ -1042,11 +1381,46 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
let device = self.device().clone();
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
Ok(Self { slice, device })
}
fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
let device = self.device().clone();
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
Ok(Self { slice, device })
}
fn scatter_add(
&self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
let device = self.device().clone();
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
}
fn index_add(
&self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
let device = self.device().clone();
let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
}
fn matmul(
&self,
@ -1110,7 +1484,7 @@ impl BackendStorage for CudaStorage {
.w()?;
CudaStorageSlice::F64(out)
}
_ => Err(CudaError::InternalError("dtype mismatch in matmul op")).w()?,
_ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?,
};
let device = dev.clone();
Ok(Self { slice, device })
@ -1198,8 +1572,7 @@ impl BackendStorage for CudaStorage {
}
_ => Err(CudaError::InternalError(
"dtype mismatch in copy_strided op",
))
.w()?,
))?,
}
Ok(())
}

View File

@ -71,8 +71,7 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
}
fn to_cpu_storage(&self) -> CpuStorage {
let mut vec = Vec::new();
vec.reserve(N1 * N2 * N3);
let mut vec = Vec::with_capacity(N1 * N2 * N3);
for i1 in 0..N1 {
for i2 in 0..N2 {
vec.extend(self[i1][i2])
@ -117,12 +116,12 @@ impl Device {
}
}
pub(crate) fn rand_uniform(
pub(crate) fn rand_uniform_f64(
&self,
shape: &Shape,
dtype: DType,
lo: f64,
up: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
match self {
Device::Cpu => {
@ -136,12 +135,21 @@ impl Device {
}
}
pub(crate) fn rand_normal(
pub(crate) fn rand_uniform<T: crate::FloatDType>(
&self,
lo: T,
up: T,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
}
pub(crate) fn rand_normal_f64(
&self,
mean: f64,
std: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
match self {
Device::Cpu => {
@ -155,6 +163,15 @@ impl Device {
}
}
pub(crate) fn rand_normal<T: crate::FloatDType>(
&self,
mean: T,
std: T,
shape: &Shape,
) -> Result<Storage> {
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {

View File

@ -119,3 +119,33 @@ with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
pub trait IntDType: WithDType {
fn is_true(&self) -> bool;
fn as_usize(&self) -> usize;
}
impl IntDType for u32 {
fn is_true(&self) -> bool {
*self != 0
}
fn as_usize(&self) -> usize {
*self as usize
}
}
impl IntDType for u8 {
fn is_true(&self) -> bool {
*self != 0
}
fn as_usize(&self) -> usize {
*self as usize
}
}
pub trait FloatDType: WithDType {}
impl FloatDType for f16 {}
impl FloatDType for bf16 {}
impl FloatDType for f32 {}
impl FloatDType for f64 {}

View File

@ -1,4 +1,5 @@
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
@ -40,11 +41,11 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self> {
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -52,16 +53,11 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn binary_impl<B: crate::op::BinaryOp>(
&self,
_: &Self,
_: &Layout,
_: &Layout,
) -> Result<Self> {
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -79,7 +75,34 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -79,6 +79,19 @@ pub enum Error {
nth_shape: Shape,
},
#[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")]
ShapeMismatchSplit {
shape: Shape,
dim: usize,
n_parts: usize,
},
#[error("{op} can only be performed on a single dimension")]
OnlySingleDimension { op: &'static str, dims: Vec<usize> },
#[error("empty tensor for {op}")]
EmptyTensor { op: &'static str },
// === Device Errors ===
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
DeviceMismatchBinaryOp {
@ -106,11 +119,11 @@ pub enum Error {
msg: &'static str,
},
#[error("{op} invalid index {index} with vocab {vocab_size}")]
#[error("{op} invalid index {index} with dim size {size}")]
InvalidIndex {
op: &'static str,
index: usize,
vocab_size: usize,
size: usize,
},
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
@ -168,6 +181,7 @@ pub enum Error {
#[error("unsupported safetensor dtype {0:?}")]
UnsupportedSafeTensorDtype(safetensors::Dtype),
/// Arbitrary errors wrapping.
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
@ -176,6 +190,10 @@ pub enum Error {
inner: Box<Self>,
backtrace: Box<std::backtrace::Backtrace>,
},
/// User generated error message, typically created via `bail!`.
#[error("{0}")]
Msg(String),
}
pub type Result<T> = std::result::Result<T, Error>;
@ -197,3 +215,24 @@ impl Error {
}
}
}
#[macro_export]
macro_rules! bail {
($msg:literal $(,)?) => {
return Err($crate::Error::Msg(format!($msg).into()).bt())
};
($err:expr $(,)?) => {
return Err($crate::Error::Msg(format!($err).into()).bt())
};
($fmt:expr, $($arg:tt)*) => {
return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
};
}
pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
match (r1, r2) {
(Ok(r1), Ok(r2)) => Ok((r1, r2)),
(Err(e), _) => Err(e),
(_, Err(e)) => Err(e),
}
}

516
candle-core/src/ggml.rs Normal file
View File

@ -0,0 +1,516 @@
//! Support for the GGML file format.
use crate::{DType, Device, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use half::f16;
// Default to QK_K 256 rather than 64.
pub const QK_K: usize = 256;
pub const K_SCALE_SIZE: usize = 12;
pub const QK4_0: usize = 32;
pub const QK4_1: usize = 32;
pub const QK5_0: usize = 32;
pub const QK5_1: usize = 32;
pub const QK8_0: usize = 32;
pub const QK8_1: usize = 32;
#[repr(C)]
struct BlockQ4_0 {
d: f16,
qs: [u8; QK4_0 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
#[repr(C)]
struct BlockQ4_1 {
d: f16,
m: f16,
qs: [u8; QK4_1 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
#[repr(C)]
struct BlockQ5_0 {
d: f16,
qh: [u8; 4],
qs: [u8; QK5_0 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
#[repr(C)]
struct BlockQ5_1 {
d: f16,
m: f16,
qh: [u8; 4],
qs: [u8; QK5_1 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
#[repr(C)]
struct BlockQ8_0 {
d: f16,
qs: [u8; QK8_0],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
#[repr(C)]
struct BlockQ8_1 {
d: f16,
s: f16,
qs: [u8; QK8_1],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
#[repr(C)]
struct BlockQ2K {
scales: [u8; QK_K / 16],
qs: [u8; QK_K / 4],
d: f16,
dmin: f16,
}
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
#[repr(C)]
struct BlockQ3K {
hmask: [u8; QK_K / 8],
qs: [u8; QK_K / 4],
scales: [u8; 12],
d: f16,
}
const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
#[repr(C)]
struct BlockQ4K {
d: f16,
dmin: f16,
scales: [u8; K_SCALE_SIZE],
qs: [u8; QK_K / 2],
}
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
#[repr(C)]
struct BlockQ5K {
d: f16,
dmin: f16,
scales: [u8; K_SCALE_SIZE],
qh: [u8; QK_K / 8],
qs: [u8; QK_K / 2],
}
const _: () =
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
#[repr(C)]
struct BlockQ6K {
ql: [u8; QK_K / 2],
qh: [u8; QK_K / 4],
scales: [i8; QK_K / 16],
d: f16,
}
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
}
let mut ys_index = 0;
for x in xs {
let d = x.d.to_f32();
let min = x.dmin.to_f32();
let q = &x.qs;
let mut is = 0;
for n in (0..QK_K).step_by(128) {
// Step by 32 over q.
let q = &q[n / 4..];
let mut shift = 0;
for _j in 0..4 {
let sc = x.scales[is];
is += 1;
let dl = d * (sc & 0xF) as f32;
let ml = min * (sc >> 4) as f32;
for q in &q[..16] {
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
ys[ys_index] = y;
ys_index += 1;
}
let sc = x.scales[is];
is += 1;
let dl = d * (sc & 0xF) as f32;
let ml = min * (sc >> 4) as f32;
for q in &q[16..32] {
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
ys[ys_index] = y;
ys_index += 1;
}
shift += 2;
}
}
}
Ok(())
}
fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
if j < 4 {
let d = q[j] & 63;
let m = q[j + 4] & 63;
(d, m)
} else {
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
(d, m)
}
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}")
}
let mut ys_index = 0;
for x in xs.iter() {
let d = x.d.to_f32();
let min = x.dmin.to_f32();
let q = &x.qs;
let mut is = 0;
for j in (0..QK_K).step_by(64) {
let q = &q[j / 2..j / 2 + 32];
let (sc, m) = get_scale_min_k4(is, &x.scales);
let d1 = d * sc as f32;
let m1 = min * m as f32;
let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
let d2 = d * sc as f32;
let m2 = min * m as f32;
for q in q {
let y = d1 * (q & 0xF) as f32 - m1;
ys[ys_index] = y;
ys_index += 1;
}
for q in q {
let y = d2 * (q >> 4) as f32 - m2;
ys[ys_index] = y;
ys_index += 1;
}
is += 2;
}
}
Ok(())
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
}
for x in xs.iter() {
let d = x.d.to_f32();
let ql = &x.ql;
let qh = &x.qh;
let sc = &x.scales;
for n in (0..QK_K).step_by(128) {
let idx = n / 128;
let ys = &mut ys[n..];
let sc = &sc[8 * idx..];
let ql = &ql[64 * idx..];
let qh = &qh[32 * idx..];
for l in 0..32 {
let is = l / 16;
let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
ys[l] = d * sc[is] as f32 * q1 as f32;
ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
}
}
}
Ok(())
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Magic {
Ggjt,
Ggla,
Ggmf,
Ggml,
Ggsn,
}
impl TryFrom<u32> for Magic {
type Error = crate::Error;
fn try_from(value: u32) -> Result<Self> {
let magic = match value {
0x67676a74 => Self::Ggjt,
0x67676c61 => Self::Ggla,
0x67676d66 => Self::Ggmf,
0x67676d6c => Self::Ggml,
0x6767736e => Self::Ggsn,
_ => crate::bail!("unknown magic {value:08x}"),
};
Ok(magic)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VersionedMagic {
GgmlUnversioned,
GgmfV1,
GgjtV1,
GgjtV2,
GgjtV3,
}
impl VersionedMagic {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let magic = reader.read_u32::<LittleEndian>()?;
let magic = Magic::try_from(magic)?;
if magic == Magic::Ggml {
return Ok(Self::GgmlUnversioned);
}
let version = reader.read_u32::<LittleEndian>()?;
let versioned_magic = match (magic, version) {
(Magic::Ggmf, 1) => Self::GgmfV1,
(Magic::Ggjt, 1) => Self::GgjtV1,
(Magic::Ggjt, 2) => Self::GgjtV2,
(Magic::Ggjt, 3) => Self::GgjtV3,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
}
fn align32(&self) -> bool {
match self {
Self::GgmlUnversioned | Self::GgmfV1 => false,
Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HParams {
pub n_vocab: u32,
pub n_embd: u32,
pub n_mult: u32,
pub n_head: u32,
pub n_layer: u32,
pub n_rot: u32,
pub ftype: u32,
}
impl HParams {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let n_vocab = reader.read_u32::<LittleEndian>()?;
let n_embd = reader.read_u32::<LittleEndian>()?;
let n_mult = reader.read_u32::<LittleEndian>()?;
let n_head = reader.read_u32::<LittleEndian>()?;
let n_layer = reader.read_u32::<LittleEndian>()?;
let n_rot = reader.read_u32::<LittleEndian>()?;
let ftype = reader.read_u32::<LittleEndian>()?;
Ok(Self {
n_vocab,
n_embd,
n_mult,
n_head,
n_layer,
n_rot,
ftype,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Vocab {
pub token_score_pairs: Vec<(Vec<u8>, f32)>,
}
impl Vocab {
fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556
let mut token_score_pairs = Vec::with_capacity(n_vocab);
for _index in 0..n_vocab {
let len = reader.read_u32::<LittleEndian>()? as usize;
let mut word = vec![0u8; len];
reader.read_exact(&mut word)?;
let score = reader.read_f32::<LittleEndian>()?;
token_score_pairs.push((word, score))
}
Ok(Self { token_score_pairs })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GgmlDType {
F32,
F16,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
}
impl GgmlDType {
fn from_u32(u: u32) -> Result<Self> {
let dtype = match u {
0 => Self::F32,
1 => Self::F16,
2 => Self::Q4_0,
3 => Self::Q4_1,
6 => Self::Q5_0,
7 => Self::Q5_1,
8 => Self::Q8_0,
9 => Self::Q8_1,
10 => Self::Q2K,
11 => Self::Q3K,
12 => Self::Q4K,
13 => Self::Q5K,
14 => Self::Q6K,
_ => crate::bail!("unknown dtype for tensor {u}"),
};
Ok(dtype)
}
fn type_size(&self) -> usize {
match self {
Self::F32 => 4,
Self::F16 => 2,
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
Self::Q2K => std::mem::size_of::<BlockQ2K>(),
Self::Q3K => std::mem::size_of::<BlockQ3K>(),
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
}
}
fn blck_size(&self) -> usize {
match self {
Self::F32 => 1,
Self::F16 => 1,
Self::Q4_0 => QK4_0,
Self::Q4_1 => QK4_1,
Self::Q5_0 => QK5_0,
Self::Q5_1 => QK5_1,
Self::Q8_0 => QK8_0,
Self::Q8_1 => QK8_1,
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K => QK_K,
}
}
}
#[derive(Debug)]
pub struct Content {
pub magic: VersionedMagic,
pub hparams: HParams,
pub vocab: Vocab,
pub tensors: Vec<(String, Tensor)>,
}
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
device: &Device,
) -> Result<(String, Tensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
let dtype = reader.read_u32::<LittleEndian>()?;
let dtype = GgmlDType::from_u32(dtype)?;
let mut dims = vec![0u32; n_dims as usize];
reader.read_u32_into::<LittleEndian>(&mut dims)?;
let mut name = vec![0u8; name_len as usize];
reader.read_exact(&mut name)?;
let name = String::from_utf8_lossy(&name).into_owned();
if magic.align32() {
let pos = reader.stream_position()?;
reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
}
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * dtype.type_size() / dtype.blck_size();
println!("{name} {dtype:?} {dims:?}");
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
let tensor = match dtype {
GgmlDType::F32 => Tensor::from_raw_buffer(&raw_data, DType::F32, &dims, device)?,
GgmlDType::F16 => Tensor::from_raw_buffer(&raw_data, DType::F16, &dims, device)?,
GgmlDType::Q2K => {
let mut f32_data = vec![0f32; tensor_elems];
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ2K>();
let raw_data =
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ2K, n_blocks) };
dequantize_row_q2k(raw_data, &mut f32_data)?;
// Maybe we should use bf16 instead?
Tensor::from_vec(f32_data, dims, device)?
}
GgmlDType::Q4K => {
let mut f32_data = vec![0f32; tensor_elems];
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ4K>();
let raw_data =
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ4K, n_blocks) };
dequantize_row_q4k(raw_data, &mut f32_data)?;
Tensor::from_vec(f32_data, dims, device)?
}
GgmlDType::Q6K => {
let mut f32_data = vec![0f32; tensor_elems];
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ6K>();
let raw_data =
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ6K, n_blocks) };
dequantize_row_q6k(raw_data, &mut f32_data)?;
Tensor::from_vec(f32_data, dims, device)?
}
_ => crate::bail!("quantized type {dtype:?} used in {name} is not supported yet"),
};
Ok((name, tensor))
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
let magic = VersionedMagic::read(reader)?;
let hparams = HParams::read(reader)?;
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
let mut tensors = vec![];
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.push((name, tensor))
}
Ok(Self {
magic,
hparams,
vocab,
tensors,
})
}
}

View File

@ -7,7 +7,7 @@ impl Tensor {
/// Intended to be use by the trait `.i()`
///
/// ```
/// # use candle::{Tensor, DType, Device, IndexOp};
/// # use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = a.i(0..1)?;
@ -22,7 +22,7 @@ impl Tensor {
/// let c = a.i((.., ..=2))?;
/// assert_eq!(c.shape().dims(), &[2, 3]);
///
/// # Ok::<(), candle::Error>(())
/// # Ok::<(), candle_core::Error>(())
/// ```
fn index(&self, indexers: &[TensorIndexer]) -> Result<Self, Error> {
let mut x = self.clone();
@ -42,7 +42,7 @@ impl Tensor {
Bound::Excluded(n) => *n,
Bound::Unbounded => dims[i],
};
let out = x.narrow(current_dim, start, stop - start)?;
let out = x.narrow(current_dim, start, stop.saturating_sub(start))?;
current_dim += 1;
out
}

View File

@ -1,8 +1,8 @@
//! ML framework for Rust
//!
//! ```rust
//! use candle::{Tensor, DType, Device};
//! # use candle::Error;
//! use candle_core::{Tensor, DType, Device};
//! # use candle_core::Error;
//! # fn main() -> Result<(), Error>{
//!
//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
@ -33,18 +33,19 @@
//!
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
mod backend;
mod backprop;
pub mod backend;
pub mod backprop;
mod conv;
mod convert;
mod cpu_backend;
pub mod cpu_backend;
#[cfg(feature = "cuda")]
mod cuda_backend;
pub mod cuda_backend;
mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
mod error;
pub mod error;
pub mod ggml;
mod indexer;
pub mod layout;
#[cfg(feature = "mkl")]
@ -52,7 +53,7 @@ mod mkl;
pub mod npy;
mod op;
pub mod safetensors;
mod shape;
pub mod shape;
mod storage;
mod strided_index;
mod tensor;
@ -61,10 +62,11 @@ mod variable;
pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation};
pub use dtype::{DType, WithDType};
pub use dtype::{DType, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
pub use layout::Layout;
pub use op::{CustomOp1, CustomOp2, CustomOp3};
pub use shape::{Shape, D};
pub use storage::Storage;
pub use strided_index::{StridedBlocks, StridedIndex};

View File

@ -1,15 +1,74 @@
use crate::Tensor;
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
use num_traits::float::Float;
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum CmpOp {
Eq,
Ne,
Le,
Ge,
Lt,
Gt,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReduceOp {
Sum,
Min,
Max,
ArgMin,
ArgMax,
}
impl ReduceOp {
pub(crate) fn name(&self) -> &'static str {
match self {
Self::ArgMax => "argmax",
Self::ArgMin => "argmin",
Self::Min => "min",
Self::Max => "max",
Self::Sum => "sum",
}
}
}
// These ops return the same type as their input type.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinaryOp {
Add,
Mul,
Sub,
Div,
}
// Unary ops with no argument
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnaryOp {
Exp,
Log,
Sin,
Cos,
Abs,
Neg,
Sqr,
Sqrt,
Gelu,
Relu,
}
#[derive(Clone)]
pub(crate) enum Op {
Add(Tensor, Tensor),
Mul(Tensor, Tensor),
Sub(Tensor, Tensor),
Div(Tensor, Tensor),
pub enum Op {
Binary(Tensor, Tensor, BinaryOp),
Unary(Tensor, UnaryOp),
Cmp(Tensor, CmpOp),
// The third argument is the reduced shape with `keepdim=true`.
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
Gather(Tensor, Tensor, usize),
ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
IndexAdd(Tensor, Tensor, Tensor, usize),
WhereCond(Tensor, Tensor, Tensor),
#[allow(dead_code)]
@ -28,29 +87,126 @@ pub(crate) enum Op {
mul: f64,
add: f64,
},
Sum(Tensor, Vec<usize>),
ToDType(Tensor),
Copy(Tensor),
Broadcast(Tensor),
Exp(Tensor),
Log(Tensor),
Sin(Tensor),
Cos(Tensor),
Abs(Tensor),
Narrow(Tensor, usize, usize, usize),
Neg(Tensor),
Reshape(Tensor),
Softmax(Tensor, usize),
Sqr(Tensor),
Sqrt(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Gelu(Tensor),
Relu(Tensor),
Elu(Tensor, f64),
// TODO: Support for custom ops.
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
}
pub(crate) trait UnaryOp {
/// Unary ops that can be defined in user-land.
pub trait CustomOp1: Send + Sync {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp2: Send + Sync {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp3: Send + Sync {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_arg3: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait UnaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
@ -73,7 +229,7 @@ pub(crate) trait UnaryOp {
fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
}
pub(crate) trait BinaryOp {
pub trait BinaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
const V: Self;
@ -115,7 +271,7 @@ pub(crate) struct Relu;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
impl BinaryOp for $op {
impl BinaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("b", $name);
const V: Self = $op;
@ -169,7 +325,7 @@ bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOp for $op {
impl UnaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
@ -201,7 +357,7 @@ macro_rules! unary_op {
};
($op: ident, $name: literal, $a: ident, $e: expr, $f32_vec:ident, $f64_vec:ident) => {
impl UnaryOp for $op {
impl UnaryOpT for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("u", $name);
const V: Self = $op;
@ -259,7 +415,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOp for Gelu {
impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu";
const V: Self = Gelu;
#[inline(always)]
@ -325,7 +481,7 @@ impl UnaryOp for Gelu {
}
}
impl UnaryOp for Relu {
impl UnaryOpT for Relu {
const NAME: &'static str = "relu";
const KERNEL: &'static str = "urelu";
const V: Self = Relu;
@ -354,3 +510,63 @@ impl UnaryOp for Relu {
v
}
}
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
/// properly checked when creating a new value
#[derive(Clone)]
pub struct BackpropOp(Option<Op>);
impl BackpropOp {
pub(crate) fn none() -> Self {
BackpropOp(None)
}
pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
let op = if arg.track_op() {
Some(f(arg.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
let op = if arg1.track_op() || arg2.track_op() {
Some(f(arg1.clone(), arg2.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new3(
arg1: &Tensor,
arg2: &Tensor,
arg3: &Tensor,
f: impl Fn(Tensor, Tensor, Tensor) -> Op,
) -> Self {
let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
Some(f(args))
} else {
None
};
Self(op)
}
}
impl std::ops::Deref for BackpropOp {
type Target = Option<Op>;
fn deref(&self) -> &Self::Target {
&self.0
}
}

View File

@ -1,7 +1,9 @@
use crate::{DType, Device, Error, Result, Tensor, WithDType};
use safetensors::tensor as st;
pub use safetensors::tensor::SafeTensors;
use safetensors::tensor::SafeTensors;
use std::borrow::Cow;
use std::collections::HashMap;
use std::path::Path;
impl From<DType> for st::Dtype {
fn from(value: DType) -> Self {
@ -52,6 +54,27 @@ impl st::View for Tensor {
}
}
impl st::View for &Tensor {
fn dtype(&self) -> st::Dtype {
(*self).dtype().into()
}
fn shape(&self) -> &[usize] {
self.dims()
}
fn data(&self) -> Cow<[u8]> {
// This copies data from GPU to CPU.
// TODO: Avoid the unwrap here.
Cow::Owned(convert_back(self).unwrap())
}
fn data_len(&self) -> usize {
let n: usize = self.dims().iter().product();
let bytes_per_element = (*self).dtype().size_in_bytes();
n * bytes_per_element
}
}
impl Tensor {
pub fn save_safetensors<P: AsRef<std::path::Path>>(
&self,
@ -63,15 +86,15 @@ impl Tensor {
}
}
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
let v = view.data();
fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
let size_in_bytes = T::DTYPE.size_in_bytes();
let elem_count = v.len() / size_in_bytes;
if (v.as_ptr() as usize) % size_in_bytes == 0 {
let elem_count = data.len() / size_in_bytes;
if (data.as_ptr() as usize) % size_in_bytes == 0 {
// SAFETY This is safe because we just checked that this
// was correctly aligned.
let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
Tensor::from_slice(data, view.shape(), device)
let data: &[T] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
Tensor::from_slice(data, shape, device)
} else {
// XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
// Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
@ -81,13 +104,57 @@ fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<
// We're downgrading the `c` pointer from T to u8, which removes alignment
// constraints.
unsafe {
std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
c.set_len(elem_count)
}
Tensor::from_slice(&c, view.shape(), device)
Tensor::from_slice(&c, shape, device)
}
}
fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
data: &[u8],
shape: &[usize],
device: &Device,
conv: F,
) -> Result<Tensor> {
let size_in_bytes = std::mem::size_of::<T>();
let elem_count = data.len() / size_in_bytes;
if (data.as_ptr() as usize) % size_in_bytes == 0 {
// SAFETY This is safe because we just checked that this
// was correctly aligned.
let data: &[T] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
Tensor::from_vec(data, shape, device)
} else {
// XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
// Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
let mut c: Vec<T> = Vec::with_capacity(elem_count);
// SAFETY: We just created c, so the allocated memory is necessarily
// contiguous and non overlapping with the view's data.
// We're downgrading the `c` pointer from T to u8, which removes alignment
// constraints.
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
c.set_len(elem_count)
}
let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
Tensor::from_vec(c, shape, device)
}
}
fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
view: &st::TensorView<'_>,
device: &Device,
conv: F,
) -> Result<Tensor> {
convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
}
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
convert_slice::<T>(view.data(), view.shape(), device)
}
fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
let size_in_bytes = T::DTYPE.size_in_bytes();
let length = vs.len() * size_in_bytes;
@ -112,19 +179,55 @@ impl<'a> Load for st::TensorView<'a> {
}
}
pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
impl Tensor {
pub fn from_raw_buffer(
data: &[u8],
dtype: DType,
shape: &[usize],
device: &Device,
) -> Result<Self> {
match dtype {
DType::U8 => convert_slice::<u8>(data, shape, device),
DType::U32 => convert_slice::<u32>(data, shape, device),
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
DType::F16 => convert_slice::<half::f16>(data, shape, device),
DType::F32 => convert_slice::<f32>(data, shape, device),
DType::F64 => convert_slice::<f64>(data, shape, device),
}
}
}
fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::U8 => convert_::<u8>(view, device),
st::Dtype::U32 => convert_::<u8>(view, device),
st::Dtype::U16 => {
let conv = |x| Ok(u32::from(x));
convert_with_cast_::<u16, u32, _>(view, device, conv)
}
st::Dtype::U32 => convert_::<u32>(view, device),
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
st::Dtype::F16 => convert_::<half::f16>(view, device),
st::Dtype::F32 => convert_::<f32>(view, device),
st::Dtype::F64 => convert_::<f64>(view, device),
st::Dtype::I32 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i32, u32, _>(view, device, conv)
}
st::Dtype::I64 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i64, u32, _>(view, device, conv)
}
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}
pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
// TODO: This makes an unnecessary copy when the tensor is on the cpu.
let tensor = tensor.flatten_all()?;
match tensor.dtype() {
@ -137,6 +240,19 @@ pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
}
}
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
let data = std::fs::read(filename.as_ref())?;
let st = safetensors::SafeTensors::deserialize(&data)?;
st.tensors()
.into_iter()
.map(|(name, view)| Ok((name, view.load(device)?)))
.collect()
}
pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Result<()> {
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
pub struct MmapedFile(memmap2::Mmap);
impl MmapedFile {
@ -173,11 +289,15 @@ mod tests {
}
#[test]
fn save_multiple_tensors() {
fn save_load_multiple_tensors() {
let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
st::serialize_to_file(map, &None, std::path::Path::new("multi.safetensors")).unwrap();
save(&map, "multi.safetensors").unwrap();
let weights = load("multi.safetensors", &Device::Cpu).unwrap();
assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]);
let bytes = std::fs::read("multi.safetensors").unwrap();
assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
std::fs::remove_file("multi.safetensors").unwrap();

View File

@ -41,6 +41,12 @@ impl From<usize> for Shape {
}
}
impl From<(usize,)> for Shape {
fn from(d1: (usize,)) -> Self {
Self(vec![d1.0])
}
}
impl From<(usize, usize)> for Shape {
fn from(d12: (usize, usize)) -> Self {
Self(vec![d12.0, d12.1])
@ -87,6 +93,12 @@ macro_rules! extract_dims {
}
}
}
impl crate::Tensor {
pub fn $fn_name(&self) -> Result<$out_type> {
self.shape().$fn_name()
}
}
impl std::convert::TryInto<$out_type> for Shape {
type Error = crate::Error;
fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
@ -328,23 +340,23 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(
r3,
dims3,
3,
|d: &[usize]| (d[0], d[1], d[2]),
(usize, usize, usize)
);
extract_dims!(
r4,
dims4,
4,
|d: &[usize]| (d[0], d[1], d[2], d[3]),
(usize, usize, usize, usize)
);
extract_dims!(
r5,
dims5,
5,
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
(usize, usize, usize, usize, usize)

View File

@ -1,5 +1,6 @@
use crate::backend::BackendStorage;
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@ -80,26 +81,48 @@ impl Storage {
}
}
pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.sum(layout, s)?;
pub(crate) fn cmp(
&self,
op: CmpOp,
rhs: &Self,
lhs_layout: &Layout,
rhs_layout: &Layout,
) -> Result<Self> {
self.same_device(rhs, "cmp")?;
self.same_dtype(rhs, "cmp")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.sum(layout, s)?;
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "cmp",
}
.bt())
}
}
}
// This assumes a contiguous layout and no offset.
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
match self {
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
Storage::Cpu(storage) => {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cuda(storage))
}
}
Ok(())
}
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
@ -115,8 +138,65 @@ impl Storage {
}
}
pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
// TODO: Different code path for the contiguous case?
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
match self {
Self::Cpu(storage) => {
let (storage, shape) = c.cpu_fwd(storage, l)?;
Ok((Self::Cpu(storage), shape))
}
Self::Cuda(storage) => {
let (storage, shape) = c.cuda_fwd(storage, l)?;
Ok((Self::Cuda(storage), shape))
}
}
}
pub(crate) fn custom_op2(
&self,
l1: &Layout,
t2: &Self,
l2: &Layout,
c: &dyn CustomOp2,
) -> Result<(Self, Shape)> {
self.same_device(t2, c.name())?;
match (self, t2) {
(Self::Cpu(s1), Self::Cpu(s2)) => {
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
Ok((Self::Cpu(s), shape))
}
(Self::Cuda(s1), Self::Cuda(s2)) => {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
Ok((Self::Cuda(s), shape))
}
_ => unreachable!(),
}
}
pub(crate) fn custom_op3(
&self,
l1: &Layout,
t2: &Self,
l2: &Layout,
t3: &Self,
l3: &Layout,
c: &dyn CustomOp3,
) -> Result<(Self, Shape)> {
self.same_device(t2, c.name())?;
self.same_device(t3, c.name())?;
match (self, t2, t3) {
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cpu(s), shape))
}
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cuda(s), shape))
}
_ => unreachable!(),
}
}
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.unary_impl::<B>(layout)?;
@ -129,7 +209,7 @@ impl Storage {
}
}
pub(crate) fn binary_impl<B: op::BinaryOp>(
pub(crate) fn binary_impl<B: op::BinaryOpT>(
&self,
rhs: &Self,
lhs_layout: &Layout,
@ -215,21 +295,96 @@ impl Storage {
}
}
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
self.same_device(rhs, "embedding")?;
pub(crate) fn gather(
&self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(indexes, "index-add")?;
match (self, indexes) {
(Self::Cpu(s), Self::Cpu(indexes)) => {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(s), Self::Cuda(indexes)) => {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cuda(storage))
}
_ => unreachable!(),
}
}
pub(crate) fn scatter_add(
&self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(indexes, "scatter-add")?;
self.same_device(source, "scatter-add")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
_ => unreachable!(),
}
}
pub(crate) fn index_add(
&self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(indexes, "index-add")?;
self.same_device(source, "index-add")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
_ => unreachable!(),
}
}
pub(crate) fn index_select(
&self,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
d: usize,
) -> Result<Self> {
self.same_device(rhs, "index-select")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.embedding(layout, rhs, rhs_l)?;
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.embedding(layout, rhs, rhs_l)?;
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "embedding",
op: "index-select",
}
.bt()),
}

File diff suppressed because it is too large Load Diff

View File

@ -34,25 +34,50 @@ impl Var {
Ok(Self(inner))
}
pub fn rand<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
lo: f64,
up: f64,
) -> Result<Self> {
let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?;
pub fn from_tensor(t: &Tensor) -> Result<Self> {
let inner = t.make_var()?;
Ok(Self(inner))
}
pub fn randn<S: Into<Shape>>(
pub fn rand_f64<S: Into<Shape>>(
lo: f64,
up: f64,
s: S,
dtype: DType,
device: &Device,
) -> Result<Self> {
let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
Ok(Self(inner))
}
pub fn randn_f64<S: Into<Shape>>(
mean: f64,
std: f64,
s: S,
dtype: DType,
device: &Device,
) -> Result<Self> {
let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?;
let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
Ok(Self(inner))
}
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,
s: S,
device: &Device,
) -> Result<Self> {
let inner = Tensor::rand_impl(lo, up, s, device, true)?;
Ok(Self(inner))
}
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
mean: T,
std: T,
s: S,
device: &Device,
) -> Result<Self> {
let inner = Tensor::randn_impl(mean, std, s, device, true)?;
Ok(Self(inner))
}

View File

@ -0,0 +1,116 @@
use candle_core::backend::BackendStorage;
use candle_core::cpu_backend;
use candle_core::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
mod test_utils;
use test_utils::to_vec1_round;
fn fwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
if v.is_sign_positive() {
v
} else {
let alpha = T::from(alpha).unwrap_or(T::nan());
(v.exp() - T::one()) * alpha
}
}
struct Elu {
alpha: f64,
}
impl CustomOp1 for Elu {
fn name(&self) -> &'static str {
"elu"
}
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
let storage = candle_core::map_dtype!(
"elu",
s,
|s| cpu_backend::unary_map(s, l, |v| fwd(v, self.alpha)),
(BF16, F16, F32, F64)
);
Ok((storage, l.shape().clone()))
}
}
#[test]
fn custom_op1_no_backward() -> Result<()> {
let cpu = &Device::Cpu;
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
let t = (t - 5.)?;
let elu_t = t.custom_op1(Elu { alpha: 1. })?;
assert_eq!(
to_vec1_round(&elu_t, 4)?,
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
);
Ok(())
}
// Define a similar struct as Elu but with backward support.
fn bwd<T: num_traits::Float>(v: T, alpha: f64) -> T {
if v.is_sign_positive() {
T::one()
} else {
let alpha = T::from(alpha).unwrap_or(T::nan());
v.exp() * alpha
}
}
struct EluBackward {
alpha: f64,
}
impl CustomOp1 for EluBackward {
fn name(&self) -> &'static str {
"elu-bwd"
}
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
let storage = candle_core::map_dtype!(
"elu-bwd",
s,
|s| cpu_backend::unary_map(s, l, |v| bwd(v, self.alpha)),
(BF16, F16, F32, F64)
);
Ok((storage, l.shape().clone()))
}
}
struct EluWithBackward(Elu);
impl EluWithBackward {
fn new(alpha: f64) -> Self {
Self(Elu { alpha })
}
}
impl CustomOp1 for EluWithBackward {
fn name(&self) -> &'static str {
"elu"
}
fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
self.0.cpu_fwd(s, l)
}
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
let alpha = self.0.alpha;
let bwd = arg.custom_op1(EluBackward { alpha })?;
Ok(Some(grad_res.mul(&bwd)?))
}
}
#[test]
fn custom_op1_with_backward() -> Result<()> {
let cpu = &Device::Cpu;
let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
let grads = elu_t.backward()?;
let grad_x = grads.get(&t).unwrap();
assert_eq!(to_vec1_round(grad_x, 4)?, [0.2707, 1.0, 1.0]);
Ok(())
}

View File

@ -1,5 +1,5 @@
use anyhow::Result;
use candle::{DType, Device::Cpu, Tensor};
use candle_core::{DType, Device::Cpu, Tensor};
#[test]
fn display_scalar() -> Result<()> {

View File

@ -1,5 +1,5 @@
use anyhow::{Context, Result};
use candle::{Device, Shape, Var};
use candle_core::{Device, Shape, Tensor, Var};
mod test_utils;
fn simple_grad(device: &Device) -> Result<()> {
@ -79,7 +79,97 @@ fn grad_descent(device: &Device) -> Result<()> {
Ok(())
}
fn unary_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., 4., 0.15], device)?;
let x = x.as_tensor();
let y = (x.log()? + 1.)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [2.0986123, 1.0, 2.3862944, -0.89712]);
assert_eq!(grad_x.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
let y = x.exp()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
y.to_vec1::<f32>()?,
[20.085537, 2.7182817, 54.59815, 1.1618342]
);
assert_eq!(
grad_x.to_vec1::<f32>()?,
[20.085537, 2.7182817, 54.59815, 1.1618342]
);
let y = x.exp()?.sqr()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
y.to_vec1::<f32>()?,
[403.4288, 7.3890557, 2980.9578, 1.3498588]
);
// exp(x)^2 = exp(2*x)
assert_eq!(
grad_x.to_vec1::<f32>()?,
[806.8576, 14.778111, 5961.9155, 2.6997175]
);
let y = x.sin()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[0.1411, 0.8415, -0.7568, 0.1494],
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[-0.99, 0.5403, -0.6536, 0.9888],
);
let y = x.cos()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-0.99, 0.5403, -0.6536, 0.9888],
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[-0.1411, -0.8415, 0.7568, -0.1494],
);
let y = x.sqr()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [9.0, 1.0, 16.0, 0.0225]);
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, 8.0, 0.3]);
let y = x.sqr()?.sqrt()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [3.0, 1.0, 4.0, 0.15]);
assert_eq!(grad_x.to_vec1::<f32>()?, [1.0, 1.0, 1.0, 1.0]);
let y = x.neg()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [-3.0, -1.0, -4.0, -0.15]);
assert_eq!(grad_x.to_vec1::<f32>()?, [-1.0, -1.0, -1.0, -1.0]);
let y = x.affine(0.2, 1.)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [1.6, 1.2, 1.8, 1.03]);
assert_eq!(grad_x.to_vec1::<f32>()?, [0.2, 0.2, 0.2, 0.2]);
let y = Tensor::new(1f32, device)?.broadcast_div(x)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
assert_eq!(
grad_x.to_vec1::<f32>()?,
[-0.11111111, -1.0, -0.0625, -44.444443],
);
let y = x.broadcast_div(&Tensor::new(0.5f32, device)?)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [6., 2., 8., 0.3]);
assert_eq!(grad_x.to_vec1::<f32>()?, [2., 2., 2., 2.]);
Ok(())
}
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
test_device!(unary_grad, unary_grad_cpu, unary_grad_gpu);

View File

@ -1,5 +1,5 @@
use anyhow::Result;
use candle::{Device, IndexOp, Tensor};
use candle_core::{Device, IndexOp, Tensor};
mod test_utils;
@ -58,6 +58,19 @@ fn range_index() -> Result<()> {
let result = tensor.i(..=1)?;
assert_eq!(result.dims(), &[2, 3]);
assert_eq!(result.to_vec2::<u32>()?, &[[0, 1, 2], [3, 4, 5]]);
// Empty range
let result = tensor.i(1..1)?;
assert_eq!(result.dims(), &[0, 3]);
let empty: [[u32; 3]; 0] = [];
assert_eq!(result.to_vec2::<u32>()?, &empty);
// Similar to PyTorch, allow empty ranges when the computed length is negative.
#[allow(clippy::reversed_empty_ranges)]
let result = tensor.i(1..0)?;
assert_eq!(result.dims(), &[0, 3]);
let empty: [[u32; 3]; 0] = [];
assert_eq!(result.to_vec2::<u32>()?, &empty);
Ok(())
}

View File

@ -1,5 +1,6 @@
mod test_utils;
use candle::{Device, IndexOp, Result, Tensor};
use candle_core as candle;
fn contiguous(device: &Device) -> Result<()> {
let tensor = Tensor::arange(0u32, 24u32, device)?.reshape((2, 3, 4))?;

View File

@ -1,10 +1,9 @@
mod test_utils;
use candle::{DType, Device, IndexOp, Result, Tensor};
use test_utils::to_vec3_round;
use candle_core::{DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
let (dim1, dim2) = tensor.shape().r2()?;
let (dim1, dim2) = tensor.dims2()?;
assert_eq!(dim1, 5);
assert_eq!(dim2, 2);
Ok(())
@ -12,7 +11,7 @@ fn zeros(device: &Device) -> Result<()> {
fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.shape().r1()?;
let dim1 = tensor.dims1()?;
assert_eq!(dim1, 3);
let content: Vec<f32> = tensor.to_vec1()?;
assert_eq!(content, [3., 1., 4.]);
@ -28,7 +27,7 @@ fn add_mul(device: &Device) -> Result<()> {
fn tensor_2d(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
let dims = tensor.shape().r2()?;
let dims = tensor.dims2()?;
assert_eq!(dims, (2, 5));
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
assert_eq!(content, data);
@ -41,7 +40,7 @@ fn binary_op(device: &Device) -> Result<()> {
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
let tensor2 = Tensor::new(data2, device)?;
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
let dims = tensor.shape().r2()?;
let dims = tensor.dims2()?;
assert_eq!(dims, (2, 5));
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]);
@ -56,7 +55,7 @@ fn binary_op(device: &Device) -> Result<()> {
fn transpose(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?.t()?;
let dims = tensor.shape().r2()?;
let dims = tensor.dims2()?;
assert_eq!(dims, (5, 2));
assert_eq!(
tensor.to_vec2::<f32>()?,
@ -68,42 +67,6 @@ fn transpose(device: &Device) -> Result<()> {
Ok(())
}
fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
let t0 = tensor.log()?.softmax(0)?;
let t1 = tensor.log()?.softmax(1)?;
let t2 = tensor.log()?.softmax(2)?;
assert_eq!(
to_vec3_round(t0, 4)?,
&[
// 3/5, 1/2, 4/11
[[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]],
// 2/5, 1/2, 7/11
[[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]]
]
);
assert_eq!(
to_vec3_round(t1, 4)?,
&[
// 3/4, 1/6, 4/13
[[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]],
// 2/10, 1/3, 7/15
[[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]]
]
);
assert_eq!(
to_vec3_round(t2, 4)?,
&[
// (3, 1, 4) / 8, (1, 5, 9) / 15
[[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
// (2, 1, 7) / 10, (8, 2, 8) / 18
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
]
);
Ok(())
}
fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
@ -201,6 +164,278 @@ fn sum(device: &Device) -> Result<()> {
Ok(())
}
fn min(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.min_keepdim(2)?.to_vec3::<u32>()?,
&[[[1], [1]], [[1], [2]]]
);
assert_eq!(
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
&[[[2, 1, 4], [1, 2, 8]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.min_keepdim(0)?.to_vec1::<u32>()?, &[200]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);
// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor.min_keepdim(0)?.min_keepdim(1)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(
tensor.min_keepdim(1)?.min_keepdim(0)?.to_vec2::<u32>()?,
&[[200]]
);
assert_eq!(tensor.min_keepdim(0)?.to_vec2::<u32>()?, &[[200, 201]]);
let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.min_keepdim(0)?
.min_keepdim(2)?
.min_keepdim(1)?
.to_vec3::<u32>()?,
&[[[200]]]
);
assert_eq!(
tensor.min_keepdim(0)?.to_vec3::<u32>()?,
&[[
[200, 201, 202, 203],
[204, 205, 206, 207],
[208, 209, 210, 211],
[212, 213, 214, 215],
[216, 217, 218, 219]
]]
);
}
Ok(())
}
fn max(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.max_keepdim(2)?.to_vec3::<u32>()?,
&[[[4], [9]], [[7], [8]]]
);
assert_eq!(
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
&[[[3, 1, 7], [8, 5, 9]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.max_keepdim(0)?.to_vec1::<u32>()?, &[3999]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);
// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor.max_keepdim(0)?.max_keepdim(1)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(
tensor.max_keepdim(1)?.max_keepdim(0)?.to_vec2::<u32>()?,
&[[3999]]
);
assert_eq!(tensor.max_keepdim(0)?.to_vec2::<u32>()?, &[[3998, 3999]]);
let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.max_keepdim(0)?
.max_keepdim(2)?
.max_keepdim(1)?
.to_vec3::<u32>()?,
&[[[3999]]]
);
assert_eq!(
tensor.max_keepdim(0)?.to_vec3::<u32>()?,
&[[
[3980, 3981, 3982, 3983],
[3984, 3985, 3986, 3987],
[3988, 3989, 3990, 3991],
[3992, 3993, 3994, 3995],
[3996, 3997, 3998, 3999]
]]
);
}
Ok(())
}
fn argmin(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.argmin_keepdim(2)?.to_vec3::<u32>()?,
&[[[1], [0]], [[1], [1]]]
);
assert_eq!(
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
&[[[1, 0, 0], [0, 1, 1]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.argmin_keepdim(0)?.to_vec1::<u32>()?, &[0]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor
.argmin_keepdim(0)?
.argmin_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmin_keepdim(1)?
.argmin_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);
// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor
.argmin_keepdim(0)?
.argmin_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmin_keepdim(1)?
.argmin_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmin_keepdim(0)?.to_vec2::<u32>()?, &[[0, 0]]);
let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.argmin_keepdim(0)?
.argmin_keepdim(2)?
.argmin_keepdim(1)?
.to_vec3::<u32>()?,
&[[[0]]]
);
assert_eq!(
tensor.argmin_keepdim(0)?.to_vec3::<u32>()?,
&[[
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
]]
);
}
Ok(())
}
fn argmax(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.argmax_keepdim(2)?.to_vec3::<u32>()?,
&[[[2], [2]], [[2], [0]]]
);
assert_eq!(
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
&[[[0, 0, 1], [1, 0, 0]]],
);
let data: Vec<u32> = (200..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.argmax_keepdim(0)?.to_vec1::<u32>()?, &[3799]);
let tensor = tensor.reshape((1900, 2))?;
assert_eq!(
tensor
.argmax_keepdim(0)?
.argmax_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmax_keepdim(1)?
.argmax_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);
// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(
tensor
.argmax_keepdim(0)?
.argmax_keepdim(1)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(
tensor
.argmax_keepdim(1)?
.argmax_keepdim(0)?
.to_vec2::<u32>()?,
&[[0]]
);
assert_eq!(tensor.argmax_keepdim(0)?.to_vec2::<u32>()?, &[[1899, 1899]]);
let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor
.argmax_keepdim(0)?
.argmax_keepdim(2)?
.argmax_keepdim(1)?
.to_vec3::<u32>()?,
&[[[0]]]
);
assert_eq!(
tensor.argmax_keepdim(0)?.to_vec3::<u32>()?,
&[[
[189, 189, 189, 189],
[189, 189, 189, 189],
[189, 189, 189, 189],
[189, 189, 189, 189],
[189, 189, 189, 189],
]]
);
}
Ok(())
}
fn narrow(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
@ -270,7 +505,11 @@ fn cat(device: &Device) -> Result<()> {
[2.0, 7.0, 1.0, 8.0, 2.0]
]
);
// TODO: This is not the expected answer, to be fixed!
// PyTorch equivalent:
// import torch
// t1 = torch.tensor([[3, 1, 4, 1, 5], [2, 7, 1, 8, 2]])
// t2 = torch.tensor([[5]*5, [2, 7, 1, 8, 2]])
// torch.cat([t1.t(), t2.t()], dim=1).t()
assert_eq!(
Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
.t()?
@ -282,7 +521,6 @@ fn cat(device: &Device) -> Result<()> {
[2.0, 7.0, 1.0, 8.0, 2.0]
]
);
// TODO: This is not the expected answer, to be fixed!
assert_eq!(
Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?,
[
@ -296,8 +534,167 @@ fn cat(device: &Device) -> Result<()> {
fn embeddings(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
let hs = Tensor::embedding(&ids, &t)?;
let hs = t.embedding(&ids)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
Ok(())
}
fn cmp(device: &Device) -> Result<()> {
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
assert_eq!(t1.eq(&t2)?.to_vec2::<u8>()?, &[[0, 0], [0, 1], [1, 0]]);
assert_eq!(t1.ne(&t2)?.to_vec2::<u8>()?, &[[1, 1], [1, 0], [0, 1]]);
assert_eq!(t1.le(&t2)?.to_vec2::<u8>()?, &[[1, 0], [1, 1], [1, 1]]);
assert_eq!(t1.lt(&t2)?.to_vec2::<u8>()?, &[[1, 0], [1, 0], [0, 1]]);
assert_eq!(t1.gt(&t2)?.to_vec2::<u8>()?, &[[0, 1], [0, 0], [0, 0]]);
assert_eq!(t1.ge(&t2)?.to_vec2::<u8>()?, &[[0, 1], [0, 1], [1, 0]]);
Ok(())
}
fn index_select(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
let hs = t.index_select(&ids, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 2.0, 1.0],
[3.0, 5.0, 4.0],
[6.0, 8.0, 7.0],
[9.0, 11.0, 10.0]
]
);
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
);
Ok(())
}
fn index_add(device: &Device) -> Result<()> {
let ids = Tensor::new(&[0u32, 1u32, 1u32], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
let init = Tensor::ones((4, 2), DType::F32, device)?;
let hs = init.index_add(&ids, &t, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[[1.0, 4.0], [4.0, 10.0], [7.0, 16.0], [10.0, 22.0]],
);
let init = Tensor::zeros((4, 2), DType::F32, device)?;
let ids = Tensor::new(&[1u32, 0u32, 0u32], device)?;
let hs = init.index_add(&ids, &t, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[[3.0, 0.0], [9.0, 3.0], [15.0, 6.0], [21.0, 9.0]],
);
let init = Tensor::zeros((6, 3), DType::F32, device)?;
let ids = Tensor::new(&[5u32, 0u32, 1u32, 0u32], device)?;
let hs = init.index_add(&ids, &t, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[12.0, 14.0, 16.0],
[6.0, 7.0, 8.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 1.0, 2.0]
]
);
Ok(())
}
fn scatter_add(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
let ids = Tensor::new(&[[0u32, 1, 2], [3, 4, 0], [3, 3, 1], [2, 0, 4]], device)?;
let init = Tensor::ones((4, 5), DType::F32, device)?;
let hs = init.scatter_add(&ids, &t, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[1.0, 2.0, 3.0, 1.0, 1.0],
[6.0, 1.0, 1.0, 4.0, 5.0],
[1.0, 9.0, 1.0, 14.0, 1.0],
[11.0, 1.0, 10.0, 1.0, 12.0]
]
);
let init = Tensor::ones((6, 3), DType::F32, device)?;
let hs = init.scatter_add(&ids, &t, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[1.0, 11.0, 6.0],
[1.0, 2.0, 9.0],
[10.0, 1.0, 3.0],
[10.0, 8.0, 1.0],
[1.0, 5.0, 12.0],
[1.0, 1.0, 1.0]
]
);
Ok(())
}
fn gather(device: &Device) -> Result<()> {
let ids = Tensor::new(&[[0u32], [2u32], [1u32], [0u32]], device)?;
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
let hs = t.gather(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0], [5.0], [7.0], [9.0]]);
let ids = Tensor::new(
&[[0u32, 0u32], [2u32, 0u32], [1u32, 1u32], [0u32, 2u32]],
device,
)?;
let hs = t.gather(&ids, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[[0.0, 0.0], [5.0, 3.0], [7.0, 7.0], [9.0, 11.0]]
);
let ids = Tensor::new(&[[0u32, 2u32, 0u32]], device)?;
let hs = t.gather(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0]]);
let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?;
let hs = t.gather(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
Ok(())
}
@ -458,9 +855,17 @@ test_device!(narrow, narrow_cpu, narrow_gpu);
test_device!(broadcast, broadcast_cpu, broadcast_gpu);
test_device!(cat, cat_cpu, cat_gpu);
test_device!(sum, sum_cpu, sum_gpu);
test_device!(min, min_cpu, min_gpu);
test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(softmax, softmax_cpu, softmax_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);

View File

@ -1,6 +1,6 @@
#![allow(dead_code)]
use candle::{Result, Tensor};
use candle_core::{Result, Tensor};
#[macro_export]
macro_rules! test_device {
@ -20,6 +20,23 @@ macro_rules! test_device {
};
}
pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
let b = 10f32.powi(digits);
let t = t.to_vec1::<f32>()?;
let t = t.iter().map(|t| f32::round(t * b) / b).collect();
Ok(t)
}
pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
let b = 10f32.powi(digits);
let t = t.to_vec2::<f32>()?;
let t = t
.iter()
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
.collect();
Ok(t)
}
pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
let b = 10f32.powi(digits);
let t = t.to_vec3::<f32>()?;

View File

@ -1,28 +1,33 @@
[package]
name = "candle-examples"
version = "0.1.0"
edition = "2021"
description = "Examples for the candle ML framework."
repository = "https://github.com/LaurentMazare/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
readme = "README.md"
[dependencies]
candle = { path = "../candle-core" }
candle-nn = { path = "../candle-nn" }
candle-transformers = { path = "../candle-transformers" }
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.1.0" }
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
candle-hub = { path = "../candle-hub" }
byteorder = { workspace = true }
clap = { workspace = true }
hf-hub = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
@ -30,7 +35,16 @@ tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
[build-dependencies]
anyhow = { workspace = true }
[features]
default = []
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
[[example]]
name = "llama_multiprocess"
required-features = ["cuda", "nccl", "flash-attn"]

View File

@ -0,0 +1 @@
# candle-examples

238
candle-examples/build.rs Normal file
View File

@ -0,0 +1,238 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::io::Write;
use std::path::PathBuf;
struct KernelDirectories {
kernel_dir: &'static str,
rust_target: &'static str,
include_dirs: &'static [&'static str],
}
const DIRS: [KernelDirectories; 1] = [KernelDirectories {
kernel_dir: "examples/custom-ops/kernels/",
rust_target: "examples/custom-ops/cuda_kernels.rs",
include_dirs: &[],
}];
impl KernelDirectories {
fn maybe_build_ptx(
&self,
cu_file: &std::path::Path,
ptx_file: &std::path::Path,
compute_cap: usize,
) -> Result<()> {
let should_compile = if ptx_file.exists() {
let ptx_modified = ptx_file.metadata()?.modified()?;
let cu_modified = cu_file.metadata()?.modified()?;
cu_modified.duration_since(ptx_modified).is_ok()
} else {
true
};
if should_compile {
#[cfg(feature = "cuda")]
{
let mut command = std::process::Command::new("nvcc");
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
let include_dirs: Vec<String> =
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
command
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", out_dir.to_str().unwrap()])
.arg(format!("-I/{}", self.kernel_dir))
.args(include_dirs)
.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
#[cfg(not(feature = "cuda"))]
std::fs::OpenOptions::new()
.create(true)
.write(true)
.open(ptx_file)?;
}
Ok(())
}
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
println!("cargo:rerun-if-changed={}", self.kernel_dir);
let kernel_dir = PathBuf::from(self.kernel_dir);
let out_dir = out_dir.join(self.kernel_dir);
if !out_dir.exists() {
std::fs::create_dir_all(&out_dir)?;
}
let mut cu_files = vec![];
let mut cuh_files = vec![];
for file in std::fs::read_dir(kernel_dir)?.flatten() {
let file = file.path();
match file.extension().and_then(|v| v.to_str()) {
Some("cu") => cu_files.push(file),
Some("cuh") => cuh_files.push(file),
_ => {}
}
}
let mut ptx_paths = vec![];
for cu_file in cu_files.iter() {
let file_stem = cu_file
.file_stem()
.with_context(|| format!("no stem {cu_file:?}"))?;
let file_stem = file_stem.to_string_lossy().into_owned();
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
ptx_paths.push(ptx_file);
}
let regenerate_rs_file = true;
if regenerate_rs_file {
let mut file = std::fs::File::create(self.rust_target)?;
for ptx_path in ptx_paths {
let name = ptx_path
.file_stem()
.context("empty stem")?
.to_string_lossy();
file.write_all(b"#[rustfmt::skip]\n")?;
let const_definition = format!(
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
name.to_uppercase().replace('.', "_"),
self.kernel_dir,
);
file.write_all(const_definition.as_bytes())?;
file.write_all(b"\n")?;
}
}
Ok(())
}
}
fn main() -> Result<()> {
println!("cargo:rerun-if-changed=build.rs");
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
let out_dir = PathBuf::from(out_dir);
#[cfg(feature = "cuda")]
set_cuda_include_dir()?;
#[cfg(feature = "cuda")]
let compute_cap = compute_cap()?;
#[cfg(not(feature = "cuda"))]
let compute_cap = 0;
for d in DIRS {
d.process(&out_dir, compute_cap)?
}
Ok(())
}
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
// Grab compute code from nvidia-smi
let mut compute_cap = {
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
};
// Grab available GPU codes from nvcc and select the highest one
let max_nvcc_code = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
if !codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
);
}
*codes.last().unwrap()
};
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
);
compute_cap = max_nvcc_code;
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
Ok(compute_cap)
}

View File

@ -4,9 +4,9 @@ mod model;
use anyhow::{anyhow, Error as E, Result};
use candle::Tensor;
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use model::{BertModel, Config, DTYPE};
use tokenizers::{PaddingParams, Tokenizer};
@ -69,10 +69,11 @@ impl Args {
)
} else {
let api = Api::new()?;
let api = api.repo(repo);
(
api.get(&repo, "config.json")?,
api.get(&repo, "tokenizer.json")?,
api.get(&repo, "model.safetensors")?,
api.get("config.json")?,
api.get("tokenizer.json")?,
api.get("model.safetensors")?,
)
};
let config = std::fs::read_to_string(config_filename)?;
@ -161,7 +162,7 @@ fn main() -> Result<()> {
let embeddings = model.forward(&token_ids, &token_type_ids)?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
println!("pooled embeddings {:?}", embeddings.shape());
let mut similarities = vec![];

View File

@ -87,7 +87,7 @@ impl LayerNorm {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
@ -262,7 +262,7 @@ impl BertEmbeddings {
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (_bsize, seq_len) = input_ids.shape().r2()?;
let (_bsize, seq_len) = input_ids.dims2()?;
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
@ -333,7 +333,7 @@ impl BertSelfAttention {
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_probs = {
let _enter_sm = self.span_softmax.enter();
attention_scores.softmax(candle::D::Minus1)?
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
};
let attention_probs = self.dropout.forward(&attention_probs)?;

View File

@ -0,0 +1,156 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use clap::Parser;
mod model;
use model::{Config, GPTBigCode};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: GPTBigCode,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
}
impl TextGeneration {
fn new(
model: GPTBigCode,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp);
Self {
model,
tokenizer,
logits_processor,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
print!("{prompt}");
std::io::stdout().flush()?;
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let (context_size, past_len) = if self.model.config().use_cache && index > 0 {
(1, tokens.len().saturating_sub(1))
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, past_len)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
let token = self
.tokenizer
.decode(vec![next_token], true)
.map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"{sample_len} tokens generated ({:.3} token/s)",
sample_len as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "bigcode/starcoderbase-1b")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
weight_file: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = repo.get("tokenizer.json")?;
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => ["model.safetensors"]
.iter()
.map(|f| repo.get(f))
.collect::<std::result::Result<Vec<_>, _>>()?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
.collect::<Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
let config = Config::starcoder_1b();
let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,359 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = if bias {
Some(vb.get(size2, "bias")?)
} else {
None
};
Ok(Linear::new(weight, bias))
}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
Ok(LayerNorm::new(weight, bias, eps))
}
fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j <= i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), device)?;
Ok(mask)
}
#[derive(Debug)]
pub struct Config {
pub vocab_size: usize,
// max_position_embeddings aka n_positions
pub max_position_embeddings: usize,
// num_hidden_layers aka n_layer
pub num_hidden_layers: usize,
// hidden_size aka n_embd
pub hidden_size: usize,
pub layer_norm_epsilon: f64,
pub n_inner: Option<usize>,
// num_attention_heads aka n_head
pub num_attention_heads: usize,
pub multi_query: bool,
pub use_cache: bool,
}
impl Config {
#[allow(dead_code)]
pub fn starcoder_1b() -> Self {
Self {
vocab_size: 49152,
max_position_embeddings: 8192,
num_hidden_layers: 24,
hidden_size: 2048,
layer_norm_epsilon: 1e-5,
n_inner: Some(8192),
num_attention_heads: 16,
multi_query: true,
use_cache: true,
}
}
#[allow(dead_code)]
pub fn starcoder_3b() -> Self {
Self {
vocab_size: 49152,
max_position_embeddings: 8192,
num_hidden_layers: 36,
hidden_size: 2816,
layer_norm_epsilon: 1e-5,
n_inner: Some(11264),
num_attention_heads: 22,
multi_query: true,
use_cache: true,
}
}
#[allow(dead_code)]
pub fn starcoder_7b() -> Self {
Self {
vocab_size: 49152,
max_position_embeddings: 8192,
num_hidden_layers: 42,
hidden_size: 4096,
layer_norm_epsilon: 1e-5,
n_inner: Some(16384),
num_attention_heads: 32,
multi_query: true,
use_cache: true,
}
}
#[allow(dead_code)]
pub fn starcoder() -> Self {
Self {
vocab_size: 49152,
max_position_embeddings: 8192,
num_hidden_layers: 40,
hidden_size: 6144,
layer_norm_epsilon: 1e-5,
n_inner: Some(24576),
num_attention_heads: 48,
multi_query: true,
use_cache: true,
}
}
}
struct Attention {
c_attn: Linear,
c_proj: Linear,
kv_cache: Option<Tensor>,
use_cache: bool,
embed_dim: usize,
kv_dim: usize,
num_heads: usize,
head_dim: usize,
multi_query: bool,
}
impl Attention {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let head_dim = hidden_size / cfg.num_attention_heads;
let kv_heads = if cfg.multi_query {
1
} else {
cfg.num_attention_heads
};
let kv_dim = kv_heads * head_dim;
let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?;
let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?;
Ok(Self {
c_proj,
c_attn,
embed_dim: hidden_size,
kv_cache: None,
use_cache: cfg.use_cache,
kv_dim,
head_dim,
num_heads: cfg.num_attention_heads,
multi_query: cfg.multi_query,
})
}
fn attn(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attention_mask: &Tensor,
) -> Result<Tensor> {
if query.dtype() != DType::F32 {
// If we start supporting f16 models, we may need the upcasting scaling bits.
// https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L133
candle::bail!("upcasting is not supported {:?}", query.dtype())
}
let scale_factor = 1f64 / (self.head_dim as f64).sqrt();
let initial_query_shape = query.shape();
let key_len = key.dim(D::Minus1)?;
let (query, key, attn_shape, attn_view) = if self.multi_query {
let (b_sz, query_len, _) = query.dims3()?;
let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
let attn_shape = (b_sz, query_len, self.num_heads, key_len);
let attn_view = (b_sz, query_len * self.num_heads, key_len);
(query, key.clone(), attn_shape, attn_view)
} else {
let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?;
let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?;
let attn_shape = (b_sz, self.num_heads, query_len, key_len);
let attn_view = (b_sz * self.num_heads, query_len, key_len);
(query, key, attn_shape, attn_view)
};
let attn_weights =
(query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?;
let attention_mask = attention_mask.broadcast_as(attn_shape)?;
let mask_value =
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let value = value.contiguous()?;
let attn_output = if self.multi_query {
attn_weights
.reshape(attn_view)?
.matmul(&value)?
.reshape(initial_query_shape)?
} else {
attn_weights.matmul(&value)?
};
Ok(attn_output)
}
fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let qkv = self.c_attn.forward(hidden_states)?;
let (query, key_value) = if self.multi_query {
let query = qkv.i((.., .., ..self.embed_dim))?;
let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?;
(query, key_value)
} else {
let mut dims = qkv.dims().to_vec();
dims.pop();
dims.push(self.embed_dim);
dims.push(self.head_dim * 3);
let qkv = qkv.reshape(dims)?.transpose(1, 2)?;
let query = qkv.i((.., .., .., ..self.head_dim))?;
let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?;
(query, key_value)
};
let mut key_value = key_value;
if self.use_cache {
if let Some(kv_cache) = &self.kv_cache {
// TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
// arbitrarily large sizes.
key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?;
}
self.kv_cache = Some(key_value.clone())
}
let key = key_value.narrow(D::Minus1, 0, self.head_dim)?;
let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?;
let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?;
let attn_output = if self.multi_query {
attn_output
} else {
attn_output
.transpose(1, 2)?
.reshape(hidden_states.shape())?
};
let attn_output = self.c_proj.forward(&attn_output)?;
Ok(attn_output)
}
}
struct Mlp {
c_fc: Linear,
c_proj: Linear,
}
impl Mlp {
fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?;
let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?;
Ok(Self { c_fc, c_proj })
}
fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> {
let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?;
let hidden_states = self.c_proj.forward(&hidden_states)?;
Ok(hidden_states)
}
}
// TODO: Add cross-attention?
struct Block {
ln_1: LayerNorm,
attn: Attention,
ln_2: LayerNorm,
mlp: Mlp,
}
impl Block {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size);
let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?;
let attn = Attention::load(vb.pp("attn"), cfg)?;
let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?;
let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?;
Ok(Self {
ln_1,
attn,
ln_2,
mlp,
})
}
fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let residual = hidden_states;
let hidden_states = self.ln_1.forward(hidden_states)?;
let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?;
let hidden_states = (&attn_outputs + residual)?;
let residual = &hidden_states;
let hidden_states = self.ln_2.forward(&hidden_states)?;
let hidden_states = self.mlp.forward(&hidden_states)?;
let hidden_states = (&hidden_states + residual)?;
Ok(hidden_states)
}
}
pub struct GPTBigCode {
wte: Embedding,
wpe: Embedding,
blocks: Vec<Block>,
ln_f: LayerNorm,
lm_head: Linear,
bias: Tensor,
config: Config,
}
impl GPTBigCode {
pub fn config(&self) -> &Config {
&self.config
}
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let vb_t = vb.pp("transformer");
let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?;
let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?;
let blocks = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg))
.collect::<Result<Vec<_>>>()?;
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?;
let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
Ok(Self {
wte,
wpe,
blocks,
lm_head,
ln_f,
bias,
config: cfg,
})
}
pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result<Tensor> {
let dev = input_ids.device();
let (b_sz, seq_len) = input_ids.dims2()?;
let key_len = past_len + seq_len;
let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?;
// MQA models: (batch_size, query_length, n_heads, key_length)
// MHA models: (batch_size, n_heads, query_length, key_length)
let seq_len_dim = if self.config.multi_query { 2 } else { 1 };
let attention_mask = attention_mask.unsqueeze(seq_len_dim)?;
let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?;
let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?;
let input_embeds = self.wte.forward(input_ids)?;
let position_embeds = self.wpe.forward(&position_ids)?;
let mut hidden_states = (&input_embeds + &position_embeds)?;
for block in self.blocks.iter_mut() {
hidden_states = block.forward(&hidden_states, &attention_mask)?;
}
let hidden_states = self.ln_f.forward(&hidden_states)?;
let hidden_states = hidden_states
.reshape((b_sz, seq_len, self.config.hidden_size))?
.narrow(1, seq_len - 1, 1)?;
let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?;
Ok(logits)
}
}

View File

@ -0,0 +1,2 @@
#[rustfmt::skip]
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));

View File

@ -0,0 +1,35 @@
#include <stdint.h>
#include "reduction_utils.cuh"
template <typename scalar_t>
__device__ void
rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
const scalar_t *__restrict__ input, // [num_tokens, hidden_size]
const float epsilon, const uint32_t num_tokens,
const uint32_t hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance));
}
}
extern "C" __global__ void rms_f32(
float *__restrict__ out, // [num_tokens, hidden_size]
const float *__restrict__ input, // [num_tokens, hidden_size]
const float epsilon, const uint32_t num_tokens,
const uint32_t hidden_size) {
rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size);
}

View File

@ -0,0 +1,46 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
template <typename T> __inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T> __inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}

View File

@ -0,0 +1,95 @@
// This example illustrates how to implement custom operations. These operations can provide their
// own forward pass (CPU and GPU versions) as well as their backward pass.
//
// In this example we add the RMS normalization operation and implement it for f32.
#![allow(dead_code)]
#![allow(unused)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod cuda_kernels;
use clap::Parser;
use candle::backend::BackendStorage;
use candle::cpu_backend;
use candle::{CpuStorage, CustomOp1, DType, Device, Layout, Result, Shape, Tensor};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
}
struct LayerNorm {
eps: f32,
}
impl CustomOp1 for LayerNorm {
fn name(&self) -> &'static str {
"layer-norm"
}
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
let (dim1, dim2) = layout.shape().dims2()?;
let slice = storage.as_slice::<f32>()?;
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => &slice[o1..o2],
};
let mut dst = Vec::with_capacity(dim1 * dim2);
for idx1 in 0..dim1 {
let src = &src[idx1 * dim2..(idx1 + 1) * dim2];
let variance = src.iter().map(|x| x * x).sum::<f32>();
let s_variance = 1f32 / (variance / dim2 as f32 + self.eps).sqrt();
dst.extend(src.iter().map(|x| x * s_variance))
}
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, layout.shape().clone()))
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
storage: &candle::CudaStorage,
layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::{cudarc, WrapErr};
use cudarc::driver::{LaunchAsync, LaunchConfig};
let (d1, d2) = layout.shape().dims2()?;
let d1 = d1 as u32;
let d2 = d2 as u32;
let dev = storage.device().clone();
let slice = storage.as_cuda_slice::<f32>()?;
let slice = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => slice.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
let params = (&dst, &slice, self.eps, d1, d2);
let cfg = LaunchConfig {
grid_dim: (d1, 1, 1),
block_dim: (d2, 1, 1),
shared_mem_bytes: 0,
};
unsafe { func.launch(cfg, params) }.w()?;
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, layout.shape().clone()))
}
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
println!("{t}");
let t = t.custom_op1(LayerNorm { eps: 1e-5 })?;
println!("{t}");
Ok(())
}

View File

@ -5,10 +5,10 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
mod model;
@ -123,14 +123,18 @@ fn main() -> Result<()> {
let device = candle_examples::device(args.cpu)?;
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision);
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = repo.get("tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = api.get(&repo, rfilename)?;
let filename = repo.get(rfilename)?;
filenames.push(filename);
}
println!("retrieved the files in {:?}", start.elapsed());

View File

@ -182,7 +182,7 @@ impl FalconRotaryEmbedding {
key: &Tensor,
past_kv_len: usize,
) -> Result<(Tensor, Tensor)> {
let (_batch, seq_len, _head_dim) = query.shape().r3()?;
let (_batch, seq_len, _head_dim) = query.dims3()?;
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
let cos = cos.narrow(0, past_kv_len, seq_len)?;
let sin = sin.narrow(0, past_kv_len, seq_len)?;
@ -245,7 +245,7 @@ impl FalconAttention {
}
fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let (b_sz, seq_len, _) = fused_qkv.shape().r3()?;
let (b_sz, seq_len, _) = fused_qkv.dims3()?;
if !self.multi_query {
let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
@ -267,7 +267,7 @@ impl FalconAttention {
let fused_qkv = self.query_key_value.forward(x)?;
let head_dim = self.head_dim;
let (query, key, value) = self.split_heads(&fused_qkv)?;
let (b_sz, seq_len, _, _) = query.shape().r4()?;
let (b_sz, seq_len, _, _) = query.dims4()?;
let query = query
.transpose(1, 2)?
.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
@ -309,11 +309,13 @@ impl FalconAttention {
// Only handle the case where alibi is None here, and non-flash attention.
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
let attention_scores = attention_scores
.broadcast_add(&mask.squeeze(1)?)?
.to_dtype(DType::F32)?
.softmax(D::Minus1)?
.to_dtype(x.dtype())?;
let attention_scores = candle_nn::ops::softmax(
&attention_scores
.broadcast_add(&mask.squeeze(1)?)?
.to_dtype(DType::F32)?,
D::Minus1,
)?
.to_dtype(x.dtype())?;
let attn_output = attention_scores
.matmul(&value)?
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
@ -422,7 +424,7 @@ pub struct Falcon {
fn make_causal_mask(t: usize) -> Result<Tensor> {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
Ok(mask)
@ -465,7 +467,7 @@ impl Falcon {
}
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len) = input_ids.shape().r2()?;
let (b_sz, seq_len) = input_ids.dims2()?;
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
Some((k, _)) => k.dim(1)?,

View File

@ -15,10 +15,10 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, Tensor, D};
use candle_hub::{api::sync::Api, Repo, RepoType};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::api::sync::Api;
mod model;
use model::{Config, Llama};
@ -76,23 +76,6 @@ Whate'er it bodes, henceforward will I bear
Upon my target three fair-shining suns.
";
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1];
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], D::Minus1)?)
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -127,17 +110,44 @@ struct Args {
/// Use f32 computations rather than f16.
#[arg(long)]
use_f32: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
v1: bool,
#[arg(long)]
use_flash_attn: bool,
}
fn main() -> Result<()> {
use tokenizers::Tokenizer;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?;
let config = Config::config_7b();
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
let config = if args.v1 {
Config::config_7b_v1(args.use_flash_attn)
} else {
Config::config_7b_v2(args.use_flash_attn)
};
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let (llama, tokenizer_filename) = match args.npy {
Some(filename) => {
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
@ -146,15 +156,22 @@ fn main() -> Result<()> {
}
None => {
let api = Api::new()?;
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
println!("loading the model weights");
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
let model_id = args.model_id.unwrap_or_else(|| {
if args.v1 {
"Narsil/amall-7b".to_string()
} else {
"meta-llama/Llama-2-7b-hf".to_string()
}
});
println!("loading the model weights from {model_id}");
let api = api.model(model_id);
let tokenizer_filename = api.get("tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = api.get(&repo, rfilename)?;
let filename = api.get(rfilename)?;
filenames.push(filename);
}
@ -180,8 +197,6 @@ fn main() -> Result<()> {
.get_ids()
.to_vec();
println!("pre-computing the positional embeddings");
let freqs_cis = precompute_freqs_cis(&config, &device)?;
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let mut new_tokens = vec![];
@ -196,12 +211,7 @@ fn main() -> Result<()> {
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let freqs_cis = if cache.use_kv_cache {
freqs_cis.narrow(1, index_pos, ctxt.len())?
} else {
freqs_cis.clone()
};
let logits = llama.forward(&input, &freqs_cis)?;
let logits = llama.forward(&input, index_pos)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Linear, VarBuilder};
use candle_nn::{Embedding, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -12,10 +12,13 @@ pub struct Config {
pub n_layer: usize,
pub n_head: usize,
pub n_embd: usize,
pub n_key_value_head: usize,
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
}
impl Config {
pub fn config_7b() -> Self {
pub fn config_7b_v1(use_flash_attn: bool) -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
@ -23,8 +26,40 @@ impl Config {
n_layer: 32,
n_head: 32,
n_embd: 4096,
n_key_value_head: 32,
use_flash_attn,
rms_norm_eps: 1e-6,
}
}
pub fn config_7b_v2(use_flash_attn: bool) -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
n_layer: 32,
n_head: 32,
n_embd: 4096,
n_key_value_head: 32,
use_flash_attn,
rms_norm_eps: 1e-5,
}
}
}
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Clone)]
@ -33,17 +68,37 @@ pub struct Cache {
pub use_kv_cache: bool,
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor,
sin: Tensor,
device: Device,
}
impl Cache {
pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self {
Self {
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache,
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
device: device.clone(),
}
cos,
sin,
})
}
fn mask(&self, t: usize) -> Result<Tensor> {
@ -51,9 +106,8 @@ impl Cache {
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())
} else {
// TODO: If we support bool or u8 tensors, this would be better.
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone());
@ -67,8 +121,9 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
@ -78,27 +133,27 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
struct RmsNorm {
scale: Tensor,
eps: f64,
span: tracing::Span,
}
impl RmsNorm {
fn load(size: usize, vb: VarBuilder) -> Result<Self> {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = vb.get(size, "weight")?;
Ok(Self::new(scale))
}
fn new(scale: Tensor) -> Self {
Self { scale }
Ok(Self { scale, eps, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().r1()?;
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let size = self.scale.dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
@ -110,63 +165,69 @@ impl RmsNorm {
}
struct CausalSelfAttention {
c_attn: Linear,
c_proj: Linear,
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
n_head: usize,
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
use_flash_attn: bool,
span: tracing::Span,
span_rot: tracing::Span,
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
impl CausalSelfAttention {
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
Self {
c_attn,
c_proj,
n_head,
cache: cache.clone(),
}
}
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let mut dims = x.dims().to_vec();
let fcis_dims = freqs_cis.dims();
let freqs_cis = if dims[2] < fcis_dims[1] {
freqs_cis.narrow(1, 0, dims[2])?
} else {
freqs_cis.clone()
};
let v = dims.pop().unwrap();
dims.push(v / 2);
dims.push(2);
let x = x.reshape(dims)?;
let re_x = x.narrow(D::Minus1, 0, 1)?;
let im_x = x.narrow(D::Minus1, 1, 1)?;
let re_f = freqs_cis
.narrow(D::Minus1, 0, 1)?
.broadcast_as(re_x.shape())?;
let im_f = freqs_cis
.narrow(D::Minus1, 1, 1)?
.broadcast_as(im_x.shape())?;
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
let rope = Tensor::cat(&[&re, &im], D::Minus1)?;
let rope = rope.flatten_from(D::Minus2)?;
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
let x_dtype = x.dtype();
let (b_sz, seq_len, n_embd) = x.shape().r3()?;
let qkv = self.c_attn.forward(x)?;
let qkv = qkv.to_dtype(DType::F32)?;
let q = qkv.narrow(D::Minus1, 0, n_embd)?;
let k = qkv.narrow(D::Minus1, n_embd, n_embd)?;
let v = qkv.narrow(D::Minus1, 2 * n_embd, n_embd)?;
let target_dim = [b_sz, seq_len, self.n_head, n_embd / self.n_head];
let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let mut v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, freqs_cis)?;
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
.transpose(1, 2)?;
let mut v = v
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
@ -189,39 +250,70 @@ impl CausalSelfAttention {
cache[block_idx] = Some((k.clone(), v.clone()))
}
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let y = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
} else {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = y.to_dtype(x_dtype)?;
let y = self.c_proj.forward(&y)?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_key_value_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
Ok(x)
}
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size;
let size = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
let q_proj = vb.get((size_in, size), "q_proj.weight")?;
let k_proj = vb.get((size_in, size), "k_proj.weight")?;
let v_proj = vb.get((size_in, size), "v_proj.weight")?;
// Invert the transformation from:
// https://github.com/huggingface/transformers/blob/2642d8d04b14c18199ebe7b35f976da02df61752/src/transformers/models/llama/convert_llama_weights_to_hf.py#L101
let n_head = cfg.n_head;
let q_proj = q_proj
.reshape((n_head, 2, size / n_head / 2, size_in))?
.transpose(1, 2)?
.reshape((size_in, size))?;
let k_proj = k_proj
.reshape((n_head, 2, size / n_head / 2, size_in))?
.transpose(1, 2)?
.reshape((size_in, size))?;
let attn_weight = Tensor::cat(&[q_proj, k_proj, v_proj], 0)?;
let c_attn = Linear::new(attn_weight, None);
let o_proj = linear(size, size_in, vb.pp("o_proj"))?;
Ok(Self::new(c_attn, o_proj, cfg.n_head, cache))
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
n_head: cfg.n_head,
n_key_value_head: cfg.n_key_value_head,
head_dim: cfg.hidden_size / cfg.n_head,
cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn,
span,
span_rot,
})
}
}
@ -236,29 +328,29 @@ struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
span: tracing::Span,
}
impl Mlp {
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
Self {
c_fc1,
c_fc2,
c_proj,
}
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "mlp");
let h_size = cfg.hidden_size;
let i_size = cfg.intermediate_size;
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
Ok(Self::new(c_fc1, c_fc2, c_proj))
Ok(Self {
c_fc1,
c_fc2,
c_proj,
span,
})
}
}
@ -267,39 +359,37 @@ struct Block {
attn: CausalSelfAttention,
rms_2: RmsNorm,
mlp: Mlp,
span: tracing::Span,
}
impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
rms_2,
mlp,
}
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
let x = (self
.attn
.forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)?
+ x)?;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "block");
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let rms_2 = RmsNorm::load(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
rms_1,
attn,
post_attention_layernorm,
rms_2,
mlp,
))
span,
})
}
}
@ -311,20 +401,11 @@ pub struct Llama {
}
impl Llama {
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
Self {
wte,
blocks,
ln_f,
lm_head,
}
}
pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let (_b_sz, seq_len) = x.shape().r2()?;
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, freqs_cis, block_idx)?;
x = block.forward(&x, index_pos, block_idx)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?;
@ -335,11 +416,16 @@ impl Llama {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layer)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
.collect();
Ok(Self::new(wte, blocks, norm, lm_head))
Ok(Self {
wte,
blocks,
ln_f,
lm_head,
})
}
}

View File

@ -0,0 +1,289 @@
// https://github.com/karpathy/llama2.c
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod model;
mod training;
mod weights;
use clap::{Parser, Subcommand};
use anyhow::{Error as E, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use candle::{IndexOp, Tensor};
use candle_transformers::generation::LogitsProcessor;
use std::io::Write;
use tokenizers::Tokenizer;
use model::{Config, Llama};
use weights::TransformerWeights;
#[derive(Parser, Debug, Clone)]
struct InferenceCmd {
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
#[arg(long, default_value = "")]
prompt: String,
/// Config file in binary or safetensors format.
#[arg(long)]
config: Option<String>,
#[arg(long, default_value = "karpathy/tinyllamas")]
model_id: String,
/// The model to be used when getting it from the hub. Possible
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
/// https://huggingface.co/karpathy/tinyllamas/tree/main
#[arg(long, default_value = "stories15M.bin")]
which_model: String,
}
#[derive(Parser, Debug, Clone)]
struct EvaluationCmd {
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
/// script from llama2.c https://github.com/karpathy/llama2.c
#[arg(long)]
pretokenized_dir: Option<String>,
#[arg(long, default_value_t = 32)]
batch_size: usize,
/// Config file in binary format.
#[arg(long)]
config: Option<String>,
#[arg(long, default_value = "karpathy/tinyllamas")]
model_id: String,
/// The model to be used when getting it from the hub. Possible
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
/// https://huggingface.co/karpathy/tinyllamas/tree/main
#[arg(long, default_value = "stories15M.bin")]
which_model: String,
}
#[derive(Parser, Debug, Clone)]
pub struct TrainingCmd {
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
/// script from llama2.c https://github.com/karpathy/llama2.c
#[arg(long)]
pretokenized_dir: String,
#[arg(long, default_value_t = 32)]
batch_size: usize,
#[arg(long, default_value_t = 0.001)]
learning_rate: f64,
}
#[derive(Subcommand, Debug, Clone)]
enum Task {
Inference(InferenceCmd),
Eval(EvaluationCmd),
Train(TrainingCmd),
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// The task to be performed, inference, training or evaluation.
#[command(subcommand)]
task: Option<Task>,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Tokenizer config file.
#[arg(long)]
tokenizer: Option<String>,
}
impl Args {
fn tokenizer(&self) -> Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
api.get("tokenizer.json")?
}
};
Tokenizer::from_file(tokenizer_path).map_err(E::msg)
}
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
match &args.task {
None => {
let cmd = InferenceCmd {
temperature: None,
prompt: "".to_string(),
config: None,
model_id: "karpathy/tinyllamas".to_string(),
which_model: "stories15M.bin".to_string(),
};
run_inference(&cmd, &args)?
}
Some(Task::Inference(cmd)) => run_inference(cmd, &args)?,
Some(Task::Eval(cmd)) => run_eval(cmd, &args)?,
Some(Task::Train(cmd)) => training::run(cmd, &args)?,
}
Ok(())
}
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
use std::io::BufRead;
let config_path = match &args.config {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
println!("loading the model weights from {}", args.model_id);
let api = api.model(args.model_id.clone());
api.get(&args.which_model)?
}
};
let tokenizer = common_args.tokenizer()?;
let device = candle_examples::device(common_args.cpu)?;
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
let tokens = match &args.pretokenized_dir {
None => {
let api = hf_hub::api::sync::Api::new()?;
let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
println!("loading the evaluation dataset from {}", model_id);
let api = api.dataset(model_id.to_string());
let dataset_path = api.get("TinyStories-valid.txt")?;
let file = std::fs::File::open(dataset_path)?;
let file = std::io::BufReader::new(file);
let mut tokens = vec![];
for line in file.lines() {
let line = line?.replace("<|endoftext|>", "<s>");
let line = tokenizer.encode(line, false).map_err(E::msg)?;
tokens.push(line.get_ids().to_vec())
}
tokens.concat()
}
Some(pretokenized_dir) => {
// Use shard 0 for the test split, similar to llama2.c
// https://github.com/karpathy/llama2.c/blob/ce05cc28cf1e3560b873bb21837638a434520a67/tinystories.py#L121
let path = std::path::PathBuf::from(pretokenized_dir).join("data00.bin");
let bytes = std::fs::read(path)?;
// Tokens are encoded as u16.
let mut tokens = vec![0u16; bytes.len() / 2];
std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens)?;
tokens.into_iter().map(|u| u as u32).collect::<Vec<u32>>()
}
};
println!("dataset loaded and encoded: {} tokens", tokens.len());
let seq_len = model.config.seq_len;
let iter = (0..tokens.len()).step_by(seq_len).flat_map(|start_idx| {
if start_idx + seq_len + 1 > tokens.len() {
None
} else {
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
let inputs = Tensor::new(&tokens[..seq_len], &device);
let targets = Tensor::new(&tokens[1..], &device);
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
}
});
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?);
}
Ok(())
}
fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let config_path = match &args.config {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
println!("loading the model weights from {}", args.model_id);
let api = api.model(args.model_id.clone());
api.get(&args.which_model)?
}
};
let tokenizer = common_args.tokenizer()?;
let device = candle_examples::device(common_args.cpu)?;
let is_safetensors = config_path
.extension()
.map_or(false, |v| v == "safetensors");
let (vb, config) = if is_safetensors {
let config = Config::tiny();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
(vb, config)
} else {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
(vb, config)
};
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
let mut index_pos = 0;
print!("{}", args.prompt);
let mut tokens = tokenizer
.encode(args.prompt.clone(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let start_gen = std::time::Instant::now();
for index in 0.. {
if tokens.len() >= model.config.seq_len {
break;
}
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos)?;
let logits = logits.i((0, logits.dim(1)? - 1))?;
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",
tokens.len(),
tokens.len() as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -0,0 +1,340 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear;
use candle_nn::{embedding, Embedding, Linear, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct Config {
pub dim: usize, // transformer dimension
pub hidden_dim: usize, // for ffn layers
pub n_layers: usize, // number of layers
pub n_heads: usize, // number of query heads
pub n_kv_heads: usize, // number of key/value heads (can be < query heads because of multiquery)
pub vocab_size: usize, // vocabulary size, usually 256 (byte-level)
pub seq_len: usize, // max sequence length
pub norm_eps: f64,
}
impl Config {
pub fn tiny() -> Self {
Self {
dim: 288,
hidden_dim: 768,
n_layers: 6,
n_heads: 6,
n_kv_heads: 6,
vocab_size: 32000,
seq_len: 256,
norm_eps: 1e-5,
}
}
}
#[derive(Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool,
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor,
sin: Tensor,
device: Device,
}
impl Cache {
pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let n_elem = cfg.dim / cfg.n_heads;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), vb.device())?;
let idx_theta = Tensor::arange(0, cfg.seq_len as u32, vb.device())?
.to_dtype(DType::F32)?
.reshape((cfg.seq_len, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let precomputed_cos = idx_theta.cos()?;
let precomputed_sin = idx_theta.sin()?;
let freq_cis_real = vb
.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")
.unwrap_or(precomputed_cos);
let freq_cis_imag = vb
.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")
.unwrap_or(precomputed_sin);
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
Ok(Self {
masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache,
kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
cos,
sin,
device: vb.device().clone(),
})
}
fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone());
Ok(mask)
}
}
}
fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
struct RmsNorm {
scale: Tensor,
eps: f64,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let scale = vb.get_or_init(size, "weight", candle_nn::Init::Const(1.))?;
Ok(Self { scale, eps })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let size = self.scale.dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
Ok(x)
}
}
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
n_head: usize,
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
Ok(rope)
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
}
cache[block_idx] = Some((k.clone(), v.clone()))
}
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let q = q.transpose(1, 2)?.contiguous()?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_key_value_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
let x = x
.unsqueeze(3)?
.expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
.reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
Ok(x)
}
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let size_in = cfg.dim;
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
n_head: cfg.n_heads,
n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
})
}
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
}
impl Mlp {
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
Self {
c_fc1,
c_fc2,
c_proj,
}
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h_size = cfg.dim;
let i_size = cfg.hidden_dim;
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
Ok(Self::new(c_fc1, c_fc2, c_proj))
}
}
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
rms_2: RmsNorm,
mlp: Mlp,
}
impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
rms_2,
mlp,
}
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,
post_attention_layernorm,
mlp,
))
}
}
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
pub config: Config,
}
impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;
}
let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
.collect();
Ok(Self {
wte,
blocks,
ln_f,
lm_head,
config: cfg,
})
}
}

View File

@ -0,0 +1,175 @@
#![allow(dead_code)]
#![allow(unused)]
use crate::model::{Cache, Config, Llama};
use candle::{DType, Device, Result, Tensor};
pub struct Dataset {
valid_tokens: Vec<memmap2::Mmap>,
train_tokens: Vec<memmap2::Mmap>,
}
fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
let file = std::fs::File::open(p)?;
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
Ok(mmap)
}
impl Dataset {
pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
let dir = dir.as_ref();
let mut bin_files = vec![];
for file in std::fs::read_dir(dir)?.flatten() {
let file = file.path();
if let Some(extension) = file.extension() {
if extension == "bin" {
bin_files.push(file)
}
}
}
if bin_files.len() < 2 {
candle::bail!("found less than two bin files in {:?}", dir)
}
bin_files.sort();
let valid_tokens = mmap_file(&bin_files[0])?;
let train_tokens = bin_files[1..]
.iter()
.map(mmap_file)
.collect::<Result<Vec<_>>>()?;
Ok(Self {
valid_tokens: vec![valid_tokens],
train_tokens,
})
}
}
struct DatasetRandomIter<'a> {
all_tokens: &'a [memmap2::Mmap],
tokens: Vec<&'a memmap2::Mmap>,
current_tokens: &'a memmap2::Mmap,
indexes_in_bytes: Vec<usize>,
seq_len: usize,
device: Device,
}
impl<'a> DatasetRandomIter<'a> {
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
use rand::seq::SliceRandom;
use rand::thread_rng;
let all_tokens = if valid {
&ds.valid_tokens
} else {
&ds.train_tokens
};
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
tokens.shuffle(&mut thread_rng());
let current_tokens = tokens.pop().unwrap();
let seq_len_in_bytes = seq_len * 2;
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
indexes_in_bytes.shuffle(&mut thread_rng());
Self {
all_tokens,
tokens,
current_tokens,
indexes_in_bytes,
seq_len,
device,
}
}
}
impl<'a> Iterator for DatasetRandomIter<'a> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {
use byteorder::{LittleEndian, ReadBytesExt};
use rand::seq::SliceRandom;
use rand::thread_rng;
let seq_len = self.seq_len;
if self.indexes_in_bytes.is_empty() {
if self.tokens.is_empty() {
self.tokens = self.all_tokens.iter().collect();
self.tokens.shuffle(&mut thread_rng());
}
self.current_tokens = self.tokens.pop().unwrap();
let seq_len_in_bytes = self.seq_len * 2;
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
.step_by(seq_len_in_bytes)
.collect::<Vec<_>>();
self.indexes_in_bytes.shuffle(&mut thread_rng());
}
let start_idx = self.indexes_in_bytes.pop().unwrap();
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
let mut tokens = vec![0u16; bytes.len() / 2];
if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
return Some(Err(err.into()));
}
let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
let inputs = Tensor::new(&tokens[..seq_len], &self.device);
let targets = Tensor::new(&tokens[1..], &self.device);
Some(candle::error::zip(inputs, targets))
}
}
fn valid_loss(
dataset: &Dataset,
model: &Llama,
args: &crate::TrainingCmd,
device: &Device,
) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
let mut sum_ce = 0f64;
let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
sum_ce += loss.to_vec0::<f32>()? as f64;
cnt += 1;
}
Ok(sum_ce / cnt as f64)
}
pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
let dataset = Dataset::new(&args.pretokenized_dir)?;
println!(
"loaded dataset, train: {} files, valid: {} files",
dataset.train_tokens.len(),
dataset.valid_tokens.len()
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
let params = candle_nn::ParamsAdamW {
lr: args.learning_rate,
..Default::default()
};
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
opt.backward_step(&loss)?;
if batch_index > 0 && batch_index % 100 == 0 {
// TODO: Add a way to deactivate the backprop graph tracking when computing the
// validation loss.
let loss = valid_loss(&dataset, &model, args, &device)?;
println!("{batch_index} {loss}");
}
if batch_index > 0 && batch_index % 1000 == 0 {
varmap.save("checkpoint.safetensors")?
}
}
Ok(())
}

View File

@ -0,0 +1,161 @@
use anyhow::Result;
use byteorder::{LittleEndian, ReadBytesExt};
use candle::{DType, Device, IndexOp, Shape, Tensor};
use candle_nn::VarBuilder;
use crate::model::Config;
pub struct TransformerWeights {
// token embedding table
token_embedding_table: Tensor, // (vocab_size, dim)
// weights for rmsnorms
rms_att_weight: Tensor, // (layer, dim) rmsnorm weights
rms_ffn_weight: Tensor, // (layer, dim)
// weights for matmuls
wq: Tensor, // (layer, dim, dim)
wk: Tensor, // (layer, dim, dim)
wv: Tensor, // (layer, dim, dim)
wo: Tensor, // (layer, dim, dim)
// weights for ffn
w1: Tensor, // (layer, hidden_dim, dim)
w2: Tensor, // (layer, dim, hidden_dim)
w3: Tensor, // (layer, hidden_dim, dim)
// final rmsnorm
rms_final_weight: Tensor, // (dim,)
// freq_cis for RoPE relatively positional embeddings
freq_cis_real: Tensor, // (seq_len, head_size/2)
freq_cis_imag: Tensor, // (seq_len, head_size/2)
}
fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> {
let mut buf = [0u8; 4];
r.read_exact(&mut buf)?;
Ok(i32::from_le_bytes(buf))
}
fn read_tensor<R: std::io::Read, S: Into<Shape>>(
r: &mut R,
shape: S,
dev: &Device,
) -> Result<Tensor> {
let shape = shape.into();
let mut data_t = vec![0f32; shape.elem_count()];
r.read_f32_into::<LittleEndian>(&mut data_t)?;
let tensor = Tensor::from_vec(data_t, shape, dev)?;
Ok(tensor)
}
impl Config {
pub fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> {
let dim = read_i32(r)? as usize;
let hidden_dim = read_i32(r)? as usize;
let n_layers = read_i32(r)? as usize;
let n_heads = read_i32(r)? as usize;
let n_kv_heads = read_i32(r)? as usize;
let vocab_size = read_i32(r)? as usize;
let seq_len = read_i32(r)? as usize;
Ok(Self {
dim,
hidden_dim,
n_layers,
n_heads,
n_kv_heads,
vocab_size,
seq_len,
norm_eps: 1e-5,
})
}
pub fn head_size(&self) -> usize {
self.dim / self.n_heads
}
}
impl TransformerWeights {
pub fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> {
let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?;
let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?;
let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
let rms_final_weight = read_tensor(r, c.dim, dev)?;
let head_size = c.head_size();
let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
Ok(Self {
token_embedding_table,
rms_att_weight,
wq,
wk,
wv,
wo,
rms_ffn_weight,
w1,
w2,
w3,
rms_final_weight,
freq_cis_real,
freq_cis_imag,
})
}
pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
let mut ws = std::collections::HashMap::new();
let mut insert = |name: &str, t: Tensor| {
ws.insert(name.to_string(), t);
};
insert("rot.freq_cis_real", self.freq_cis_real.clone());
insert("rot.freq_cis_imag", self.freq_cis_imag.clone());
insert(
"model.embed_tokens.weight",
self.token_embedding_table.clone(),
);
insert("lm_head.weight", self.token_embedding_table.clone());
insert("model.norm.weight", self.rms_final_weight.clone());
for layer in 0..cfg.n_layers {
ws.insert(
format!("model.layers.{layer}.self_attn.q_proj.weight"),
self.wq.i(layer)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.k_proj.weight"),
self.wk.i(layer)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.v_proj.weight"),
self.wv.i(layer)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.o_proj.weight"),
self.wo.i(layer)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.gate_proj.weight"),
self.w1.i(layer)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.down_proj.weight"),
self.w2.i(layer)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.up_proj.weight"),
self.w3.i(layer)?,
);
ws.insert(
format!("model.layers.{layer}.input_layernorm.weight"),
self.rms_att_weight.i(layer)?,
);
ws.insert(
format!("model.layers.{layer}.post_attention_layernorm.weight"),
self.rms_ffn_weight.i(layer)?,
);
}
let vb = VarBuilder::from_tensors(ws, DType::F32, device);
Ok(vb)
}
}

View File

@ -0,0 +1,251 @@
// An implementation of LLaMA https://github.com/facebookresearch/llama
//
// This is based on nanoGPT in a similar way to:
// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
//
// The tokenizer config can be retrieved from:
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
//
// In order to convert the llama weights to a .npz file, run:
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
use hf_hub::api::sync::Api;
use std::io::Write;
use std::rc::Rc;
mod model;
use model::{Config, Llama};
const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = r"
EDWARD:
I wonder how our princely father 'scaped,
Or whether he be 'scaped away or no
From Clifford's and Northumberland's pursuit:
Had he been ta'en, we should have heard the news;
Had he been slain, we should have heard the news;
Or had he 'scaped, methinks we should have heard
The happy tidings of his good escape.
How fares my brother? why is he so sad?
RICHARD:
I cannot joy, until I be resolved
Where our right valiant father is become.
I saw him in the battle range about;
And watch'd him how he singled Clifford forth.
Methought he bore him in the thickest troop
As doth a lion in a herd of neat;
Or as a bear, encompass'd round with dogs,
Who having pinch'd a few and made them cry,
The rest stand all aloof, and bark at him.
So fared our father with his enemies;
So fled his enemies my warlike father:
Methinks, 'tis prize enough to be his son.
See how the morning opes her golden gates,
And takes her farewell of the glorious sun!
How well resembles it the prime of youth,
Trimm'd like a younker prancing to his love!
EDWARD:
Dazzle mine eyes, or do I see three suns?
RICHARD:
Three glorious suns, each one a perfect sun;
Not separated with the racking clouds,
But sever'd in a pale clear-shining sky.
See, see! they join, embrace, and seem to kiss,
As if they vow'd some league inviolable:
Now are they but one lamp, one light, one sun.
In this the heaven figures some event.
EDWARD:
'Tis wondrous strange, the like yet never heard of.
I think it cites us, brother, to the field,
That we, the sons of brave Plantagenet,
Each one already blazing by our meeds,
Should notwithstanding join our lights together
And over-shine the earth as this the world.
Whate'er it bodes, henceforward will I bear
Upon my target three fair-shining suns.
";
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long)]
num_shards: usize,
#[arg(long)]
rank: Option<usize>,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
sample_len: usize,
/// Disable the key-value cache.
#[arg(long)]
no_kv_cache: bool,
/// The initial prompt.
#[arg(long)]
prompt: Option<String>,
#[arg(long)]
model_id: Option<String>,
}
fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
let config = Config::config_7b();
let dtype = DType::F16;
let api = Api::new()?;
let model_id = args
.model_id
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
println!("loading the model weights from {model_id}");
let api = api.model(model_id);
let tokenizer_filename = api.get("tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = api.get(rfilename)?;
filenames.push(filename);
}
if args.rank.is_none() {
let children: Vec<_> = (0..args.num_shards)
.map(|rank| {
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
args.push_back("--rank".to_string());
args.push_back(format!("{rank}"));
let name = args.pop_front().unwrap();
std::process::Command::new(name).args(args).spawn().unwrap()
})
.collect();
for mut child in children {
child.wait().unwrap();
}
return Ok(());
}
let i = args.rank.unwrap();
let num_shards = args.num_shards;
let rank = i;
// Primitive IPC
let id = if rank == 0 {
let id = Id::new().unwrap();
std::fs::File::create("nccl_id.txt.tmp")?
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())
.unwrap();
std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?;
id
} else {
let path = std::path::PathBuf::from("nccl_id.txt");
while !path.exists() {
std::thread::sleep(std::time::Duration::from_secs(1));
}
let data = std::fs::read("nccl_id.txt")?;
let internal: [i8; 128] = data
.into_iter()
.map(|i| i as i8)
.collect::<Vec<_>>()
.try_into()
.unwrap();
let id: Id = Id::uninit(internal);
id
};
let device = CudaDevice::new(i)?;
let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap());
if rank == 0 {
std::fs::remove_file("nccl_id.txt")?;
}
println!("Rank {rank:?} spawned");
let device = Device::new_cuda(i)?;
let cache = model::Cache::new(&config, &device)?;
println!("building the model");
let handles = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
.collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
let llama = Llama::load(vb, &cache, &config, comm)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
if rank == 0 {
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(vec![next_token], true).map_err(E::msg)?
);
}
}
let dt = start_gen.elapsed();
if rank == 0 {
println!(
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
}
Ok(())
}

View File

@ -0,0 +1,459 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::{Embedding, Linear, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN;
struct TensorParallelColumnLinear {
linear: Linear,
}
impl TensorParallelColumnLinear {
fn new(linear: Linear) -> Self {
Self { linear }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.linear.forward(x)
}
}
struct TensorParallelRowLinear {
linear: Linear,
comm: Rc<Comm>,
}
struct AllReduce {
comm: Rc<Comm>,
}
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Sync for AllReduce {}
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Send for AllReduce {}
impl CustomOp1 for AllReduce {
fn name(&self) -> &'static str {
"allreduce"
}
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
todo!("implement allreduce for cpu is not necessary for single node");
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
s: &candle::CudaStorage,
l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::WrapErr;
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
let s = s.as_cuda_slice::<f16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, l.shape().clone()))
}
}
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
x.custom_op1(AllReduce { comm: comm.clone() })
}
impl TensorParallelRowLinear {
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
Self { linear, comm }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.linear.forward(x)?;
all_reduce_sum(&x, &self.comm)
}
}
impl TensorParallelColumnLinear {
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_sharded("weight", 0, rank, size)?;
Ok(Self::new(Linear::new(weight, None)))
}
fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weights: Vec<_> = prefixes
.iter()
.map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap())
.collect();
let weight = Tensor::cat(&weights, 0)?;
Ok(Self::new(Linear::new(weight, None)))
}
}
impl TensorParallelRowLinear {
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_sharded("weight", 1, rank, size)?;
Ok(Self::new(Linear::new(weight, None), comm))
}
}
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub n_layer: usize,
pub n_head: usize,
pub n_embd: usize,
pub n_key_value_head: usize,
}
impl Config {
pub fn config_7b() -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
n_layer: 32,
n_head: 32,
n_embd: 4096,
n_key_value_head: 32,
}
}
}
#[derive(Clone)]
pub struct Cache {
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor,
sin: Tensor,
}
impl Cache {
pub fn new(config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
let n_elem = config.n_embd / config.n_head;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?.to_dtype(DType::F16)?;
let sin = idx_theta.sin()?.to_dtype(DType::F16)?;
Ok(Self {
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
cos,
sin,
})
}
}
fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Embedding::new(embeddings, cfg.hidden_size))
}
struct RmsNorm {
scale: Tensor,
}
impl RmsNorm {
fn load(size: usize, vb: VarBuilder) -> Result<Self> {
let scale = vb.get(size, "weight")?;
Ok(Self::new(scale))
}
fn new(scale: Tensor) -> Self {
Self { scale }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
let (b_sz, seq_len, hidden_size) = x.shape().dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let size = self.scale.shape().dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
.broadcast_as((b_sz, seq_len, size))?;
let x = (scale * x_normed)?;
let x = x.to_dtype(in_dtype)?;
Ok(x)
}
}
struct CausalSelfAttention {
qkv_proj: TensorParallelColumnLinear,
o_proj: TensorParallelRowLinear,
n_head: usize,
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, _, seq_len, n_embd) = x.shape().dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let (b_sz, seq_len, _) = x.shape().dims3()?;
let qkv = self.qkv_proj.forward(x)?;
let n_embd = self.n_head * self.head_dim;
let q = qkv.i((.., .., ..self.n_head * self.head_dim))?;
let k = qkv.i((
..,
..,
self.n_head * self.head_dim
..self.n_head * self.head_dim + self.n_key_value_head * self.head_dim,
))?;
let v = qkv.i((
..,
..,
self.n_head * self.head_dim + self.n_key_value_head * self.head_dim..,
))?;
// todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape());
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
.transpose(1, 2)?;
let mut v = v
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
let k_seq_len = k.dims()[1];
if k_seq_len > MAX_SEQ_LEN {
k = k
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()?
}
let v_seq_len = v.dims()[1];
if v_seq_len > 2 * MAX_SEQ_LEN {
v = v
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.contiguous()?
}
}
cache[block_idx] = Some((k.clone(), v.clone()));
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
.transpose(1, 2)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_key_value_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
Ok(x)
}
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let qkv_proj = TensorParallelColumnLinear::load_multi(
vb.clone(),
&["q_proj", "k_proj", "v_proj"],
comm.clone(),
)?;
let o_proj = TensorParallelRowLinear::load(vb.pp("o_proj"), comm.clone())?;
Ok(Self {
qkv_proj,
o_proj,
n_head: cfg.n_head / comm.world_size(),
n_key_value_head: cfg.n_key_value_head / comm.world_size(),
head_dim: cfg.hidden_size / cfg.n_head,
cache: cache.clone(),
})
}
}
struct Mlp {
c_fc1: TensorParallelColumnLinear,
c_fc2: TensorParallelColumnLinear,
c_proj: TensorParallelRowLinear,
}
impl Mlp {
fn new(
c_fc1: TensorParallelColumnLinear,
c_fc2: TensorParallelColumnLinear,
c_proj: TensorParallelRowLinear,
) -> Self {
Self {
c_fc1,
c_fc2,
c_proj,
}
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
Ok(Self::new(c_fc1, c_fc2, c_proj))
}
}
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
rms_2: RmsNorm,
mlp: Mlp,
}
impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
rms_2,
mlp,
}
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?;
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,
post_attention_layernorm,
mlp,
))
}
}
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
}
impl Llama {
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
Self {
wte,
blocks,
ln_f,
lm_head,
}
}
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.shape().dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layer)
.map(|i| {
Block::load(
vb.pp(&format!("model.layers.{i}")),
cache,
cfg,
comm.clone(),
)
.unwrap()
})
.collect();
Ok(Self::new(wte, blocks, norm, lm_head))
}
}

View File

@ -0,0 +1,163 @@
// This should reach 91.5% accuracy.
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use clap::{Parser, ValueEnum};
use candle::{DType, Result, Tensor, D};
use candle_nn::{loss, ops, Linear, VarBuilder, VarMap};
const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;
fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?;
Ok(Linear::new(ws, Some(bs)))
}
trait Model: Sized {
fn new(vs: VarBuilder) -> Result<Self>;
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
struct LinearModel {
linear: Linear,
}
impl Model for LinearModel {
fn new(vs: VarBuilder) -> Result<Self> {
let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
Ok(Self { linear })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.linear.forward(xs)
}
}
struct Mlp {
ln1: Linear,
ln2: Linear,
}
impl Model for Mlp {
fn new(vs: VarBuilder) -> Result<Self> {
let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?;
Ok(Self { ln1, ln2 })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.ln1.forward(xs)?;
let xs = xs.relu()?;
self.ln2.forward(&xs)
}
}
struct TrainingArgs {
learning_rate: f64,
load: Option<String>,
save: Option<String>,
epochs: usize,
}
fn training_loop<M: Model>(
m: candle_nn::vision::Dataset,
args: &TrainingArgs,
) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
let train_labels = m.train_labels;
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
let mut varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
let model = M::new(vs.clone())?;
if let Some(load) = &args.load {
println!("loading weights from {load}");
varmap.load(load)?
}
let sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate);
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
for epoch in 1..args.epochs {
let logits = model.forward(&train_images)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_labels)?;
sgd.backward_step(&loss)?;
let test_logits = model.forward(&test_images)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_labels)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
println!(
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
loss.to_scalar::<f32>()?,
100. * test_accuracy
);
}
if let Some(save) = &args.save {
println!("saving trained weights in {save}");
varmap.save(save)?
}
Ok(())
}
#[derive(ValueEnum, Clone)]
enum WhichModel {
Linear,
Mlp,
}
#[derive(Parser)]
struct Args {
#[clap(value_enum, default_value_t = WhichModel::Linear)]
model: WhichModel,
#[arg(long)]
learning_rate: Option<f64>,
#[arg(long, default_value_t = 200)]
epochs: usize,
/// The file where to save the trained weights, in safetensors format.
#[arg(long)]
save: Option<String>,
/// The file where to load the trained weights from, in safetensors format.
#[arg(long)]
load: Option<String>,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
let m = candle_nn::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
println!("test-labels: {:?}", m.test_labels.shape());
let default_learning_rate = match args.model {
WhichModel::Linear => 1.,
WhichModel::Mlp => 0.05,
};
let training_args = TrainingArgs {
epochs: args.epochs,
learning_rate: args.learning_rate.unwrap_or(default_learning_rate),
load: args.load,
save: args.save,
};
match args.model {
WhichModel::Linear => training_loop::<LinearModel>(m, &training_args),
WhichModel::Mlp => training_loop::<Mlp>(m, &training_args),
}
}

View File

@ -142,7 +142,7 @@ impl EncodecEuclideanCodebook {
}
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = Tensor::embedding(embed_ind, &self.embed)?;
let quantize = self.embed.embedding(embed_ind)?;
Ok(quantize)
}
}

View File

@ -123,7 +123,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
}
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (_b_sz, _codebooks, seq_len) = input_ids.shape().r3()?;
let (_b_sz, _codebooks, seq_len) = input_ids.dims3()?;
if seq_len > self.weights.dim(0)? {
self.weights = get_embedding(seq_len, self.embedding_dim)?
}
@ -170,7 +170,7 @@ impl MusicgenAttention {
kv_states: Option<&Tensor>,
attention_mask: &Tensor,
) -> Result<Tensor> {
let (b_sz, tgt_len, _) = xs.shape().r3()?;
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (self.q_proj.forward(xs)? * self.scaling)?;
let kv_states = kv_states.unwrap_or(xs);
@ -187,7 +187,7 @@ impl MusicgenAttention {
let attn_weights = attn_weights
.reshape((b_sz, self.num_heads, tgt_len, src_len))?
.broadcast_add(attention_mask)?;
let attn_weights = attn_weights.softmax(D::Minus1)?;
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
// TODO: layer_head_mask?
let attn_output = attn_weights
.matmul(&value_states)?
@ -308,7 +308,7 @@ impl MusicgenDecoder {
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let dev = input_ids.device();
let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
let (b_sz_times_codebooks, seq_len) = input_ids.dims2()?;
let b_sz = b_sz_times_codebooks / self.num_codebooks;
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
@ -352,7 +352,7 @@ impl MusicgenForCausalLM {
}
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len) = input_ids.shape().r2()?;
let (b_sz, seq_len) = input_ids.dims2()?;
let hidden_states = self.decoder.forward(input_ids)?;
let lm_logits = self
.lm_heads

View File

@ -223,7 +223,7 @@ impl T5Attention {
.transpose(1, 2)?;
let scores = q.matmul(&k.t()?)?;
// TODO: position_bias_masked
let attn_weights = scores.softmax(D::Minus1)?;
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = self.o.forward(&attn_output)?;
Ok(attn_output)
@ -338,7 +338,7 @@ impl T5Stack {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let (_b_sz, _seq_len) = input_embeds.shape().r2()?;
let (_b_sz, _seq_len) = input_embeds.dims2()?;
let mut hidden_states = self.dropout.forward(&input_embeds)?;
for block in self.block.iter() {

View File

@ -18,10 +18,8 @@ fn fft<T: Float>(inp: &[T]) -> Vec<T> {
}
let mut out = vec![zero; n * 2];
let mut even = vec![];
even.reserve(n / 2);
let mut odd = vec![];
odd.reserve(n / 2);
let mut even = Vec::with_capacity(n / 2);
let mut odd = Vec::with_capacity(n / 2);
for (i, &inp) in inp.iter().enumerate() {
if i % 2 == 0 {

View File

@ -11,9 +11,9 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{safetensors::Load, DType, Device, Tensor};
use candle_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder;
use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
@ -120,19 +120,17 @@ impl Decoder {
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
no_speech_prob = logits
.get(0)?
.softmax(0)?
no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)?
.to_scalar::<f32>()? as f64;
}
let (seq_len, _) = logits.shape().r2()?;
let (seq_len, _) = logits.dims2()?;
let logits = logits
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = (&logits / t)?.softmax(0)?;
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
@ -146,8 +144,7 @@ impl Decoder {
.unwrap()
};
tokens.push(next_token);
let prob = logits
.softmax(candle::D::Minus1)?
let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
@ -195,7 +192,7 @@ impl Decoder {
}
fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> {
let (_, _, content_frames) = mel.shape().r3()?;
let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0;
let mut segments = vec![];
while seek < content_frames {
@ -282,28 +279,23 @@ fn main() -> Result<()> {
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
)
} else {
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?;
let dataset = api.dataset("Narsil/candle-examples".to_string());
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let sample = if let Some(input) = args.input {
if let Some(sample) = input.strip_prefix("sample:") {
api.get(
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
&format!("samples_{sample}.wav"),
)?
dataset.get(&format!("samples_{sample}.wav"))?
} else {
std::path::PathBuf::from(input)
}
} else {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
api.get(
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
"samples_jfk.wav",
)?
dataset.get("samples_jfk.wav")?
};
(
api.get(&repo, "config.json")?,
api.get(&repo, "tokenizer.json")?,
api.get(&repo, "model.safetensors")?,
repo.get("config.json")?,
repo.get("tokenizer.json")?,
repo.get("model.safetensors")?,
sample,
)
};

View File

@ -2,7 +2,7 @@
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@ -132,7 +132,7 @@ impl MultiHeadAttention {
}
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
let (n_batch, n_ctx, n_state) = x.shape().r3()?;
let (n_batch, n_ctx, n_state) = x.dims3()?;
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
}
@ -144,7 +144,7 @@ impl MultiHeadAttention {
v: &Tensor,
mask: Option<&Tensor>,
) -> Result<Tensor> {
let (_, n_ctx, n_state) = q.shape().r3()?;
let (_, n_ctx, n_state) = q.dims3()?;
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
let q = (self.reshape_head(q)? * scale)?;
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
@ -154,7 +154,7 @@ impl MultiHeadAttention {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
let w = qk.softmax(candle::D::Minus1)?;
let w = softmax(&qk, candle::D::Minus1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
Ok(wv)
}
@ -270,7 +270,7 @@ impl AudioEncoder {
let x = self.conv1.forward(x)?.gelu()?;
let x = self.conv2.forward(&x)?.gelu()?;
let x = x.transpose(1, 2)?;
let (_bsize, seq_len, _hidden) = x.shape().r3()?;
let (_bsize, seq_len, _hidden) = x.dims3()?;
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
let mut x = x.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {

View File

@ -0,0 +1,24 @@
[package]
name = "candle-flash-attn"
version = "0.1.0"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], version = "0.1.0", package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
num_cpus = "1.15.0"
rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", version = "0.1.0", features = ["cuda"] }

View File

@ -0,0 +1 @@
# candle-flash-attn

252
candle-flash-attn/build.rs Normal file
View File

@ -0,0 +1,252 @@
// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel.
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result};
use rayon::prelude::*;
use std::path::PathBuf;
use std::str::FromStr;
const KERNEL_FILES: [&str; 9] = [
"flash_api.cu",
"flash_fwd_hdim128_fp16_sm80.cu",
"flash_fwd_hdim160_fp16_sm80.cu",
"flash_fwd_hdim192_fp16_sm80.cu",
"flash_fwd_hdim224_fp16_sm80.cu",
"flash_fwd_hdim256_fp16_sm80.cu",
"flash_fwd_hdim32_fp16_sm80.cu",
"flash_fwd_hdim64_fp16_sm80.cu",
"flash_fwd_hdim96_fp16_sm80.cu",
// "flash_fwd_hdim128_bf16_sm80.cu",
// "flash_fwd_hdim160_bf16_sm80.cu",
// "flash_fwd_hdim192_bf16_sm80.cu",
// "flash_fwd_hdim224_bf16_sm80.cu",
// "flash_fwd_hdim256_bf16_sm80.cu",
// "flash_fwd_hdim32_bf16_sm80.cu",
// "flash_fwd_hdim64_bf16_sm80.cu",
// "flash_fwd_hdim96_bf16_sm80.cu",
];
fn main() -> Result<()> {
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|_| num_cpus::get_physical(),
|s| usize::from_str(&s).unwrap(),
);
rayon::ThreadPoolBuilder::new()
.num_threads(num_cpus)
.build_global()
.unwrap();
println!("cargo:rerun-if-changed=build.rs");
for kernel_file in KERNEL_FILES.iter() {
println!("cargo:rerun-if-changed=kernels/{kernel_file}");
}
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
println!("cargo:rerun-if-changed=kernels/flash.h");
println!("cargo:rerun-if-changed=kernels/philox.cuh");
println!("cargo:rerun-if-changed=kernels/softmax.h");
println!("cargo:rerun-if-changed=kernels/utils.h");
println!("cargo:rerun-if-changed=kernels/kernel_traits.h");
println!("cargo:rerun-if-changed=kernels/block_info.h");
println!("cargo:rerun-if-changed=kernels/static_switch.h");
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
Err(_) =>
{
#[allow(clippy::redundant_clone)]
out_dir.clone()
}
Ok(build_dir) => PathBuf::from(build_dir),
};
set_cuda_include_dir()?;
let compute_cap = compute_cap()?;
let out_file = build_dir.join("libflashattention.a");
let kernel_dir = PathBuf::from("kernels");
let cu_files: Vec<_> = KERNEL_FILES
.iter()
.map(|f| {
let mut obj_file = out_dir.join(f);
obj_file.set_extension("o");
(kernel_dir.join(f), obj_file)
})
.collect();
let should_compile = if out_file.exists() {
cu_files.iter().any(|(cu_file, _)| {
let out_modified = out_file.metadata().unwrap().modified().unwrap();
let in_modified = cu_file.metadata().unwrap().modified().unwrap();
in_modified.duration_since(out_modified).is_ok()
})
} else {
true
};
if should_compile {
cu_files
.par_iter()
.map(|(cu_file, obj_file)| {
let mut command = std::process::Command::new("nvcc");
command
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("-c")
.args(["-o", obj_file.to_str().unwrap()])
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
Ok(())
})
.collect::<Result<()>>()?;
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
let mut command = std::process::Command::new("nvcc");
command
.arg("--lib")
.args(["-o", out_file.to_str().unwrap()])
.args(obj_files);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while linking:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
finishing to run for some reason. Calling nvcc manually worked fine.
cc::Build::new()
.cuda(true)
.include("cutlass/include")
.flag("--expt-relaxed-constexpr")
.flag("--default-stream")
.flag("per-thread")
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
.compile("flashattn");
*/
Ok(())
}
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
// Grab compute code from nvidia-smi
let mut compute_cap = {
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
};
// Grab available GPU codes from nvcc and select the highest one
let max_nvcc_code = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
if !codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
);
}
*codes.last().unwrap()
};
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
);
compute_cap = max_nvcc_code;
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
Ok(compute_cap)
}

View File

@ -0,0 +1,41 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen=true>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
{
}
template <typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template <typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const uint32_t actual_seqlen_q;
const uint32_t actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash

View File

@ -0,0 +1,141 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
// #ifdef OLD_GENERATOR_PATH
// #include <ATen/CUDAGeneratorImpl.h>
// #else
// #include <ATen/cuda/CUDAGeneratorImpl.h>
// #endif
//
// #include <ATen/cuda/CUDAGraphsUtils.cuh>
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = uint32_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
// The pointer to the P matrix.
void * __restrict__ p_ptr;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int *__restrict__ blockmask;
// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
// uint16_t p_dropout_in_uint16_t;
uint8_t p_dropout_in_uint8_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;
// Random state.
// at::PhiloxCudaState philox_args;
bool is_bf16;
bool is_causal;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_bwd_params : public Flash_fwd_params {
// The dO and dQKV matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;
// The stride between rows of the dO, dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
index_t dq_batch_stride;
index_t dk_batch_stride;
index_t dv_batch_stride;
index_t dq_row_stride;
index_t dk_row_stride;
index_t dv_row_stride;
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);

View File

@ -0,0 +1,112 @@
#include "flash_fwd_launch_template.h"
// TODO: Switch back to handling bf16.
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
FWD_HEADDIM_SWITCH(params.d, [&] {
run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
});
}
// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// FP16_SWITCH(!params.is_bf16, [&] {
// FWD_HEADDIM_SWITCH(params.d, [&] {
// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
// });
// });
// }
extern "C" void run_mha(
void *q_ptr,
void *k_ptr,
void *v_ptr,
void *o_ptr,
void *softmax_lse_ptr,
int32_t *cu_seqlens_q_ptr,
int32_t *cu_seqlens_k_ptr,
uint32_t q_batch_stride,
uint32_t k_batch_stride,
uint32_t v_batch_stride,
uint32_t o_batch_stride,
uint32_t q_row_stride,
uint32_t k_row_stride,
uint32_t v_row_stride,
uint32_t o_row_stride,
uint32_t q_head_stride,
uint32_t k_head_stride,
uint32_t v_head_stride,
uint32_t o_head_stride,
uint32_t b,
uint32_t h,
uint32_t h_k,
uint32_t d,
uint32_t d_rounded,
float softmax_scale,
uint32_t seqlen_q,
uint32_t seqlen_k,
uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded,
int is_causal
) {
Flash_fwd_params params;
// Reset the parameters
memset(&params, 0, sizeof(params));
// Set the pointers and strides.
params.q_ptr = q_ptr;
params.k_ptr = k_ptr;
params.v_ptr = v_ptr;
params.o_ptr = o_ptr;
params.softmax_lse_ptr = softmax_lse_ptr;
// All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride;
params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride;
params.v_row_stride = v_row_stride;
params.o_row_stride = o_row_stride;
params.q_head_stride = q_head_stride;
params.k_head_stride = k_head_stride;
params.v_head_stride = v_head_stride;
params.o_head_stride = o_head_stride;
// Set the dimensions.
params.b = b;
params.h = h;
params.h_k = h_k;
params.h_h_k_ratio = h / h_k;
params.seqlen_q = seqlen_q;
params.seqlen_k = seqlen_k;
params.seqlen_q_rounded = seqlen_q_rounded;
params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d;
params.d_rounded = d_rounded;
params.is_causal = is_causal;
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
params.is_bf16 = 0;
params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`.
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);
}

View File

@ -0,0 +1,19 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,32 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
// // 1st ones are good for H100, A100
// // 2nd one is good for A6000 bc we get slightly better occupancy
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
// // 1st one is good for H100, A100, A6000
// }
// }
template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
}

View File

@ -0,0 +1,17 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,27 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
// // For A100, H100, 1st is fastest.
// });
// }
template<>
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
}

View File

@ -0,0 +1,16 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
}

View File

@ -0,0 +1,27 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // This one is slightly faster for causal?
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
// });
// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
// // For A6000, 1st is faster when causal, 3rd is faster when not causal
// }
template<>
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
}

View File

@ -0,0 +1,9 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
}

Some files were not shown because too many files have changed in this diff Show More