Compare commits

..

109 Commits

Author SHA1 Message Date
c2261d0222 Merge. 2024-01-07 20:27:33 +01:00
0eb90ed783 Simpler repro for the neon optimization issue + bugfix (#1544)
* Simpler repro for the neon optimization issue.

* Bugfix for q4k.

* Improve the fix, share the dot-prod bit.

* Clippy fixes.

* Fix for q6k.

* Also fix for q2k.

* Use the new shared dotprod.

* Add more testing.
2024-01-07 20:21:49 +01:00
89b5a06858 Use bindgen-cuda for the custom-kernel example. (#1536)
* Use bindgen-cuda for the custom-kernel example.

* Only depend on the kernels when cuda is enabled.

* Skip rustfmt.
2024-01-07 17:18:46 +01:00
30313c3081 Moving to a proper build crate bindgen_cuda. (#1531)
* Moving to a proper build crate `bindgen_cuda`.

* Fmt.
2024-01-07 12:29:24 +01:00
e72d52b1a2 Unpin more of the workplace relative dependencies. (#1535) 2024-01-07 12:26:20 +01:00
b4cb982e49 Simplifying our internal cargo dependencies. (#1529) 2024-01-07 12:04:14 +01:00
06d186355b Change more consitently the test. 2024-01-06 15:20:55 +01:00
2bbd544832 Non random for better quantization quality 2024-01-06 15:16:01 +01:00
84250bf52f fix index_pos bug when kv cache is disabled. (#1517)
* fix index_pos bug when kv cache is disabled

* Tweak the fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-01-06 11:43:01 +01:00
8d1a57c9a0 chore: update flash attention kernels (#1518)
* chore: update flash attention kernels

* fmt

* remove unused kernels

* force f32

* correct stride
2024-01-05 18:28:55 +01:00
504d0b9ac7 Potential bug on q4k. 2024-01-05 14:15:47 +01:00
3a7304cb0d add link to gpt-from-scratch-rs (#1525) 2024-01-05 11:59:46 +01:00
fa3ea98ba9 Adding bfloat16 support for the cast kernels. (#1520) 2024-01-04 12:12:56 +01:00
135ae5f3eb Simplify the one-hot implementation, support arbitrary rank. (#1514)
* Simplify the one-hot implementation, support arbitrary rank.

* More cleanup.
2024-01-01 11:40:17 +01:00
41614b4a9b Add one-hot/cold encoding (#1489)
* add one-hot encoding

* one_hot: improve error handling, use generic to_vecN::<D>

Bails if the index value is equal to or greater than the depth value,
which would result in an out-of-bounds error.

A redundant check is added to ensure the index value does not exceed
the length of the one-hot matrix size, which would also result in an
out-of-bounds error.

Bails if the index value is less than -1. If the index value is -1,
then it ignores the setting of the on_value for the index value. Only
values that are less than -1 are considered errors.

* one-hot: use two generics, one_hot::<I, O>, for input and output data types

Separating the input and output data types allows the input tensor
indices to be a different data type than the output encoded tensor data type.

For example, one_hot::<i64, u8>(...) will take an input tensor of i64 values
and encode the output tensor using u8 values.

The generic I::DTYPE must match the data type of the input indices, otherwise
the method will bail.

Additionally, this method adds an `allow_f64` option to enable the input indices
data type to be f64 values. f64 values are disabled by default.

TODO: indices data type and the generic I data type are currently not compile-time
checked.

* one_hot: remove input generic, use indices dtype matching

This commit removes the to_f64() type cast and explicitly
matches the DType from the input tensor. Currently, only U8,
U32 and I64 is supported for input tensors.

The match arms on the dtype is verbose. It would be nice
to use a generic type with the WithDtype traitbound to
pass to the to_vecN method and then return an inner value.

Open to suggestions for better approaches here to reduce
the match arm verbosity.

* one_hot: use flat_map iterator over dims instead of nested for loop

This commit replaces the nested for loops with an flat map iter over
the dimensions of the input tensor.

This commit also adds a test for a rank 3 input tensor.

* one_hot: use mandatory on/off-values, remove const msgs

This commit also updates doc tests, comments and test cases.

* Small cleanups.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-01-01 11:18:40 +01:00
03ce8caf40 Format properly the Stable Diffusion example run with params (#1511)
Move out the --sd-version flag out of the prompt.
2024-01-01 11:13:35 +01:00
b0fe5e4453 Do not implement Module for BatchNorm. (#1513) 2024-01-01 10:13:13 +01:00
1fb2dd905c Add support for tiny-llama-1.1b. (#1512) 2023-12-31 12:18:25 +01:00
a0facd0e67 Small tweaks to batch-norm. (#1505) 2023-12-30 17:06:07 +01:00
4290b81244 [Breaking] Add training to batchnorm with exponential moving average (#1504)
* Add training to batchnorm with exponential moving average

* Add more checks to batch norm

* Resolve some review comments

* Add with_momentum varients of `new` methods

* Add check for range of momentum variable; update batch norm test

* Run cargo fmt

* Add back num_features parameter

* Format; tiny simplification
2023-12-30 16:42:08 +01:00
51e577a682 Add Policy Gradient to Reinforcement Learning examples (#1500)
* added policy_gradient, modified main, ddpg and README

* fixed typo in README

* removed unnecessary imports

* small refactor

* Use clap for picking up the subcommand to run.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-30 09:01:29 +01:00
0a245e6fa4 Metal: support unary abs (#1503)
* Metal: support unary abs

* cargo fmt
2023-12-30 00:00:12 +01:00
87d7f81b43 Metal: more u8/u32 (#1502)
* Adds more metal u8

* Metal: more u32
2023-12-29 23:56:21 +01:00
4373534d59 Metal: i64 basic support (#1495)
* Adds basic metal i64 support

* metal copy i64
2023-12-29 19:42:50 +01:00
f4a2787217 Merge pull request #1498 from huggingface/debugging_windows_ci
Fix CI
2023-12-29 12:33:50 +01:00
488e02a3f6 Merge pull request #1496 from bayedieng/unary
Implement urecip op for metal backend
2023-12-29 12:20:52 +01:00
adc95ca2bf Ignore skipped. 2023-12-29 12:15:57 +01:00
4907c63ea1 Ignore stop on remote forks. 2023-12-29 12:12:10 +01:00
d76ac20e0e Fix. 2023-12-29 12:06:38 +01:00
f5c98f22c7 Merge pull request #1491 from mimiquate/metal-errors
Improves metal's not implemented error messages
2023-12-29 12:03:40 +01:00
5b12fbb143 Trying to fix flakyness by making hub_2 and hub_3 serial tests (potential issue on mingw with mmap). 2023-12-29 11:13:33 +01:00
cc06ba2294 fix bad pattern matching and function name 2023-12-29 09:46:24 +00:00
a6bd0b47a5 Fix the CI. 2023-12-29 10:17:52 +01:00
b59b1b2bb6 remove generated png 2023-12-28 21:50:58 +00:00
3922b42c18 add urecip op to metal backend 2023-12-28 21:50:12 +00:00
1e442d4bb9 Fix lints for clippy 1.75. (#1494) 2023-12-28 20:26:20 +01:00
cd889c0f8a add config_amazon_mistral_lite (#1493)
Co-authored-by: Ubuntu <danielclough@users.noreply.github.com>
2023-12-28 19:59:58 +01:00
8e93e76a91 fixes error message 2023-12-28 15:03:05 -03:00
b3e838f3e2 cargo fmt 2023-12-28 14:07:34 -03:00
8bf892403a Improves metal's not implemented error messages 2023-12-28 11:04:06 -03:00
d35f0a1376 Bump the crate version to 0.3.3. (#1490) 2023-12-28 13:38:30 +01:00
65cb90bd40 Add some mention to SOLAR-10.7B in the readme. (#1487) 2023-12-27 15:25:39 +01:00
996a7f2e24 Rework the llama example config, add the solar model. (#1485) 2023-12-26 22:24:04 +01:00
3071ea6c3e Use the new hub helper function. (#1484) 2023-12-26 09:44:30 +01:00
37c539f2b7 Helper function to load sharded safetensors files (#1481)
* Fix the quantized mistral example.

* Add a helper function to load sharded safetensors weights.

* Use the sharded loader.
2023-12-25 21:49:21 +01:00
eae3a20d43 Merge pull request #1479 from huggingface/upsample_metal
Adding upsample_nearest_2d.
2023-12-25 14:25:53 +01:00
13a5d15ebc Adding upsample_nearest_2d. 2023-12-25 14:25:19 +01:00
1505d85276 Merge pull request #1461 from huggingface/metal-conv
Adding the convolutions (1d + 2d) to candle on metal.
2023-12-25 12:48:09 +01:00
95e18ef675 Fixing matmul for convolutions. 2023-12-25 12:29:34 +01:00
7135791dd5 Fix the quantized mistral example. (#1478) 2023-12-25 09:31:24 +01:00
88589d8815 Support mistral instruct v0.2. (#1475)
* Support mistral instruct v0.2.

* Use the safetensors model now that they are available.
2023-12-23 16:18:49 +01:00
5b35fd0fcf MMLU evaluation for Phi. (#1474)
* MMLU evaluation for Phi.

* Improve the evaluation.
2023-12-23 15:28:36 +01:00
ba1fae590e Validate the kernel size in pooling ops. (#1473)
* Validate the kernel size in pooling ops.

* Revert the changes to basics.
2023-12-23 11:19:22 +01:00
78d982e1bd Fix for mamba 2.8b. (#1472) 2023-12-23 11:01:39 +01:00
d8b9a727fc Support different mamba models. (#1471) 2023-12-23 10:46:02 +01:00
ceb78d3e28 Sketch the minimal mamba example. (#1465)
* Sketch the minimal mamba example.

* Fix rustfmt.

* Forward pass for mamba.

* Finish the forward pass.

* Inference fixes.

* Bugfixes.

* More fixes.

* Add a readme.
2023-12-22 00:28:50 +01:00
f6408a3779 feat: add clear_kv_cache to mistral and qmistral models (#1464) 2023-12-21 21:19:19 +01:00
10d94659c3 Adding the convolutions (1d + 2d) to candle on metal. 2023-12-21 10:39:24 +01:00
563a79afa1 make fn name generic (#1459)
Co-authored-by: Ubuntu <danielclough@users.noreply.github.com>
2023-12-21 02:16:31 +01:00
8ede5f4210 add fn config_chat_ml (#1458)
* add fn config_chat_ml

* Add a link to the original config.

---------

Co-authored-by: Ubuntu <danielclough@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-12-20 21:03:24 +01:00
9fc210fae8 Merge pull request #1318 from huggingface/metal4
Starting to fix some tests.
2023-12-20 15:37:31 +01:00
9b5e4843a6 Optimizing decode matmul (Phi at 28tok/s on M3).
Adding some benchmark in order to help checking out matmul performance.
2023-12-20 09:54:19 +01:00
03641293ee Clippy pass. 2023-12-18 15:22:43 +01:00
064ba17bd7 Remove print. 2023-12-18 11:04:16 +01:00
e8ee253ee0 Missing cast. 2023-12-18 11:01:18 +01:00
8bd3d6b94b Index add. 2023-12-18 10:46:01 +01:00
6a3ca7da0c Scatter add. 2023-12-18 10:32:22 +01:00
96f1a28e39 Add a simple full method. (#1455)
* Add a simple implementation of the full method.

* Add the docstring.
2023-12-17 20:15:57 -05:00
586b6f6fff Adding gather op. 2023-12-17 23:34:12 +01:00
e4b0cc59f5 Adding CMP 2023-12-17 22:32:25 +01:00
0a6e0a8c9a Implement randn (CPU-> device) 2023-12-17 19:09:08 +01:00
972903021c Finish reduce kernels. 2023-12-17 19:07:00 +01:00
94817dac56 Bump the crate version to 0.3.2. (#1452) 2023-12-17 05:34:53 -06:00
1e86717bf2 Fix a couple typos (#1451)
* Mixtral quantized instruct.

* Fix a couple typos.
2023-12-17 05:20:05 -06:00
c630622a07 Expose AdamW parameters (#1449)
* Expose AdamW parameters

* Use reference
2023-12-16 18:41:56 -06:00
c4cfcf1539 Tweak the readme for phi and the default sample length. (#1450) 2023-12-16 18:11:36 -06:00
1782e93de6 Mixtral quantized instruct. (#1447) 2023-12-16 16:16:39 -06:00
cfdf9640a3 Readme tweaks. (#1446) 2023-12-16 06:23:12 -06:00
e12cbfd73b Update the readme to mention mixtral. (#1443) 2023-12-15 19:29:03 -06:00
30a958e5dd Quantized mixtral model (#1442)
* Add the Mixtral model.

* Add more of the mixtral layers.

* Add the final layers for mixtral.

* Sketch the expert selection.

* Add some expert routing logic.

* Hopefully finish the routing logic for mixtral.

* Add the mixtral example.

* Fix the weight filenames.

* Bugfix.

* Another fix.

* Yet another fix + remove the unused pragma.

* Shape fix.

* Support for quantized mixtral.

* Support mixtral in the quantized example.

* Mlp or moe type.

* Fix the expert field namings.

* Refactor the mlp bit.

* More MoE logic.

* Add the MoE quantized logic.

* Fix the experts length.
2023-12-15 19:16:06 -06:00
614842b311 Add the Mixtral model. (#1437)
* Add the Mixtral model.

* Add more of the mixtral layers.

* Add the final layers for mixtral.

* Sketch the expert selection.

* Add some expert routing logic.

* Hopefully finish the routing logic for mixtral.

* Add the mixtral example.

* Fix the weight filenames.

* Bugfix.

* Another fix.

* Yet another fix + remove the unused pragma.

* Shape fix.

* Add a readme.
2023-12-15 14:19:56 -06:00
79eab519fd Fix phi example (#1436)
* Fix phi example

* Remove the cuda mention.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-15 07:01:10 -06:00
6bc92e63cb Addressing a lot of comments. 2023-12-15 13:06:04 +01:00
aa04015098 Remove unwrap(). 2023-12-15 12:23:28 +01:00
8b5059e951 Remove test file. 2023-12-15 11:55:30 +01:00
26540641c1 Renamed all kernel names. 2023-12-15 11:24:47 +01:00
34d83377f6 Better error message on older macos 2023-12-15 11:18:54 +01:00
77197379cc More cleanup. 2023-12-15 11:17:05 +01:00
916a8c5464 Revert candle-transformers. 2023-12-15 11:15:21 +01:00
243e83f2b9 Adding a bunch of docs !
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2023-12-15 11:03:05 +01:00
e60f9b5dfc Speedup ShardedSafeTensors to load Tensors with default hints (#1384)
* Speedup ShardedSafeTensors to load Tensors with default hints

* Tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-14 08:08:56 -06:00
7be982f6f7 Mention phi-2 in the readme. (#1434) 2023-12-14 08:02:27 -06:00
104e196d46 Phi 2 wasm (#1432)
* add phi 2.0 quantized model wasm

* cols

* spell

* bug
2023-12-14 06:04:17 -06:00
5e33c85c8f Quantized version for phi-v2. (#1430)
* Quantized version for phi-v2.

* More quantized support.
2023-12-13 21:16:34 -06:00
2b3a018be7 Support for phi-2. (#1429)
* Support for phi-2.

* Use the v2 naming scheme.
2023-12-13 20:59:29 -06:00
4cb443d00a Fix the logsumexp test. (#1426) 2023-12-12 10:56:11 -06:00
77252ffb82 Add logsumexp function (#1424) 2023-12-12 10:32:17 -06:00
18eb87f25f Upsample grad (#1420)
* encode size of upsample in enum

* working convolution method for limited 2d kernels

* add test for sf 3 interpolation

* add higher dimensional tests, fix to work with multichannel input

* Remove commented out line.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-10 08:43:24 +01:00
9bd94c1ffa Speed up bert with approx gelu (#1410) 2023-12-06 17:46:37 +01:00
236b820e28 Another prelu bugfix. (#1407) 2023-12-06 09:54:41 +01:00
2648e797c2 Use the proper broadcasting for prelu. (#1406) 2023-12-05 07:09:31 +01:00
b5c283e86f Add the prelu layer. (#1402) 2023-12-03 16:06:09 +00:00
8418154ee0 Add nvcc ccbin support to examples (#1401) 2023-12-03 16:01:16 +00:00
99b7273b03 Add compute cap env support to examples (#1400) 2023-12-03 16:00:24 +00:00
16161145ae Add the leo models to the quantized examples. (#1398) 2023-12-03 12:30:41 +00:00
0738df5290 Add more mentions to SDXL Turbo in the readme. (#1397) 2023-12-03 10:41:21 +00:00
37bf1ed012 Stable Diffusion Turbo Support (#1395)
* Add support for SD Turbo

* Set Leading as default in euler_ancestral discrete

* Use the appropriate default values for n_steps and guidance_scale.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-03 08:37:10 +01:00
dd40edfe73 Add Euler Ancestral Discrete Scheduler (#1390)
* Add Euler Ancestral Discrete Scheduler

* Fix a bug of init_noise_sigma generation

* minor fixes

* use partition_point instead of custom bsearch

* Fix some clippy lints.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-12-02 19:59:23 +00:00
5aa1a65dab Add quantized Starling, fix open-chat prompt (#1393)
* Add quantized Starling, fix open-chat prompt

* Fix open-chat and starling prompts
2023-12-02 16:47:19 +00:00
158 changed files with 7211 additions and 2963 deletions

View File

@ -8,6 +8,8 @@ jobs:
start-runner: start-runner:
name: Start self-hosted EC2 runner name: Start self-hosted EC2 runner
runs-on: ubuntu-latest runs-on: ubuntu-latest
# Don't run on forks, they won't have access to secrets anyway.
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
env: env:
AWS_REGION: us-east-1 AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002 EC2_AMI_ID: ami-03cfed9ea28f4b002
@ -70,7 +72,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
AWS_REGION: us-east-1 AWS_REGION: us-east-1
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
steps: steps:
- name: Configure AWS credentials - name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1 uses: aws-actions/configure-aws-credentials@v1

View File

@ -63,7 +63,7 @@ This documents the main changes to the `candle` crate.
[760](https://github.com/huggingface/candle/pull/760). [760](https://github.com/huggingface/candle/pull/760).
- Add the Segment-Anything Model (SAM) as an example - Add the Segment-Anything Model (SAM) as an example
[773](https://github.com/huggingface/candle/pull/773). [773](https://github.com/huggingface/candle/pull/773).
- TinyViT backbone for the segemnt anything example - TinyViT backbone for the segment anything example
[787](https://github.com/huggingface/candle/pull/787). [787](https://github.com/huggingface/candle/pull/787).
- Shape with holes support - Shape with holes support
[770](https://github.com/huggingface/candle/pull/770). [770](https://github.com/huggingface/candle/pull/770).

View File

@ -19,7 +19,7 @@ exclude = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "0.3.1" version = "0.3.3"
edition = "2021" edition = "2021"
description = "Minimalist ML framework." description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle" repository = "https://github.com/huggingface/candle"
@ -31,7 +31,16 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" } accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3" byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core" }
candle-datasets = { path = "./candle-datasets" }
candle-flash-attn = { path = "./candle-flash-attn" }
candle-kernels = { path = "./candle-kernels" }
candle-metal-kernels = { path = "./candle-metal-kernels" }
candle-nn = { path = "./candle-nn" }
candle-onnx = { path = "./candle-onnx" }
candle-transformers = { path = "./candle-transformers" }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.9.14", features = ["f16"] } cudarc = { version = "0.9.14", features = ["f16"] }
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] } gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0" hf-hub = "0.3.0"
@ -61,7 +70,7 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0" wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] } yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false } zip = { version = "0.6.6", default-features = false }
metal = { version = "0.27.0", features = ["mps"], package = "candle-metal" } metal = { version = "0.27.0", features = ["mps"]}
[profile.release-with-debug] [profile.release-with-debug]
inherits = "release" inherits = "release"

View File

@ -54,19 +54,25 @@ These online demos run entirely in your browser:
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition. - [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation. - [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation. - [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation. - [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation. - [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning. - [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
We also provide a some command line based examples using state of the art models: We also provide a some command line based examples using state of the art models:
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM. - [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM. - [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b. - [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. pre-trained on 1T tokens of English and code datasets.
- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal
implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with - [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
performance larger than all publicly available 13b models as of 2023-09-28. better performance than all publicly available 13b models as of 2023-09-28.
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
experts 8x7b general LLM with better performance than a Llama 2 70B model with
much faster inference.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation. - [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion. - [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual - [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
@ -78,7 +84,7 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600"> <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to - [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions. image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200"> <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
@ -122,7 +128,7 @@ There are also some wasm examples for whisper and
[whisper](https://huggingface.co/spaces/lmz/candle-whisper), [whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2), [llama2](https://huggingface.co/spaces/lmz/candle-llama2),
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm), [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm), [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm). [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
For LLaMA2, run the following command to retrieve the weight files and start a For LLaMA2, run the following command to retrieve the weight files and start a
@ -141,8 +147,10 @@ And then head over to
## Useful External Resources ## Useful External Resources
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A - [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
very detailed tutorial showing how to convert a PyTorch model to Candle. very detailed tutorial showing how to convert a PyTorch model to Candle.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has - [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and
out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples). ergonomic LoRA implementation for Candle. `candle-lora` has
out-of-the-box LoRA support for many models from Candle, which can be found
[here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers - [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop. including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and - [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
@ -150,6 +158,7 @@ And then head over to
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle. - [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more. - [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle. - [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
If you have an addition to this list, please submit a pull request. If you have an addition to this list, please submit a pull request.
@ -168,11 +177,13 @@ If you have an addition to this list, please submit a pull request.
- WASM support, run your models in a browser. - WASM support, run your models in a browser.
- Included models. - Included models.
- Language Models. - Language Models.
- LLaMA v1 and v2. - LLaMA v1 and v2 with variants such as SOLAR-10.7B.
- Falcon. - Falcon.
- StarCoder. - StarCoder.
- Phi v1.5. - Phi 1, 1.5, and 2.
- Minimal Mamba
- Mistral 7b v0.1. - Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T. - StableLM-3B-4E1T.
- Replit-code-v1.5-3B. - Replit-code-v1.5-3B.
- Bert. - Bert.
@ -180,8 +191,9 @@ If you have an addition to this list, please submit a pull request.
- Quantized LLMs. - Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants. - Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct. - Mistral 7b, and 7b instruct.
- Zephyr 7b a and b (Mistral based). - Mixtral 8x7b.
- OpenChat 3.5 (Mistral based). - Zephyr 7b a and b (Mistral-7b based).
- OpenChat 3.5 (Mistral-7b based).
- Text to text. - Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction). - T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation). - Marian MT (Machine Translation).

View File

@ -11,11 +11,11 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } candle = { workspace = true }
candle-datasets = { path = "../candle-datasets", version = "0.3.1" } candle-datasets = { workspace = true }
candle-nn = { path = "../candle-nn", version = "0.3.1" } candle-nn = { workspace = true }
candle-transformers = { path = "../candle-transformers", version = "0.3.1" } candle-transformers = { workspace = true }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true } candle-flash-attn = { workspace = true, optional = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }

View File

@ -28,6 +28,7 @@ let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap()
#[rustfmt::skip] #[rustfmt::skip]
#[test] #[test]
fn book_hub_2() { fn book_hub_2() {
{
// ANCHOR: book_hub_2 // ANCHOR: book_hub_2
use candle::Device; use candle::Device;
use hf_hub::api::sync::Api; use hf_hub::api::sync::Api;
@ -45,9 +46,10 @@ let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap()
assert_eq!(weights.len(), 206); assert_eq!(weights.len(), 206);
} }
#[rustfmt::skip] // #[rustfmt::skip]
#[test] // #[test]
fn book_hub_3() { // fn book_hub_3() {
{
// ANCHOR: book_hub_3 // ANCHOR: book_hub_3
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api; use hf_hub::api::sync::Api;
@ -102,6 +104,7 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
assert_eq!(view.shape(), &[768, 768]); assert_eq!(view.shape(), &[768, 768]);
assert_eq!(tp_tensor.dims(), &[192, 768]); assert_eq!(tp_tensor.dims(), &[192, 768]);
} }
}
#[rustfmt::skip] #[rustfmt::skip]
#[test] #[test]

View File

@ -12,8 +12,8 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true } byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true } candle-kernels = { workspace = true, optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true } candle-metal-kernels = { workspace = true, optional = true }
metal = { workspace = true, optional = true} metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
gemm = { workspace = true } gemm = { workspace = true }
@ -34,6 +34,8 @@ zip = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
clap = { workspace = true } clap = { workspace = true }
criterion = { workspace = true }
[features] [features]
default = [] default = []
@ -42,3 +44,8 @@ cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"] mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"] accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"] metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "matmul"
harness = false

View File

@ -0,0 +1,42 @@
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor) {
a.matmul(&b.t().unwrap()).unwrap();
}
fn criterion_benchmark(c: &mut Criterion) {
let b = 1;
let m = 1;
let n = 2048;
let k = 2048;
let device = Device::new_metal(0).unwrap();
let dtype = DType::F32;
let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group("matmul_metal");
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&lhs), black_box(&rhs));
}
if let Device::Metal(device) = &device {
device.wait_until_completed().unwrap();
} else {
panic!("Expected metal device");
}
start.elapsed()
})
});
group.finish();
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

View File

@ -114,7 +114,7 @@ impl Tensor {
| Op::Unary(_node, UnaryOp::Round) => nodes, | Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node) Op::Reshape(node)
| Op::UpsampleNearest1D(node) | Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D(node) | Op::UpsampleNearest2D { arg: node, .. }
| Op::AvgPool2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. }
| Op::Copy(node) | Op::Copy(node)
@ -350,9 +350,27 @@ impl Tensor {
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported { Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d", op: "upsample-nearest1d",
})?, })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { Op::UpsampleNearest2D {
op: "upsample-nearest2d", arg,
})?, target_h,
target_w,
} => {
let (_n, c, h, w) = arg.dims4()?;
if target_h % h != 0 || target_w % w != 0 {
crate::bail!("backward not supported for non integer upscaling factors")
}
let scale_h = target_h / h;
let scale_w = target_w / w;
if scale_h != scale_w {
crate::bail!("backward not supported for non uniform upscaling factors")
};
let kernel =
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = conv_sum;
}
Op::SliceScatter0(lhs, rhs, start_rhs) => { Op::SliceScatter0(lhs, rhs, start_rhs) => {
let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_sum_grad = grads.or_insert(rhs)?;
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;

View File

@ -201,10 +201,9 @@ impl Device {
Ok(Storage::Cuda(storage)) Ok(Storage::Cuda(storage))
} }
} }
Device::Metal(_device) => { Device::Metal(device) => {
// let storage = device.rand_uniform(shape, dtype, lo, up)?; let storage = device.rand_uniform(shape, dtype, lo, up)?;
// Ok(Storage::Metal(storage)) Ok(Storage::Metal(storage))
crate::bail!("Metal rand_uniform not implemented")
} }
} }
} }

View File

@ -64,7 +64,7 @@ impl Tensor {
#[derive(Debug)] #[derive(Debug)]
/// Generic structure used to index a slice of the tensor /// Generic structure used to index a slice of the tensor
pub enum TensorIndexer { pub enum TensorIndexer {
/// This selects the elemnts for which an index has some specific value. /// This selects the elements for which an index has some specific value.
Select(usize), Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor /// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>), Narrow(Bound<usize>, Bound<usize>),

File diff suppressed because it is too large Load Diff

View File

@ -132,7 +132,11 @@ pub enum Op {
}, },
UpsampleNearest1D(Tensor), UpsampleNearest1D(Tensor),
UpsampleNearest2D(Tensor), UpsampleNearest2D {
arg: Tensor,
target_h: usize,
target_w: usize,
},
Cat(Vec<Tensor>, usize), Cat(Vec<Tensor>, usize),

View File

@ -353,7 +353,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
q3 = q3.add(32); q3 = q3.add(32);
// Prepare low and high bits // Prepare low and high bits
// We hardcode the shifts here to avoid loading them into a seperate register // We hardcode the shifts here to avoid loading them into a separate register
let q3l_0 = _mm256_and_si256(q3bits, m3); let q3l_0 = _mm256_and_si256(q3bits, m3);
let q3h_0 = if j == 0 { let q3h_0 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0) _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
@ -586,7 +586,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q5bits = _mm256_loadu_si256(q5 as *const __m256i); let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
q5 = q5.add(32); q5 = q5.add(32);
//Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register //Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
let q5l_0 = _mm256_and_si256(q5bits, m4); let q5l_0 = _mm256_and_si256(q5bits, m4);
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask); let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_0_right_shift = match j { let q5l_0_right_shift = match j {

View File

@ -41,7 +41,7 @@ impl VersionedMagic {
(Magic::Gguf, 1) => Self::GgufV1, (Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2, (Magic::Gguf, 2) => Self::GgufV2,
(Magic::Gguf, 3) => Self::GgufV3, (Magic::Gguf, 3) => Self::GgufV3,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"), _ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"),
}; };
Ok(versioned_magic) Ok(versioned_magic)
} }
@ -463,7 +463,7 @@ impl Content {
) -> Result<QTensor> { ) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) { let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info, Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor-infor for {name}"), None => crate::bail!("cannot find tensor info for {name}"),
}; };
tensor_info.read(reader, self.tensor_data_offset) tensor_info.read(reader, self.tensor_data_offset)
} }

View File

@ -12,6 +12,14 @@ use core::arch::arm::*;
#[cfg(target_arch = "aarch64")] #[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*; use core::arch::aarch64::*;
#[inline(always)]
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
// TODO: dotprod
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
}
#[inline(always)] #[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> { pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0; let qk = QK8_0;
@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
// TODO: Support dotprod when it's available outside of nightly. let pl0 = vdotq_s32(v0_0ls, v1_0l);
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l)); let ph0 = vdotq_s32(v0_0hs, v1_0h);
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
sumv0 = vmlaq_n_f32( sumv0 = vmlaq_n_f32(
sumv0, sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)), vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
// TODO dotprod once this is the intrinsics are. let p0 = vdotq_s32(x0_0, y0_0);
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); let p1 = vdotq_s32(x0_1, y0_1);
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
sumv0 = vmlaq_n_f32( sumv0 = vmlaq_n_f32(
sumv0, sumv0,
@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
for i in (0..QK_K).step_by(16) { for i in (0..QK_K).step_by(16) {
let xs = vld1q_s8(xs.add(i)); let xs = vld1q_s8(xs.add(i));
let ys = vld1q_s8(ys.add(i)); let ys = vld1q_s8(ys.add(i));
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys)); let xy = vdotq_s32(xs, ys);
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
sum_i = vaddq_s32(sum_i, xy) sum_i = vaddq_s32(sum_i, xy)
} }
sumf += vaddvq_s32(sum_i) as f32 * scale sumf += vaddvq_s32(sum_i) as f32 * scale
@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
// TODO: dotprod let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2); scale = scale.add(2);
let p2 = vaddq_s16( let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2); scale = scale.add(2);
let q8bytes = vld1q_s8_x4(q8); let q8bytes = vld1q_s8_x4(q8);
@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
// TODO: dotprod case. let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p0 = vaddq_s16( let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2); scale = scale.add(2);
let p2 = vaddq_s16( let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2); scale = scale.add(2);
} }
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
// TODO: dotprod let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
let p0 = vaddq_s16( sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
);
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
scales = scales.add(1); scales = scales.add(1);
let p2 = vaddq_s16( let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)), let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)), sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
);
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
scales = scales.add(1); scales = scales.add(1);
} }
sumf += d * sumi as f32 - dmin * sumi_mins as f32; sumf += d * sumi as f32 - dmin * sumi_mins as f32;
@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
for j in 0..QK_K / 64 { for j in 0..QK_K / 64 {
let q4bits = vld1q_u8_x2(q4); let q4bits = vld1q_u8_x2(q4);
q4 = q4.add(32); q4 = q4.add(32);
// TODO: dotprod
let q8bytes = vld1q_s8_x2(q8); let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32); q8 = q8.add(32);
let q4bytes = int8x16x2_t( let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
); );
let p0 = vaddq_s16( let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
let q8bytes = vld1q_s8_x2(q8); let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32); q8 = q8.add(32);
@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
); );
let p2 = vaddq_s16( let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
} }
sumf += d * (sumi1 + sumi2) as f32; sumf += d * (sumi1 + sumi2) as f32;
} }
@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3), vreinterpretq_s8_u8(q3h_3),
); );
// TODO: dotprod let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
let p0 = vaddq_s16( let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)), let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)), let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
); isum += vaddvq_s32(p0) * *scale as i32
let p1 = vaddq_s16( + vaddvq_s32(p1) * *scale.add(1) as i32
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)), + vaddvq_s32(p2) * *scale.add(2) as i32
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)), + vaddvq_s32(p3) * *scale.add(3) as i32;
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
scale = scale.add(4); scale = scale.add(4);
let q3h_0 = vbicq_u8(m2, qhbits.0); let q3h_0 = vbicq_u8(m2, qhbits.0);
@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3), vreinterpretq_s8_u8(q3h_3),
); );
// TODO: dotprod let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
let p0 = vaddq_s16( let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)), let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)), let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
); isum += vaddvq_s32(p0) * *scale as i32
let p1 = vaddq_s16( + vaddvq_s32(p1) * *scale.add(1) as i32
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)), + vaddvq_s32(p2) * *scale.add(2) as i32
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)), + vaddvq_s32(p3) * *scale.add(3) as i32;
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
scale = scale.add(4); scale = scale.add(4);
if j == 0 { if j == 0 {
@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
let mut is = 0usize; let mut is = 0usize;
// TODO: dotprod // TODO: dotprod
for _j in 0..QK_K / 128 { for _j in 0..QK_K / 128 {
let q2bits = vld1q_u8_x2(q2); let q2bits = vld1q_u8_x2(q2);
q2 = q2.add(32); q2 = q2.add(32);
@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale(
q2bytes: int8x16x2_t, q2bytes: int8x16x2_t,
q8bytes: int8x16x2_t, q8bytes: int8x16x2_t,
) -> i32 { ) -> i32 {
let p1 = vaddq_s16( let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)), let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)), vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
);
vaddvq_s16(p1) as i32 * aux[is + index] as i32
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
} }

View File

@ -478,23 +478,6 @@ extract_dims!(
(usize, usize, usize, usize, usize) (usize, usize, usize, usize, usize)
); );
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stride() {
let shape = Shape::from(());
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
let shape = Shape::from(42);
assert_eq!(shape.stride_contiguous(), [1]);
let shape = Shape::from((42, 1337));
assert_eq!(shape.stride_contiguous(), [1337, 1]);
let shape = Shape::from((299, 792, 458));
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}
pub trait ShapeWithOneHole { pub trait ShapeWithOneHole {
fn into_shape(self, el_count: usize) -> Result<Shape>; fn into_shape(self, el_count: usize) -> Result<Shape>;
} }
@ -627,3 +610,20 @@ impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
Ok((d1, d2, d3, d4, d).into()) Ok((d1, d2, d3, d4, d).into())
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stride() {
let shape = Shape::from(());
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
let shape = Shape::from(42);
assert_eq!(shape.stride_contiguous(), [1]);
let shape = Shape::from((42, 1337));
assert_eq!(shape.stride_contiguous(), [1337, 1]);
let shape = Shape::from((299, 792, 458));
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}

View File

@ -1,4 +1,4 @@
//! Tensors are N-dimenional matrixes of elements using a single data type. //! Tensors are N-dimensional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)] #![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage}; use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{ use crate::op::{
@ -361,6 +361,16 @@ impl Tensor {
Self::new_impl(array, shape, device, false) Self::new_impl(array, shape, device, false)
} }
/// Returns a new tensor with all the elements having the same specified value. Note that
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
pub fn full<D: crate::WithDType, S: Into<Shape>>(
value: D,
shape: S,
device: &Device,
) -> Result<Self> {
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
}
/// Creates a new 1D tensor from an iterator. /// Creates a new 1D tensor from an iterator.
pub fn from_iter<D: crate::WithDType>( pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>, iter: impl IntoIterator<Item = D>,
@ -386,7 +396,7 @@ impl Tensor {
device: &Device, device: &Device,
) -> Result<Self> { ) -> Result<Self> {
if D::is_zero(&step) { if D::is_zero(&step) {
crate::bail!("step cannot be zero") bail!("step cannot be zero")
} }
let mut data = vec![]; let mut data = vec![];
let mut current = start; let mut current = start;
@ -669,7 +679,7 @@ impl Tensor {
} }
/// Split a tensor into the specified number of chunks, this may return less chunks than /// Split a tensor into the specified number of chunks, this may return less chunks than
/// specificed. /// specified.
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> { pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
let dim = dim.to_index(self.shape(), "chunk")?; let dim = dim.to_index(self.shape(), "chunk")?;
let size = self.dim(dim)?; let size = self.dim(dim)?;
@ -994,7 +1004,11 @@ impl Tensor {
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`. /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> { pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?; let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D); let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
arg,
target_h,
target_w,
});
let storage = self let storage = self
.storage() .storage()
.upsample_nearest2d(self.layout(), target_h, target_w)?; .upsample_nearest2d(self.layout(), target_h, target_w)?;
@ -1027,6 +1041,9 @@ impl Tensor {
let kernel_size = kernel_size.to_usize2(); let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2(); let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?; let (n, c, h, w) = self.dims4()?;
if h < kernel_size.0 || w < kernel_size.1 {
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
}
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1; let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1; let w_out = (w - kernel_size.1) / stride.1 + 1;
@ -1062,6 +1079,9 @@ impl Tensor {
let kernel_size = kernel_size.to_usize2(); let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2(); let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?; let (n, c, h, w) = self.dims4()?;
if h < kernel_size.0 || w < kernel_size.1 {
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
}
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1; let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1; let w_out = (w - kernel_size.1) / stride.1 + 1;
@ -1784,7 +1804,7 @@ impl Tensor {
let is_permutation = let is_permutation =
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i)); dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
if !is_permutation { if !is_permutation {
crate::bail!( bail!(
"dimension mismatch in permute, tensor {:?}, dims: {:?}", "dimension mismatch in permute, tensor {:?}, dims: {:?}",
self.dims(), self.dims(),
dims dims
@ -1863,10 +1883,7 @@ impl Tensor {
Storage::Metal(metal.storage_from_cpu_storage(storage)?) Storage::Metal(metal.storage_from_cpu_storage(storage)?)
} }
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => { (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
// println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
Storage::Cpu(storage.to_cpu_storage()?)
}
(Storage::Cuda(storage), Device::Cuda(cuda)) => { (Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids // TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same. // are the same.
@ -2282,7 +2299,7 @@ impl Tensor {
if left == 0 && right == 0 { if left == 0 && right == 0 {
Ok(self.clone()) Ok(self.clone())
} else if self.elem_count() == 0 { } else if self.elem_count() == 0 {
crate::bail!("cannot use pad_with_same on an empty tensor") bail!("cannot use pad_with_same on an empty tensor")
} else if left == 0 { } else if left == 0 {
let dim = dim.to_index(self.shape(), "pad_with_same")?; let dim = dim.to_index(self.shape(), "pad_with_same")?;
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?; let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
@ -2446,13 +2463,13 @@ impl Tensor {
pub fn normalize_axis(&self, axis: i64) -> Result<usize> { pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
let rank = self.rank() as i64; let rank = self.rank() as i64;
if rank <= axis { if rank <= axis {
crate::bail!("axis {axis} is too large, tensor rank {rank}") bail!("axis {axis} is too large, tensor rank {rank}")
} else if 0 <= axis { } else if 0 <= axis {
Ok(axis as usize) Ok(axis as usize)
} else { } else {
let naxis = rank + axis; let naxis = rank + axis;
if naxis < 0 { if naxis < 0 {
crate::bail!("axis {axis} is too small, tensor rank {rank}") bail!("axis {axis} is too small, tensor rank {rank}")
} }
Ok(naxis as usize) Ok(naxis as usize)
} }
@ -2514,14 +2531,14 @@ impl Tensor {
let src_dims = src.dims(); let src_dims = src.dims();
let self_dims = self.dims(); let self_dims = self.dims();
if self_dims.len() != src_dims.len() { if self_dims.len() != src_dims.len() {
crate::bail!( bail!(
"slice-assign requires input with the same rank {} <> {}", "slice-assign requires input with the same rank {} <> {}",
self_dims.len(), self_dims.len(),
src_dims.len() src_dims.len()
) )
} }
if self_dims.len() != ranges.len() { if self_dims.len() != ranges.len() {
crate::bail!( bail!(
"slice-assign requires input with the same rank as there are ranges {} <> {}", "slice-assign requires input with the same rank as there are ranges {} <> {}",
self_dims.len(), self_dims.len(),
ranges.len() ranges.len()
@ -2541,18 +2558,16 @@ impl Tensor {
std::ops::Bound::Excluded(v) => *v, std::ops::Bound::Excluded(v) => *v,
}; };
if end_excluded <= start_included { if end_excluded <= start_included {
crate::bail!( bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
)
} }
if self_dims[i] < end_excluded { if self_dims[i] < end_excluded {
crate::bail!( bail!(
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}", "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
self_dims[i] self_dims[i]
) )
} }
if end_excluded - start_included != src_dims[i] { if end_excluded - start_included != src_dims[i] {
crate::bail!( bail!(
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i] "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
) )
} }
@ -2561,6 +2576,13 @@ impl Tensor {
} }
mask.where_cond(/* on_true= */ &src, /* on_false= */ self) mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
} }
/// Returns log(sum(exp(tensor), dim)).
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?;
let sum = exp.sum(sum_dims)?;
sum.log()
}
} }
macro_rules! bin_trait { macro_rules! bin_trait {

View File

@ -270,6 +270,166 @@ fn unary_grad(device: &Device) -> Result<()> {
[0.7358, 2.0000, 0.2707, 1.0000] [0.7358, 2.0000, 0.2707, 1.0000]
); );
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06.,
07., 08., 09., 10., 11., 12.,
13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
],
device,
)?;
// gradient should be
// row 1
// 1+2+7+8 = 18
// 3+4+9+10 = 26
// 5+6+11+12 = 34
// row 2
// 13+14+19+20 = 66
// 15+16+21+22 = 74
// 17+18+23+24 = 82
// row 3
// 25+26+31+32 = 114
// 27+28+33+34 = 122
// 29+30+35+36 = 130
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
[[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]]
);
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06.,
07., 08., 09., 10., 11., 12.,
13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
],
device,
)?;
// gradient should be
// row 1
// 1+2+3+7+8+9+13+14+15 = 72
// 4+5+6+10+11+12+16+17+18 = 99
// row 2
// 19+20+21+25+26+27+31+32+33 = 234
// 22+23+24+28+29+30+34+35+36 = 243
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
[[72_f32, 99.], [234., 261.]]
);
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?;
let y = x.interpolate2d(4, 4)?.reshape(32)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04.,
05., 06., 07., 08.,
09., 10., 11., 12.,
13., 14., 15., 16.,
17., 18., 19., 20.,
21., 22., 23., 24.,
25., 26., 27., 28.,
29., 30., 31., 32.
],
device,
)?;
// gradient should be
// m1r1
// 1+2+5+6=14
// 3+4+7+8=22
// m1r2
// 9+10+13+14=46
// 11+12+15+16=54
// m2r1
// 17+18+21+22=78
// 19+20+23+24=86
// m2r2
// 25+26+29+30=110
// 27+28+31+32=118
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
);
// manually checked: see comments
let x = Var::new(
&[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]],
device,
)?;
let y = x.interpolate2d(4, 4)?.reshape(32)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04.,
05., 06., 07., 08.,
09., 10., 11., 12.,
13., 14., 15., 16.,
17., 18., 19., 20.,
21., 22., 23., 24.,
25., 26., 27., 28.,
29., 30., 31., 32.
],
device,
)?;
// gradient should be
// m1r1
// 1+2+5+6=14
// 3+4+7+8=22
// m1r2
// 9+10+13+14=46
// 11+12+15+16=54
// m2r1
// 17+18+21+22=78
// 19+20+23+24=86
// m2r2
// 25+26+29+30=110
// 27+28+31+32=118
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
);
Ok(()) Ok(())
} }

View File

@ -1,4 +1,5 @@
use candle_core::{ use candle_core::{
bail,
quantized::{self, GgmlDType}, quantized::{self, GgmlDType},
test_utils::to_vec2_round, test_utils::to_vec2_round,
Device, Module, Result, Tensor, Device, Module, Result, Tensor,
@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
} }
} }
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 /// Creates a vector similar to the ones used in GGML unit tests:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
fn create_ggml_like_vector(offset: f32) -> Vec<f32> { fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
(0..GGML_TEST_SIZE) (0..GGML_TEST_SIZE)
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos()) .map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
sum / a.len() as f32 sum / a.len() as f32
} }
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 /// Similar to the GGML quantization unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> { fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0); let src = create_ggml_like_vector(0.0);
let mut dst = vec![0.0; GGML_TEST_SIZE]; let mut dst = vec![0.0; GGML_TEST_SIZE];
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?; let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
let error = calculate_rmse(src.as_slice(), dst.as_slice()); let error = calculate_rmse(src.as_slice(), dst.as_slice());
if error > max_error { if error > max_error {
candle_core::bail!( bail!(
"Quantization error {} exceeds max error {}", "Quantization error {} exceeds max error {}",
error, error,
max_error max_error
@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
GgmlDType::Q5K => 0.000740, GgmlDType::Q5K => 0.000740,
GgmlDType::Q6K => 0.000952, GgmlDType::Q6K => 0.000952,
GgmlDType::Q4_0 => 0.001143, GgmlDType::Q4_0 => 0.001143,
GgmlDType::Q4_1 => 0.007784, GgmlDType::Q4_1 => 0.008,
GgmlDType::Q5_0 => 0.001353, GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.001363, GgmlDType::Q5_1 => 0.00149,
GgmlDType::Q8_0 => 0.000092, GgmlDType::Q8_0 => 0.000092,
// Not from the ggml repo. // Not from the ggml repo.
GgmlDType::Q8K => 0.00065, GgmlDType::Q8K => 0.00065,
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",), _ => bail!("No GGML results for quantization type {dtype:?}",),
}; };
Ok(err) Ok(err)
} }
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 /// Similar to the GGML matmul unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> { fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
let a = create_ggml_like_vector(0.0); let a = create_ggml_like_vector(0.0);
let b = create_ggml_like_vector(1.0); let b = create_ggml_like_vector(1.0);
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
// Another example that is more likely to trigger the overflow reported in #1526
let a = (0..GGML_TEST_SIZE)
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
.collect::<Vec<_>>();
let b = (0..GGML_TEST_SIZE)
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
.collect::<Vec<_>>();
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
Ok(())
}
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
let length = a.len(); let length = a.len();
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE]; let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE]; let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
T::from_float(&a, &mut a_quant)?; T::from_float(a, &mut a_quant)?;
T::VecDotType::from_float(&b, &mut b_quant)?; T::VecDotType::from_float(b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?; let result = T::vec_dot(length, &a_quant, &b_quant)?;
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?; let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
let reference_result = vec_dot_reference(&a, &b); let reference_result = vec_dot_reference(a, b);
if (result - result_unopt).abs() / length as f32 > 1e-6 { if (result - result_unopt).abs() / length as f32 > 1e-6 {
candle_core::bail!( bail!(
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}" "the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
) )
} }
let error = (result - reference_result).abs() / length as f32; let error = (result - reference_result).abs() / length as f32;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?; let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
candle_core::bail!( bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
);
} }
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
// => we use a slightly higher error threshold // => we use a slightly higher error threshold
const ERROR_LENIENCY: f32 = 0.00001; const ERROR_LENIENCY: f32 = 0.00001;
if error - ERROR_LENIENCY > ggml_error { if error - ERROR_LENIENCY > ggml_error {
candle_core::bail!( bail!(
"Dot product error {} exceeds ggml reference error {}", "Dot product error {} exceeds ggml reference error {}",
error, error,
ggml_error ggml_error
@ -543,6 +558,36 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
Ok(()) Ok(())
} }
fn get_small_tensors(
m: usize,
k: usize,
n: usize,
device: &Device,
) -> Result<(Tensor, Tensor, Tensor)> {
let lhs = (0..m * k)
.map(|i| i as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..n * k)
.map(|i| i as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_vec(lhs, (m, k), device)?;
let rhs = Tensor::from_vec(rhs, (n, k), device)?;
let mm = lhs.matmul(&rhs.t()?)?;
Ok((lhs, rhs, mm))
}
#[test]
fn quantized_mm() -> Result<()> {
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
Ok(())
}
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result. /// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
fn get_random_tensors( fn get_random_tensors(
m: usize, m: usize,
@ -598,20 +643,30 @@ fn quantized_matmul_q3k() -> Result<()> {
let cpu = &Device::Cpu; let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21); let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; let (lhs, rhs, mm) = get_small_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]); // assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?; // let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); // let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); // assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?; let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let qmm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]); let error: f32 = ((&mm - &qmm)?.abs()? / &mm.abs()?)?
let dst = mm.flatten_all()?.to_vec1::<f32>()?; .sum_all()?
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); .to_scalar()?;
assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]); let error = error / (m * n) as f32;
// assert_eq!(qmm.dims(), [m, n]);
// let dst = qmm.flatten_all()?.to_vec1::<f32>()?;
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
// assert_eq!(dst, [1.029, 1.418, -0.314, 1.495]);
assert!(
error < 0.01,
"{error} is too big, shouldn't exceed a few percent. \nGot:{qmm}\nExpected:\n{mm} "
);
ggml_matmul_error_test::<BlockQ3K>()?; ggml_matmul_error_test::<BlockQ3K>()?;
@ -624,20 +679,30 @@ fn quantized_matmul_q4k() -> Result<()> {
let cpu = &Device::Cpu; let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21); let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?; let (lhs, rhs, mm) = get_small_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]); // assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?; // let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); // let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); // assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?; let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let qmm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]); let error: f32 = ((&mm - &qmm)?.abs()? / &mm.abs()?)?
let dst = mm.flatten_all()?.to_vec1::<f32>()?; .sum_all()?
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); .to_scalar()?;
assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]); let error = error / (m * n) as f32;
assert!(
error < 0.01,
"{error} is too big, shouldn't exceed a few percent. \nGot:{qmm}\nExpected:\n{mm} "
);
// assert_eq!(mm.dims(), [m, n]);
// let dst = mm.flatten_all()?.to_vec1::<f32>()?;
// let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
// assert_eq!(dst, [1.125, 1.435, -0.201, 1.589]);
ggml_matmul_error_test::<BlockQ4K>()?; ggml_matmul_error_test::<BlockQ4K>()?;

View File

@ -1,4 +1,4 @@
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor}; use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
fn zeros(device: &Device) -> Result<()> { fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?; let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@ -32,6 +32,14 @@ fn ones(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn full(device: &Device) -> Result<()> {
assert_eq!(
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
[[42, 42, 42], [42, 42, 42]],
);
Ok(())
}
fn arange(device: &Device) -> Result<()> { fn arange(device: &Device) -> Result<()> {
assert_eq!( assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?, Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
@ -900,9 +908,7 @@ fn matmul(device: &Device) -> Result<()> {
let b = Tensor::from_slice(&data, (2, 2), device)?; let b = Tensor::from_slice(&data, (2, 2), device)?;
let c = a.matmul(&b)?; let c = a.matmul(&b)?;
let d = a.matmul(&c)?;
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]); assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
assert_eq!(d.to_vec2::<f32>()?, &[[37.0, 54.0], [81.0, 118.0]]);
let data = vec![1.0f32, 2.0]; let data = vec![1.0f32, 2.0];
let a = Tensor::from_slice(&data, (2, 1), device)?; let a = Tensor::from_slice(&data, (2, 1), device)?;
@ -1074,6 +1080,7 @@ fn randn(device: &Device) -> Result<()> {
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal); test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal); test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal); test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal); test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
@ -1223,3 +1230,26 @@ fn cumsum() -> Result<()> {
); );
Ok(()) Ok(())
} }
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
let a_vec: Vec<f64> = a.to_vec1()?;
let b_vec: Vec<f64> = b.to_vec1()?;
assert_eq!(a_vec.len(), b_vec.len());
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
assert!((a - b).abs() < epsilon);
}
Ok(())
}
#[test]
fn logsumexp() -> Result<()> {
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let output = input.logsumexp(D::Minus1)?;
// The expectations obtained from pytorch.
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
assert_close(&output, &expected, 0.00001)?;
Ok(())
}

View File

@ -11,8 +11,8 @@ readme = "README.md"
[dependencies] [dependencies]
byteorder = { workspace = true } byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../candle-nn", version = "0.3.1" } candle-nn = { workspace = true }
hf-hub = { workspace = true} hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true } memmap2 = { workspace = true }

View File

@ -11,14 +11,17 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } candle = { workspace = true }
candle-datasets = { path = "../candle-datasets", version = "0.3.1" } candle-datasets = { workspace = true }
candle-nn = { path = "../candle-nn", version = "0.3.1" } candle-nn = { workspace = true }
candle-transformers = { path = "../candle-transformers", version = "0.3.1" } candle-transformers = { workspace = true }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true } candle-flash-attn = { workspace = true, optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true } candle-onnx = { workspace = true, optional = true }
csv = "1.3.0"
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true } half = { workspace = true, optional = true }
hf-hub = { workspace = true, features=["tokio"]}
image = { workspace = true } image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true } num-traits = { workspace = true }
@ -33,7 +36,6 @@ tokenizers = { workspace = true, features = ["onig"] }
anyhow = { workspace = true } anyhow = { workspace = true }
byteorder = { workspace = true } byteorder = { workspace = true }
clap = { workspace = true } clap = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]}
imageproc = { workspace = true } imageproc = { workspace = true }
memmap2 = { workspace = true } memmap2 = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
@ -47,11 +49,12 @@ tokio = "1.29.1"
[build-dependencies] [build-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
bindgen_cuda = { version = "0.1.1", optional = true }
[features] [features]
default = [] default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
cudnn = ["candle/cudnn"] cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

View File

@ -4,235 +4,34 @@ use std::io::Write;
use std::path::PathBuf; use std::path::PathBuf;
struct KernelDirectories { struct KernelDirectories {
kernel_dir: &'static str, kernel_glob: &'static str,
rust_target: &'static str, rust_target: &'static str,
include_dirs: &'static [&'static str], include_dirs: &'static [&'static str],
} }
const DIRS: [KernelDirectories; 1] = [KernelDirectories { const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
kernel_dir: "examples/custom-ops/kernels/", kernel_glob: "examples/custom-ops/kernels/*.cu",
rust_target: "examples/custom-ops/cuda_kernels.rs", rust_target: "examples/custom-ops/cuda_kernels.rs",
include_dirs: &[], include_dirs: &[],
}]; }];
impl KernelDirectories {
fn maybe_build_ptx(
&self,
cu_file: &std::path::Path,
ptx_file: &std::path::Path,
compute_cap: usize,
) -> Result<()> {
let should_compile = if ptx_file.exists() {
let ptx_modified = ptx_file.metadata()?.modified()?;
let cu_modified = cu_file.metadata()?.modified()?;
cu_modified.duration_since(ptx_modified).is_ok()
} else {
true
};
if should_compile {
#[cfg(feature = "cuda")]
{
let mut command = std::process::Command::new("nvcc");
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
let include_dirs: Vec<String> =
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
command
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", out_dir.to_str().unwrap()])
.arg(format!("-I/{}", self.kernel_dir))
.args(include_dirs)
.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
#[cfg(not(feature = "cuda"))]
std::fs::OpenOptions::new()
.create(true)
.write(true)
.open(ptx_file)?;
}
Ok(())
}
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
println!("cargo:rerun-if-changed={}", self.kernel_dir);
let kernel_dir = PathBuf::from(self.kernel_dir);
let out_dir = out_dir.join(self.kernel_dir);
if !out_dir.exists() {
std::fs::create_dir_all(&out_dir)?;
}
let mut cu_files = vec![];
let mut cuh_files = vec![];
for file in std::fs::read_dir(kernel_dir)?.flatten() {
let file = file.path();
match file.extension().and_then(|v| v.to_str()) {
Some("cu") => cu_files.push(file),
Some("cuh") => cuh_files.push(file),
_ => {}
}
}
let mut ptx_paths = vec![];
for cu_file in cu_files.iter() {
let file_stem = cu_file
.file_stem()
.with_context(|| format!("no stem {cu_file:?}"))?;
let file_stem = file_stem.to_string_lossy().into_owned();
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
ptx_paths.push(ptx_file);
}
let regenerate_rs_file = true;
if regenerate_rs_file {
let mut file = std::fs::File::create(self.rust_target)?;
for ptx_path in ptx_paths {
let name = ptx_path
.file_stem()
.context("empty stem")?
.to_string_lossy();
file.write_all(b"#[rustfmt::skip]\n")?;
let const_definition = format!(
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
name.to_uppercase().replace('.', "_"),
self.kernel_dir,
);
file.write_all(const_definition.as_bytes())?;
file.write_all(b"\n")?;
}
}
Ok(())
}
}
fn main() -> Result<()> { fn main() -> Result<()> {
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
let out_dir = PathBuf::from(out_dir);
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
set_cuda_include_dir()?; {
#[cfg(feature = "cuda")] for kdir in KERNEL_DIRS.iter() {
let compute_cap = compute_cap()?; let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
println!("cargo:info={builder:?}");
let bindings = builder.build_ptx().unwrap();
bindings.write(kdir.rust_target).unwrap()
}
}
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
let compute_cap = 0; {
for d in DIRS { for kdir in KERNEL_DIRS.iter() {
d.process(&out_dir, compute_cap)? let _file = std::fs::File::create(kdir.rust_target)?;
}
} }
Ok(()) Ok(())
} }
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
// Grab compute code from nvidia-smi
let mut compute_cap = {
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
};
// Grab available GPU codes from nvcc and select the highest one
let max_nvcc_code = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
if !codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
);
}
*codes.last().unwrap()
};
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
);
compute_cap = max_nvcc_code;
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
Ok(compute_cap)
}

View File

@ -2,10 +2,10 @@
Bert is a general large language model. In this example it can be used for two Bert is a general large language model. In this example it can be used for two
different tasks: different tasks:
- Compute sentence embeddings for a prompt. - Compute sentence embeddings for a prompt.
- Compute similarities between a set of sentences. - Compute similarities between a set of sentences.
## Sentence embeddings ## Sentence embeddings
Bert is used to compute the sentence embeddings for a prompt. The model weights Bert is used to compute the sentence embeddings for a prompt. The model weights
@ -24,6 +24,48 @@ cargo run --example bert --release -- --prompt "Here is a test sentence"
> Tensor[[1, 7, 384], f32] > Tensor[[1, 7, 384], f32]
``` ```
### Custom models
You can specify different models, such as BGE, with the `--model-id` flag:
```bash
cargo run --example bert --release -- \
--model-id BAAI/bge-large-zh-v1.5 \
--prompt "Here is a test sentence"
Loaded and encoded 435.70775ms
[[[ 3.0944e-1, -7.8455e-5, -1.2768e0, ..., 1.3755e-2, -3.2371e-1, 2.3819e-1],
[-2.8506e-1, 1.9953e-1, -1.3076e0, ..., 6.9819e-2, 1.0833e-2, -1.1512e0],
[ 3.9892e-1, 2.0000e-1, -9.3178e-1, ..., -4.1393e-1, -4.9644e-2, -3.3786e-1],
...
[ 6.0345e-1, 3.5744e-1, -1.2672e0, ..., -6.9165e-1, -3.4973e-3, -8.4214e-1],
[ 3.9218e-1, -3.2735e-1, -1.3123e0, ..., -4.9318e-1, -5.1334e-1, -3.6391e-1],
[ 3.0978e-1, 2.5662e-4, -1.2773e0, ..., 1.3357e-2, -3.2390e-1, 2.3858e-1]]]
Tensor[[1, 9, 1024], f32]
Took 176.744667ms
```
### Gelu approximation
You can get a speedup by using an approximation of the gelu activation, with a
small loss of precision, by passing the `--approximate-gelu` flag:
```bash
$ cargo run --example bert --release -- \
--model-id BAAI/bge-large-zh-v1.5 \
--prompt "Here is a test sentence" \
--approximate-gelu
Loaded and encoded 244.388042ms
[[[ 3.1048e-1, -6.0339e-4, -1.2758e0, ..., 1.3718e-2, -3.2362e-1, 2.3775e-1],
[-2.8354e-1, 1.9984e-1, -1.3077e0, ..., 6.9390e-2, 9.9681e-3, -1.1531e0],
[ 3.9947e-1, 1.9917e-1, -9.3178e-1, ..., -4.1301e-1, -5.0719e-2, -3.3955e-1],
...
[ 6.0499e-1, 3.5664e-1, -1.2642e0, ..., -6.9134e-1, -3.4581e-3, -8.4471e-1],
[ 3.9311e-1, -3.2812e-1, -1.3105e0, ..., -4.9291e-1, -5.1270e-1, -3.6543e-1],
[ 3.1082e-1, -2.6737e-4, -1.2762e0, ..., 1.3319e-2, -3.2381e-1, 2.3815e-1]]]
Tensor[[1, 9, 1024], f32]
Took 116.840791ms
```
## Similarities ## Similarities
In this example, Bert is used to compute the sentence embeddings for a set of In this example, Bert is used to compute the sentence embeddings for a set of

View File

@ -3,7 +3,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
use candle_transformers::models::bert::{BertModel, Config, DTYPE}; use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle::Tensor; use candle::Tensor;
@ -45,6 +45,10 @@ struct Args {
/// L2 normalization for embeddings. /// L2 normalization for embeddings.
#[arg(long, default_value = "true")] #[arg(long, default_value = "true")]
normalize_embeddings: bool, normalize_embeddings: bool,
/// Use tanh based approximation for Gelu instead of erf implementation.
#[arg(long, default_value = "false")]
approximate_gelu: bool,
} }
impl Args { impl Args {
@ -73,7 +77,7 @@ impl Args {
(config, tokenizer, weights) (config, tokenizer, weights)
}; };
let config = std::fs::read_to_string(config_filename)?; let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?; let mut config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth { let vb = if self.use_pth {
@ -81,6 +85,9 @@ impl Args {
} else { } else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
}; };
if self.approximate_gelu {
config.hidden_act = HiddenAct::GeluApproximate;
}
let model = BertModel::load(vb, &config)?; let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer)) Ok((model, tokenizer))
} }

View File

@ -1,2 +1 @@
#[rustfmt::skip] pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));

View File

@ -6,7 +6,8 @@
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
#[allow(unused)] #[rustfmt::skip]
#[cfg(feature = "cuda")]
mod cuda_kernels; mod cuda_kernels;
use clap::Parser; use clap::Parser;

View File

@ -165,14 +165,7 @@ fn main() -> Result<()> {
args.revision, args.revision,
)); ));
let tokenizer_filename = repo.get("tokenizer.json")?; let tokenizer_filename = repo.get("tokenizer.json")?;
let mut filenames = vec![]; let filenames = candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?;
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = repo.get(rfilename)?;
filenames.push(filename);
}
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

View File

@ -13,7 +13,7 @@ extern crate accelerate_src;
extern crate intel_mkl_src; extern crate intel_mkl_src;
use anyhow::{bail, Error as E, Result}; use anyhow::{bail, Error as E, Result};
use clap::Parser; use clap::{Parser, ValueEnum};
use candle::{DType, Tensor}; use candle::{DType, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
@ -22,11 +22,21 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write; use std::io::Write;
use candle_transformers::models::llama as model; use candle_transformers::models::llama as model;
use model::{Config, Llama, LlamaConfig}; use model::{Llama, LlamaConfig};
const EOS_TOKEN: &str = "</s>"; const EOS_TOKEN: &str = "</s>";
const DEFAULT_PROMPT: &str = "My favorite theorem is "; const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
V1,
V2,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
TinyLlama1_1BChat,
}
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
@ -34,10 +44,6 @@ struct Args {
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// Use npy instead of safetensors
#[arg(long)]
npy: Option<String>,
/// The temperature used to generate samples. /// The temperature used to generate samples.
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
@ -76,17 +82,13 @@ struct Args {
#[arg(long)] #[arg(long)]
revision: Option<String>, revision: Option<String>,
#[arg(long)] /// The model size to use.
v1: bool, #[arg(long, default_value = "v2")]
which: Which,
#[arg(long)] #[arg(long)]
use_flash_attn: bool, use_flash_attn: bool,
/// The folder name that contains safetensor weights and json files
/// (same structure as huggingface online)
#[arg(long)]
local_weights: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty. /// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)] #[arg(long, default_value_t = 1.0)]
repeat_penalty: f32, repeat_penalty: f32,
@ -118,65 +120,34 @@ fn main() -> Result<()> {
Some(dtype) => bail!("Unsupported dtype {dtype}"), Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16, None => DType::F16,
}; };
let (llama, tokenizer_filename, cache) = match args.npy { let (llama, tokenizer_filename, cache) = {
Some(filename) => { let api = Api::new()?;
let config = if args.v1 { let model_id = args.model_id.unwrap_or_else(|| match args.which {
Config::config_7b_v1(args.use_flash_attn) Which::V1 => "Narsil/amall-7b".to_string(),
} else { Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Config::config_7b_v2(args.use_flash_attn) Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
}; Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; });
let vb = VarBuilder::from_npz(filename, dtype, &device)?; println!("loading the model weights from {model_id}");
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json"); let revision = args.revision.unwrap_or("main".to_string());
(Llama::load(vb, &cache, &config)?, tokenizer, cache) let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
}
None => {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| {
if args.v1 {
"Narsil/amall-7b".to_string()
} else {
"meta-llama/Llama-2-7b-hf".to_string()
}
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match &args.local_weights { let tokenizer_filename = api.get("tokenizer.json")?;
Some(path) => (path.to_owned() + "tokenizer.json").into(), let config_filename = api.get("config.json")?;
_ => api.get("tokenizer.json")?, let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
}; let config = config.into_config(args.use_flash_attn);
let config_filename = match &args.local_weights { let filenames = match args.which {
Some(path) => (path.to_owned() + "config.json").into(), Which::V1 | Which::V2 | Which::Solar10_7B => {
_ => api.get("config.json")?, candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
};
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
match &args.local_weights {
Some(path) => {
filenames.push((path.to_owned() + rfilename).into());
}
_ => {
let filename = api.get(rfilename)?;
filenames.push(filename);
}
};
} }
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
};
println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
println!("building the model"); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
}
}; };
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN); let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
@ -194,14 +165,14 @@ fn main() -> Result<()> {
let mut index_pos = 0; let mut index_pos = 0;
let mut token_generated = 0; let mut token_generated = 0;
for index in 0..args.sample_len { for index in 0..args.sample_len {
let context_size = if cache.use_kv_cache && index > 0 { let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
1 (1, index_pos)
} else { } else {
tokens.len() (tokens.len(), 0)
}; };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?; let logits = llama.forward(&input, context_index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {
logits logits

View File

@ -143,14 +143,7 @@ fn main() -> Result<()> {
let config_filename = api.get("config.json")?; let config_filename = api.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let tokenizer_filename = api.get("tokenizer.json")?; let tokenizer_filename = api.get("tokenizer.json")?;
let mut filenames = vec![]; let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = api.get(rfilename)?;
filenames.push(filename);
}
if args.rank.is_none() { if args.rank.is_none() {
let children: Vec<_> = (0..args.num_shards) let children: Vec<_> = (0..args.num_shards)

View File

@ -0,0 +1,12 @@
# candle-mamba-minimal: minimal implementation of Mamba
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
## Running the example
```bash
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
Mamba is the most popular and best-selling game in the world. It has been downloaded more than 1,000 times by over 1 million people worldwide since its release on March 18th 2016.
The Mamba series of games are a collection that combines elements from all genres including action, adventure, strategy & puzzle games with some unique gameplay features such as stealth and survival. The game is also known for its innovative graphics and the ability to play in a variety of different modes like single player or multiplayer.
```

View File

@ -0,0 +1,287 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
mod model;
use model::{Config, Model};
use candle::{DType, Device, Module, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
for _ in 0..sample_len {
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
enum Which {
Mamba130m,
Mamba370m,
Mamba790m,
Mamba1_4b,
Mamba2_8b,
Mamba2_8bSlimPj,
}
impl std::fmt::Display for Which {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Which {
fn model_id(&self) -> &'static str {
match self {
Self::Mamba130m => "state-spaces/mamba-130m",
Self::Mamba370m => "state-spaces/mamba-370m",
Self::Mamba790m => "state-spaces/mamba-790m",
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
}
}
fn revision(&self) -> &'static str {
match self {
Self::Mamba130m
| Self::Mamba370m
| Self::Mamba790m
| Self::Mamba1_4b
| Self::Mamba2_8bSlimPj => "refs/pr/1",
Self::Mamba2_8b => "refs/pr/4",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long, default_value = "mamba130m")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id
.unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("EleutherAI/gpt-neox-20b".to_string())
.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb.pp("backbone"))?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,204 @@
/// This follows the lines of:
/// https://github.com/johnma2006/mamba-minimal/blob/master/model.py
/// Simple, minimal implementation of Mamba in one file of PyTorch.
use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::{RmsNorm, VarBuilder};
use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear};
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
d_model: usize,
n_layer: usize,
vocab_size: usize,
pad_vocab_size_multiple: usize,
}
impl Config {
fn vocab_size(&self) -> usize {
let pad = self.pad_vocab_size_multiple;
(self.vocab_size + pad - 1) / pad * pad
}
fn dt_rank(&self) -> usize {
(self.d_model + 15) / 16
}
fn d_conv(&self) -> usize {
4
}
fn d_state(&self) -> usize {
16
}
fn d_inner(&self) -> usize {
self.d_model * 2
}
}
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L177
#[derive(Clone, Debug)]
pub struct MambaBlock {
in_proj: Linear,
conv1d: candle_nn::Conv1d,
x_proj: Linear,
dt_proj: Linear,
a_log: Tensor,
d: Tensor,
out_proj: Linear,
dt_rank: usize,
}
impl MambaBlock {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let d_inner = cfg.d_inner();
let d_conv = cfg.d_conv();
let d_state = cfg.d_state();
let dt_rank = cfg.dt_rank();
let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?;
let conv_cfg = candle_nn::Conv1dConfig {
groups: d_inner,
padding: d_conv - 1,
..Default::default()
};
let conv1d = candle_nn::conv1d(d_inner, d_inner, d_conv, conv_cfg, vb.pp("conv1d"))?;
let x_proj = linear_no_bias(d_inner, dt_rank + d_state * 2, vb.pp("x_proj"))?;
let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
let a_log = vb.get((d_inner, d_state), "A_log")?;
let d = vb.get(d_inner, "D")?;
let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?;
Ok(Self {
in_proj,
conv1d,
x_proj,
dt_proj,
a_log,
d,
out_proj,
dt_rank,
})
}
fn ssm(&self, xs: &Tensor) -> Result<Tensor> {
let (_d_in, n) = self.a_log.dims2()?;
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
let d = self.d.to_dtype(candle::DType::F32)?;
let x_dbl = xs.apply(&self.x_proj)?;
let delta = x_dbl.narrow(D::Minus1, 0, self.dt_rank)?;
let b = x_dbl.narrow(D::Minus1, self.dt_rank, n)?;
let c = x_dbl.narrow(D::Minus1, self.dt_rank + n, n)?;
let delta = delta.contiguous()?.apply(&self.dt_proj)?;
// softplus without threshold
let delta = (delta.exp()? + 1.)?.log()?;
let ss = selective_scan(xs, &delta, &a, &b, &c, &d)?;
Ok(ss)
}
}
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L275
fn selective_scan(
u: &Tensor,
delta: &Tensor,
a: &Tensor,
b: &Tensor,
c: &Tensor,
d: &Tensor,
) -> Result<Tensor> {
let (b_sz, l, d_in) = u.dims3()?;
let n = a.dim(1)?;
let delta = delta.t()?.reshape((b_sz, d_in, l, 1))?; // b d_in l 1
let delta_a = delta.broadcast_mul(&a.reshape((1, d_in, 1, n))?)?.exp()?;
let delta_b_u = delta
.broadcast_mul(&b.reshape((b_sz, 1, l, n))?)?
.broadcast_mul(&u.t()?.reshape((b_sz, d_in, l, 1))?)?;
let mut xs = Tensor::zeros((b_sz, d_in, n), delta_a.dtype(), delta_a.device())?;
let mut ys = Vec::with_capacity(l);
for i in 0..l {
xs = ((delta_a.i((.., .., i))? * xs)? + delta_b_u.i((.., .., i))?)?;
let y = xs.matmul(&c.i((.., i, ..))?.unsqueeze(2)?)?.squeeze(2)?;
ys.push(y)
}
let ys = Tensor::stack(ys.as_slice(), 1)?;
ys + u.broadcast_mul(d)
}
impl Module for MambaBlock {
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L206
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_b_sz, seq_len, _dim) = xs.dims3()?;
let xs_and_res = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?;
let (xs, res) = (&xs_and_res[0], &xs_and_res[1]);
let xs = xs
.t()?
.apply(&self.conv1d)?
.narrow(D::Minus1, 0, seq_len)?
.t()?;
let xs = candle_nn::ops::silu(&xs)?;
let ys = (self.ssm(&xs)? * candle_nn::ops::silu(res))?;
ys.apply(&self.out_proj)
}
}
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L143
#[derive(Clone, Debug)]
pub struct ResidualBlock {
mixer: MambaBlock,
norm: RmsNorm,
}
impl ResidualBlock {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?;
let mixer = MambaBlock::new(cfg, vb.pp("mixer"))?;
Ok(Self { mixer, norm })
}
}
impl Module for ResidualBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.norm)?.apply(&self.mixer)? + xs
}
}
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56
#[derive(Clone, Debug)]
pub struct Model {
embedding: candle_nn::Embedding,
layers: Vec<ResidualBlock>,
norm_f: RmsNorm,
lm_head: Linear,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?;
let mut layers = Vec::with_capacity(cfg.n_layer);
let vb_l = vb.pp("layers");
for layer_idx in 0..cfg.n_layer {
let layer = ResidualBlock::new(cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
let lm_head = Linear::from_weights(embedding.embeddings().clone(), None);
Ok(Self {
embedding,
layers,
norm_f,
lm_head,
})
}
}
impl Module for Model {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let (_b_size, seq_len) = input_ids.dims2()?;
let mut xs = self.embedding.forward(input_ids)?;
for layer in self.layers.iter() {
xs = layer.forward(&xs)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm_f)?
.apply(&self.lm_head)
}
}

View File

@ -155,8 +155,8 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)] #[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize, sample_len: usize,
#[arg(long, default_value = "lmz/candle-mistral")] #[arg(long)]
model_id: String, model_id: Option<String>,
#[arg(long, default_value = "main")] #[arg(long, default_value = "main")]
revision: String, revision: String,
@ -207,8 +207,18 @@ fn main() -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let api = Api::new()?; let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
if args.quantized {
"lmz/candle-mistral".to_string()
} else {
"mistralai/Mistral-7B-v0.1".to_string()
}
}
};
let repo = api.repo(Repo::with_revision( let repo = api.repo(Repo::with_revision(
args.model_id, model_id,
RepoType::Model, RepoType::Model,
args.revision, args.revision,
)); ));
@ -225,10 +235,7 @@ fn main() -> Result<()> {
if args.quantized { if args.quantized {
vec![repo.get("model-q4k.gguf")?] vec![repo.get("model-q4k.gguf")?]
} else { } else {
vec![ candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
repo.get("pytorch_model-00001-of-00002.safetensors")?,
repo.get("pytorch_model-00002-of-00002.safetensors")?,
]
} }
} }
}; };

View File

@ -0,0 +1,25 @@
# candle-mixtral: 8x7b LLM using a sparse mixture of experts.
Mixtral-8x7B-v0.1 is a pretrained generative LLM with 56 billion parameters.
- [Blog post](https://mistral.ai/news/mixtral-of-experts/) from Mistral announcing the model release.
- [Model card](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) on the HuggingFace Hub.
## Running the example
```bash
$ cargo run --example mixtral --release -- --prompt "def print_prime(n): "
def print_prime(n): # n is the number of prime numbers to be printed
i = 2
count = 0
while (count < n):
if (isPrime(i)):
print(i)
count += 1
i += 1
def isPrime(n):
for x in range(2, int(n**0.5)+1):
if (n % x == 0):
...
```

View File

@ -0,0 +1,241 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::mixtral::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("</s>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::v0_1_8x7b(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -321,7 +321,7 @@ impl MusicgenDecoder {
let positions = self.embed_positions.forward(&input)?.to_device(dev)?; let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
let mut xs = inputs_embeds.broadcast_add(&positions)?; let mut xs = inputs_embeds.broadcast_add(&positions)?;
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?; let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() { for decoder_layer in self.layers.iter_mut() {
xs = decoder_layer.forward(&xs, &attention_mask, None)?; xs = decoder_layer.forward(&xs, &attention_mask, None)?;
} }
let xs = self.layer_norm.forward(&xs)?; let xs = self.layer_norm.forward(&xs)?;

View File

@ -1,14 +1,33 @@
# candle-phi: 1.3b LLM with state of the art performance for <10b models. # candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
only 1.3 billion parameters but with state of the art performance compared to [Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
only 1.3 and 2.7 billion parameters but with state of the art performance compared to
models with up to 10 billion parameters. models with up to 10 billion parameters.
The candle implementation provides both the standard version as well as a The candle implementation provides both the standard version as well as a
quantized variant. quantized variant.
## Running some example ## Running some examples
For the v2 version.
```bash
$ cargo run --example phi --release -- --model 2 \
--prompt "A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?"
A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?
Solution:
The potential energy of the skier is converted into kinetic energy as it slides down the slope. The formula for potential energy is mgh, where m is mass, g is acceleration due to gravity (9.8 m/s^2), and h is height. Since there's no friction, all the potential energy is converted into kinetic energy at the bottom of the slope. The formula for kinetic energy is 1/2mv^2, where v is velocity. We can equate these two formulas:
mgh = 1/2mv^2
Solving for v, we get:
v = sqrt(2gh)
Substituting the given values, we get:
v = sqrt(2*9.8*40) = 28 m/s
Therefore, the skier speed at the bottom of the slope is 28 m/s.
```
For the v1.5 version.
```bash ```bash
$ cargo run --example phi --release -- --prompt "def print_prime(n): " $ cargo run --example phi --release -- --prompt "def print_prime(n): "

View File

@ -123,6 +123,8 @@ enum WhichModel {
V1, V1,
#[value(name = "1.5")] #[value(name = "1.5")]
V1_5, V1_5,
#[value(name = "2")]
V2,
PuffinPhiV2, PuffinPhiV2,
PhiHermes, PhiHermes,
} }
@ -143,7 +145,10 @@ struct Args {
verbose_prompt: bool, verbose_prompt: bool,
#[arg(long)] #[arg(long)]
prompt: String, prompt: Option<String>,
#[arg(long)]
mmlu_dir: Option<String>,
/// The temperature used to generate samples. /// The temperature used to generate samples.
#[arg(long)] #[arg(long)]
@ -158,7 +163,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)] #[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize, sample_len: usize,
#[arg(long)] #[arg(long)]
@ -225,6 +230,7 @@ fn main() -> Result<()> {
match args.model { match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::V2 => "microsoft/phi-2".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string() "lmz/candle-quantized-phi".to_string()
} }
@ -241,7 +247,9 @@ fn main() -> Result<()> {
match args.model { match args.model {
WhichModel::V1 => "refs/pr/2".to_string(), WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/18".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(), WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"main".to_string()
}
} }
} }
} }
@ -250,27 +258,32 @@ fn main() -> Result<()> {
let tokenizer_filename = match args.tokenizer { let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file), Some(file) => std::path::PathBuf::from(file),
None => match args.model { None => match args.model {
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?, WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")? repo.get("tokenizer-puffin-phi-v2.json")?
} }
}, },
}; };
let filename = match args.weight_file { let filenames = match args.weight_file {
Some(weight_file) => std::path::PathBuf::from(weight_file), Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => { None => {
if args.quantized { if args.quantized {
match args.model { match args.model {
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?, WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
WhichModel::V1_5 => repo.get("model-q4k.gguf")?, WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?, WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?, WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
} }
} else { } else {
match args.model { match args.model {
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?, WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?, WhichModel::V2 => candle_examples::hub_load_safetensors(
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?, &repo,
"model.safetensors.index.json",
)?,
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
} }
} }
} }
@ -282,32 +295,127 @@ fn main() -> Result<()> {
let config = match args.model { let config = match args.model {
WhichModel::V1 => Config::v1(), WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(), WhichModel::V1_5 => Config::v1_5(),
WhichModel::V2 => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
}; };
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
let model = QMixFormer::new(&config, vb)?; let model = match args.model {
WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?,
};
(Model::Quantized(model), Device::Cpu) (Model::Quantized(model), Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = MixFormer::new(&config, vb)?; let model = match args.model {
WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
_ => MixFormer::new(&config, vb)?,
};
(Model::MixFormer(model), device) (Model::MixFormer(model), device)
}; };
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new( match (args.prompt, args.mmlu_dir) {
model, (None, None) | (Some(_), Some(_)) => {
tokenizer, anyhow::bail!("exactly one of --prompt and --mmlu-dir must be specified")
args.seed, }
args.temperature, (Some(prompt), None) => {
args.top_p, let mut pipeline = TextGeneration::new(
args.repeat_penalty, model,
args.repeat_last_n, tokenizer,
args.verbose_prompt, args.seed,
&device, args.temperature,
); args.top_p,
pipeline.run(&args.prompt, args.sample_len)?; args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&prompt, args.sample_len)?;
}
(None, Some(mmlu_dir)) => mmlu(model, tokenizer, &device, mmlu_dir)?,
}
Ok(())
}
fn mmlu<P: AsRef<std::path::Path>>(
mut model: Model,
tokenizer: Tokenizer,
device: &Device,
mmlu_dir: P,
) -> anyhow::Result<()> {
for dir_entry in mmlu_dir.as_ref().read_dir()?.flatten() {
let dir_entry = dir_entry.path();
let theme = match dir_entry.file_stem().and_then(|v| v.to_str()) {
None => "".to_string(),
Some(v) => match v.strip_suffix("_test") {
None => v.replace('_', " "),
Some(v) => v.replace('_', " "),
},
};
if dir_entry.extension().as_ref().and_then(|v| v.to_str()) != Some("csv") {
continue;
}
println!("reading {dir_entry:?}");
let dir_entry = std::fs::File::open(dir_entry)?;
let mut reader = csv::ReaderBuilder::new()
.has_headers(false)
.from_reader(dir_entry);
let token_a = tokenizer.token_to_id("A").unwrap();
let token_b = tokenizer.token_to_id("B").unwrap();
let token_c = tokenizer.token_to_id("C").unwrap();
let token_d = tokenizer.token_to_id("D").unwrap();
for row in reader.records() {
let row = match row {
Err(_) => continue,
Ok(row) => row,
};
if row.len() < 5 {
continue;
}
let question = row.get(0).unwrap();
let answer_a = row.get(1).unwrap();
let answer_b = row.get(2).unwrap();
let answer_c = row.get(3).unwrap();
let answer_d = row.get(4).unwrap();
let answer = row.get(5).unwrap();
let prompt = format!(
"{} {theme}.\n{question}\nA. {answer_a}\nB. {answer_b}\nC. {answer_c}\nD. {answer_d}\nAnswer:\n",
"The following are multiple choice questions (with answers) about"
);
let tokens = tokenizer.encode(prompt.as_str(), true).map_err(E::msg)?;
let tokens = tokens.get_ids().to_vec();
let input = Tensor::new(tokens, device)?.unsqueeze(0)?;
let logits = match &mut model {
Model::MixFormer(m) => {
m.clear_kv_cache();
m.forward(&input)?
}
Model::Quantized(m) => {
m.clear_kv_cache();
m.forward(&input)?
}
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits_v: Vec<f32> = logits.to_vec1()?;
let pr_a = logits_v[token_a as usize];
let pr_b = logits_v[token_b as usize];
let pr_c = logits_v[token_c as usize];
let pr_d = logits_v[token_d as usize];
let model_answer = if pr_a > pr_b && pr_a > pr_c && pr_a > pr_d {
"A"
} else if pr_b > pr_c && pr_b > pr_d {
"B"
} else if pr_c > pr_d {
"C"
} else {
"D"
};
println!("{prompt}\n -> {model_answer} vs {answer}");
}
}
Ok(()) Ok(())
} }

View File

@ -26,6 +26,19 @@ cargo run --example quantized --release -- --prompt "The best thing about coding
> The best thing about coding in rust is 1.) that I dont need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines. > The best thing about coding in rust is 1.) that I dont need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
``` ```
Using the mixtral sparse mixture of expert model:
```bash
$ cargo run --example quantized --release -- --which mixtral --prompt "Lebesgue's integral is superior to Riemann's because "
> avx: true, neon: false, simd128: false, f16c: true
> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64
> loaded 995 tensors (26.44GB) in 0.03s
Lebesgue's integral is superior to Riemann's because 1. it is defined for a wider class of functions, those which are absolutely integrable; 2. the definition does not involve limits in two variables---one being computed before the other (which makes some computations more difficult); and 3. interchange of order of integration is easier to establish than with Riemann's integral. On the other hand, Lebesgue's integral applies only for bounded functions defined on finite intervals; it does not provide numerical values for improper integrals. The latter are best evaluated using Cauchy's limit definition.
The reason $f(x) = x^2$ is discontinuous at the ends of its interval of definition, and Riemann's integral requires continuity on the whole of an open interval containing it (see our earlier post), sine no such function exists with this property, is that the endpoints are infinite in measure for Lebesgue's integral.
```
## Command-line flags ## Command-line flags
Run with `--help` to see all options. Run with `--help` to see all options.

View File

@ -45,16 +45,28 @@ enum Which {
L13bCode, L13bCode,
#[value(name = "32b-code")] #[value(name = "32b-code")]
L34bCode, L34bCode,
#[value(name = "7b-leo")]
Leo7b,
#[value(name = "13b-leo")]
Leo13b,
#[value(name = "7b-mistral")] #[value(name = "7b-mistral")]
Mistral7b, Mistral7b,
#[value(name = "7b-mistral-instruct")] #[value(name = "7b-mistral-instruct")]
Mistral7bInstruct, Mistral7bInstruct,
#[value(name = "7b-mistral-instruct-v0.2")]
Mistral7bInstructV02,
#[value(name = "7b-zephyr-a")] #[value(name = "7b-zephyr-a")]
Zephyr7bAlpha, Zephyr7bAlpha,
#[value(name = "7b-zephyr-b")] #[value(name = "7b-zephyr-b")]
Zephyr7bBeta, Zephyr7bBeta,
#[value(name = "7b-open-chat-3.5")] #[value(name = "7b-open-chat-3.5")]
OpenChat35, OpenChat35,
#[value(name = "7b-starling-a")]
Starling7bAlpha,
#[value(name = "mixtral")]
Mixtral,
#[value(name = "mixtral-instruct")]
MixtralInstruct,
} }
impl Which { impl Which {
@ -68,14 +80,20 @@ impl Which {
| Self::L70bChat | Self::L70bChat
| Self::L7bCode | Self::L7bCode
| Self::L13bCode | Self::L13bCode
| Self::L34bCode => false, | Self::L34bCode
| Self::Leo7b
| Self::Leo13b => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. // same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35 Self::OpenChat35
| Self::Starling7bAlpha
| Self::Zephyr7bAlpha | Self::Zephyr7bAlpha
| Self::Zephyr7bBeta | Self::Zephyr7bBeta
| Self::Mixtral
| Self::MixtralInstruct
| Self::Mistral7b | Self::Mistral7b
| Self::Mistral7bInstruct => true, | Self::Mistral7bInstruct
| Self::Mistral7bInstructV02 => true,
} }
} }
@ -90,14 +108,44 @@ impl Which {
| Self::L7bCode | Self::L7bCode
| Self::L13bCode | Self::L13bCode
| Self::L34bCode | Self::L34bCode
| Self::Leo7b
| Self::Leo13b
| Self::Mixtral
| Self::MixtralInstruct
| Self::Mistral7b | Self::Mistral7b
| Self::Mistral7bInstruct | Self::Mistral7bInstruct
| Self::OpenChat35 => false, | Self::Mistral7bInstructV02
| Self::OpenChat35
| Self::Starling7bAlpha => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
} }
} }
fn is_open_chat(&self) -> bool { fn is_open_chat(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b
| Self::Mixtral
| Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
}
}
fn tokenizer_repo(&self) -> &'static str {
match self { match self {
Which::L7b Which::L7b
| Which::L13b | Which::L13b
@ -107,12 +155,18 @@ impl Which {
| Which::L70bChat | Which::L70bChat
| Which::L7bCode | Which::L7bCode
| Which::L13bCode | Which::L13bCode
| Which::L34bCode | Which::L34bCode => "hf-internal-testing/llama-tokenizer",
| Which::Mistral7b Which::Leo7b => "LeoLM/leo-hessianai-7b",
Which::Leo13b => "LeoLM/leo-hessianai-13b",
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Which::Mistral7b
| Which::Mistral7bInstruct | Which::Mistral7bInstruct
| Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha | Which::Zephyr7bAlpha
| Which::Zephyr7bBeta => false, | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Which::OpenChat35 => true, Which::OpenChat35 => "openchat/openchat_3.5",
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
} }
} }
} }
@ -120,7 +174,7 @@ impl Which {
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp /// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from llama.cpp
#[arg(long)] #[arg(long)]
model: Option<String>, model: Option<String>,
@ -181,13 +235,7 @@ impl Args {
Some(config) => std::path::PathBuf::from(config), Some(config) => std::path::PathBuf::from(config),
None => { None => {
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let repo = if self.which.is_open_chat() { let repo = self.which.tokenizer_repo();
"openchat/openchat_3.5"
} else if self.which.is_mistral() {
"mistralai/Mistral-7B-v0.1"
} else {
"hf-internal-testing/llama-tokenizer"
};
let api = api.model(repo.to_string()); let api = api.model(repo.to_string());
api.get("tokenizer.json")? api.get("tokenizer.json")?
} }
@ -218,6 +266,22 @@ impl Args {
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"), Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"), Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"), Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
Which::Leo7b => (
"TheBloke/leo-hessianai-7B-GGUF",
"leo-hessianai-7b.Q4_K_M.gguf",
),
Which::Leo13b => (
"TheBloke/leo-hessianai-13B-GGUF",
"leo-hessianai-13b.Q4_K_M.gguf",
),
Which::Mixtral => (
"TheBloke/Mixtral-8x7B-v0.1-GGUF",
"mixtral-8x7b-v0.1.Q4_K_M.gguf",
),
Which::MixtralInstruct => (
"TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF",
"mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf",
),
Which::Mistral7b => ( Which::Mistral7b => (
"TheBloke/Mistral-7B-v0.1-GGUF", "TheBloke/Mistral-7B-v0.1-GGUF",
"mistral-7b-v0.1.Q4_K_S.gguf", "mistral-7b-v0.1.Q4_K_S.gguf",
@ -226,6 +290,10 @@ impl Args {
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF", "TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf", "mistral-7b-instruct-v0.1.Q4_K_S.gguf",
), ),
Which::Mistral7bInstructV02 => (
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
"mistral-7b-instruct-v0.2.Q4_K_S.gguf",
),
Which::Zephyr7bAlpha => ( Which::Zephyr7bAlpha => (
"TheBloke/zephyr-7B-alpha-GGUF", "TheBloke/zephyr-7B-alpha-GGUF",
"zephyr-7b-alpha.Q4_K_M.gguf", "zephyr-7b-alpha.Q4_K_M.gguf",
@ -234,6 +302,10 @@ impl Args {
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf") ("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
} }
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"), Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
Which::Starling7bAlpha => (
"TheBloke/Starling-LM-7B-alpha-GGUF",
"starling-lm-7b-alpha.Q4_K_M.gguf",
),
}; };
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string()); let api = api.model(repo.to_string());
@ -292,7 +364,7 @@ fn main() -> anyhow::Result<()> {
let mut model = match model_path.extension().and_then(|v| v.to_str()) { let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => { Some("gguf") => {
let model = gguf_file::Content::read(&mut file)?; let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0; let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() { for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count(); let elem_count = tensor.shape.elem_count();
@ -308,7 +380,7 @@ fn main() -> anyhow::Result<()> {
ModelWeights::from_gguf(model, &mut file)? ModelWeights::from_gguf(model, &mut file)?
} }
Some("ggml" | "bin") | Some(_) | None => { Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file)?; let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0; let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() { for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count(); let elem_count = tensor.shape().elem_count();
@ -329,14 +401,20 @@ fn main() -> anyhow::Result<()> {
| Which::L13bChat | Which::L13bChat
| Which::L7bCode | Which::L7bCode
| Which::L13bCode | Which::L13bCode
| Which::L34bCode => 1, | Which::L34bCode
Which::Mistral7b | Which::Leo7b
| Which::Leo13b => 1,
Which::Mixtral
| Which::MixtralInstruct
| Which::Mistral7b
| Which::Mistral7bInstruct | Which::Mistral7bInstruct
| Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha | Which::Zephyr7bAlpha
| Which::Zephyr7bBeta | Which::Zephyr7bBeta
| Which::L70b | Which::L70b
| Which::L70bChat | Which::L70bChat
| Which::OpenChat35 => 8, | Which::OpenChat35
| Which::Starling7bAlpha => 8,
}; };
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
} }
@ -369,7 +447,7 @@ fn main() -> anyhow::Result<()> {
} }
} }
if args.which.is_open_chat() { if args.which.is_open_chat() {
format!("User: {prompt}<|end_of_turn|>Assistant: ") format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:")
} else if args.which.is_zephyr() { } else if args.which.is_zephyr() {
if prompt_index == 0 || is_interactive { if prompt_index == 0 || is_interactive {
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",) format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)

View File

@ -8,9 +8,16 @@ Python package with:
pip install "gymnasium[accept-rom-license]" pip install "gymnasium[accept-rom-license]"
``` ```
In order to run the example, use the following command. Note the additional In order to run the examples, use the following commands. Note the additional
`--package` flag to ensure that there is no conflict with the `candle-pyo3` `--package` flag to ensure that there is no conflict with the `candle-pyo3`
crate. crate.
For the Policy Gradient example:
```bash ```bash
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- pg
```
For the Deep Deterministic Policy Gradient example:
```bash
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples -- ddpg
``` ```

View File

@ -78,7 +78,7 @@ class EpisodicLifeEnv(gym.Wrapper):
# then update lives to handle bonus lives # then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives() lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0: if lives < self.lives and lives > 0:
# for Qbert somtimes we stay in lives == 0 condtion for a few frames # for Qbert sometimes we stay in lives == 0 condition for a few frames
# so its important to keep lives > 0, so that we only reset once # so its important to keep lives > 0, so that we only reset once
# the environment advertises done. # the environment advertises done.
done = True done = True

View File

@ -8,6 +8,8 @@ use candle_nn::{
}; };
use rand::{distributions::Uniform, thread_rng, Rng}; use rand::{distributions::Uniform, thread_rng, Rng};
use super::gym_env::GymEnv;
pub struct OuNoise { pub struct OuNoise {
mu: f64, mu: f64,
theta: f64, theta: f64,
@ -449,3 +451,106 @@ impl DDPG<'_> {
Ok(()) Ok(())
} }
} }
// The impact of the q value of the next state on the current state's q value.
const GAMMA: f64 = 0.99;
// The weight for updating the target networks.
const TAU: f64 = 0.005;
// The capacity of the replay buffer used for sampling training data.
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
// The training batch size for each training iteration.
const TRAINING_BATCH_SIZE: usize = 100;
// The total number of episodes.
const MAX_EPISODES: usize = 100;
// The maximum length of an episode.
const EPISODE_LENGTH: usize = 200;
// The number of training iterations after one episode finishes.
const TRAINING_ITERATIONS: usize = 200;
// Ornstein-Uhlenbeck process parameters.
const MU: f64 = 0.0;
const THETA: f64 = 0.15;
const SIGMA: f64 = 0.1;
const ACTOR_LEARNING_RATE: f64 = 1e-4;
const CRITIC_LEARNING_RATE: f64 = 1e-3;
pub fn run() -> Result<()> {
let env = GymEnv::new("Pendulum-v1")?;
println!("action space: {}", env.action_space());
println!("observation space: {:?}", env.observation_space());
let size_state = env.observation_space().iter().product::<usize>();
let size_action = env.action_space();
let mut agent = DDPG::new(
&Device::Cpu,
size_state,
size_action,
true,
ACTOR_LEARNING_RATE,
CRITIC_LEARNING_RATE,
GAMMA,
TAU,
REPLAY_BUFFER_CAPACITY,
OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?;
let mut rng = rand::thread_rng();
for episode in 0..MAX_EPISODES {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
agent.remember(
&state,
&Tensor::new(vec![action], &Device::Cpu)?,
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
&step.state,
step.terminated,
step.truncated,
);
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
for _ in 0..TRAINING_ITERATIONS {
agent.train(TRAINING_BATCH_SIZE)?;
}
}
println!("Testing...");
agent.train = false;
for episode in 0..10 {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
}
Ok(())
}

View File

@ -6,139 +6,32 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
use candle::Result;
use clap::{Parser, Subcommand};
mod gym_env; mod gym_env;
mod vec_gym_env; mod vec_gym_env;
mod ddpg; mod ddpg;
mod policy_gradient;
use candle::{Device, Result, Tensor}; #[derive(Parser)]
use clap::Parser;
use rand::Rng;
// The impact of the q value of the next state on the current state's q value.
const GAMMA: f64 = 0.99;
// The weight for updating the target networks.
const TAU: f64 = 0.005;
// The capacity of the replay buffer used for sampling training data.
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
// The training batch size for each training iteration.
const TRAINING_BATCH_SIZE: usize = 100;
// The total number of episodes.
const MAX_EPISODES: usize = 100;
// The maximum length of an episode.
const EPISODE_LENGTH: usize = 200;
// The number of training iterations after one episode finishes.
const TRAINING_ITERATIONS: usize = 200;
// Ornstein-Uhlenbeck process parameters.
const MU: f64 = 0.0;
const THETA: f64 = 0.15;
const SIGMA: f64 = 0.1;
const ACTOR_LEARNING_RATE: f64 = 1e-4;
const CRITIC_LEARNING_RATE: f64 = 1e-3;
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args { struct Args {
/// Run on CPU rather than on GPU. #[command(subcommand)]
#[arg(long)] command: Command,
cpu: bool, }
/// Enable tracing (generates a trace-timestamp.json file). #[derive(Subcommand)]
#[arg(long)] enum Command {
tracing: bool, Pg,
Ddpg,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse(); let args = Args::parse();
match args.command {
let _guard = if args.tracing { Command::Pg => policy_gradient::run()?,
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); Command::Ddpg => ddpg::run()?,
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let env = gym_env::GymEnv::new("Pendulum-v1")?;
println!("action space: {}", env.action_space());
println!("observation space: {:?}", env.observation_space());
let size_state = env.observation_space().iter().product::<usize>();
let size_action = env.action_space();
let mut agent = ddpg::DDPG::new(
&Device::Cpu,
size_state,
size_action,
true,
ACTOR_LEARNING_RATE,
CRITIC_LEARNING_RATE,
GAMMA,
TAU,
REPLAY_BUFFER_CAPACITY,
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?;
let mut rng = rand::thread_rng();
for episode in 0..MAX_EPISODES {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
agent.remember(
&state,
&Tensor::new(vec![action], &Device::Cpu)?,
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
&step.state,
step.terminated,
step.truncated,
);
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
for _ in 0..TRAINING_ITERATIONS {
agent.train(TRAINING_BATCH_SIZE)?;
}
}
println!("Testing...");
agent.train = false;
for episode in 0..10 {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
} }
Ok(()) Ok(())
} }

View File

@ -0,0 +1,146 @@
use super::gym_env::{GymEnv, Step};
use candle::{DType, Device, Error, Module, Result, Tensor};
use candle_nn::{
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
ParamsAdamW, VarBuilder, VarMap,
};
use rand::{distributions::Distribution, rngs::ThreadRng, Rng};
fn new_model(
input_shape: &[usize],
num_actions: usize,
dtype: DType,
device: &Device,
) -> Result<(impl Module, VarMap)> {
let input_size = input_shape.iter().product();
let mut varmap = VarMap::new();
let var_builder = VarBuilder::from_varmap(&varmap, dtype, device);
let model = seq()
.add(linear(input_size, 32, var_builder.pp("lin1"))?)
.add(Activation::Relu)
.add(linear(32, num_actions, var_builder.pp("lin2"))?);
Ok((model, varmap))
}
fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
let mut rewards: Vec<f64> = steps.iter().map(|s| s.reward).collect();
let mut acc_reward = 0f64;
for (i, reward) in rewards.iter_mut().enumerate().rev() {
if steps[i].terminated {
acc_reward = 0.0;
}
acc_reward += *reward;
*reward = acc_reward;
}
rewards
}
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
let mut rng = rng;
Ok(distribution.sample(&mut rng))
}
pub fn run() -> Result<()> {
let env = GymEnv::new("CartPole-v1")?;
println!("action space: {:?}", env.action_space());
println!("observation space: {:?}", env.observation_space());
let (model, varmap) = new_model(
env.observation_space(),
env.action_space(),
DType::F32,
&Device::Cpu,
)?;
let optimizer_params = ParamsAdamW {
lr: 0.01,
weight_decay: 0.01,
..Default::default()
};
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
let mut rng = rand::thread_rng();
for epoch_idx in 0..100 {
let mut state = env.reset(rng.gen::<u64>())?;
let mut steps: Vec<Step<i64>> = vec![];
loop {
let action = {
let action_probs: Vec<f32> =
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
.squeeze(0)?
.to_vec1()?;
weighted_sample(action_probs, &mut rng)? as i64
};
let step = env.step(action)?;
steps.push(step.copy_with_obs(&state));
if step.terminated || step.truncated {
state = env.reset(rng.gen::<u64>())?;
if steps.len() > 5000 {
break;
}
} else {
state = step.state;
}
}
let total_reward: f64 = steps.iter().map(|s| s.reward).sum();
let episodes: i64 = steps
.iter()
.map(|s| (s.terminated || s.truncated) as i64)
.sum();
println!(
"epoch: {:<3} episodes: {:<5} avg reward per episode: {:.2}",
epoch_idx,
episodes,
total_reward / episodes as f64
);
let batch_size = steps.len();
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
.to_dtype(DType::F32)?
.detach()?;
let actions_mask = {
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
let actions_mask: Vec<Tensor> = actions
.iter()
.map(|&action| {
// One-hot encoding
let mut action_mask = vec![0.0; env.action_space()];
action_mask[action as usize] = 1.0;
Tensor::from_vec(action_mask, env.action_space(), &Device::Cpu)
.unwrap()
.to_dtype(DType::F32)
.unwrap()
})
.collect();
Tensor::stack(&actions_mask, 0)?.detach()?
};
let states = {
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
Tensor::stack(&states, 0)?.detach()?
};
let log_probs = actions_mask
.mul(&log_softmax(&model.forward(&states)?, 1)?)?
.sum(1)?;
let loss = rewards.mul(&log_probs)?.neg()?.mean_all()?;
optimizer.backward_step(&loss)?;
}
Ok(())
}

View File

@ -8,7 +8,7 @@ XL using Rust and [candle](https://github.com/huggingface/candle).
The `stable-diffusion` example is a conversion of The `stable-diffusion` example is a conversion of
[diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle [diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle
rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1, rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1,
as well as Stable Diffusion XL 1.0. as well as Stable Diffusion XL 1.0, and Turbo.
## Getting the weights ## Getting the weights
@ -23,16 +23,26 @@ cargo run --example stable-diffusion --release --features=cuda,cudnn \
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
``` ```
The final image is named `sd_final.png` by default. The final image is named `sd_final.png` by default. The Turbo version is much
The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The faster than previous versions, to give it a try add a `--sd-version turbo` flag,
original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim). e.g.:
```bash
cargo run --example stable-diffusion --release --features=cuda,cudnn \
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" --sd-version turbo
```
The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising
Diffusion Implicit Model scheduler (DDIM). The original paper and some code can
be found in the [associated repo](https://github.com/ermongroup/ddim).
The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
### Command-line flags ### Command-line flags
- `--prompt`: the prompt to be used to generate the image. - `--prompt`: the prompt to be used to generate the image.
- `--uncond-prompt`: the optional unconditional prompt. - `--uncond-prompt`: the optional unconditional prompt.
- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, or - `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`,
`xl`. `xl`, or `turbo`.
- `--cpu`: use the cpu rather than the gpu (much slower). - `--cpu`: use the cpu rather than the gpu (much slower).
- `--height`, `--width`: set the height and width for the generated image. - `--height`, `--width`: set the height and width for the generated image.
- `--n-steps`: the number of steps to be used in the diffusion process. - `--n-steps`: the number of steps to be used in the diffusion process.

View File

@ -11,8 +11,6 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser; use clap::Parser;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
const GUIDANCE_SCALE: f64 = 7.5;
#[derive(Parser)] #[derive(Parser)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
@ -63,8 +61,8 @@ struct Args {
sliced_attention_size: Option<usize>, sliced_attention_size: Option<usize>,
/// The number of steps to run the diffusion for. /// The number of steps to run the diffusion for.
#[arg(long, default_value_t = 30)] #[arg(long)]
n_steps: usize, n_steps: Option<usize>,
/// The number of samples to generate. /// The number of samples to generate.
#[arg(long, default_value_t = 1)] #[arg(long, default_value_t = 1)]
@ -87,6 +85,9 @@ struct Args {
#[arg(long)] #[arg(long)]
use_f16: bool, use_f16: bool,
#[arg(long)]
guidance_scale: Option<f64>,
#[arg(long, value_name = "FILE")] #[arg(long, value_name = "FILE")]
img2img: Option<String>, img2img: Option<String>,
@ -102,6 +103,7 @@ enum StableDiffusionVersion {
V1_5, V1_5,
V2_1, V2_1,
Xl, Xl,
Turbo,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -120,12 +122,13 @@ impl StableDiffusionVersion {
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1", Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5", Self::V1_5 => "runwayml/stable-diffusion-v1-5",
Self::Turbo => "stabilityai/sdxl-turbo",
} }
} }
fn unet_file(&self, use_f16: bool) -> &'static str { fn unet_file(&self, use_f16: bool) -> &'static str {
match self { match self {
Self::V1_5 | Self::V2_1 | Self::Xl => { Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 { if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors" "unet/diffusion_pytorch_model.fp16.safetensors"
} else { } else {
@ -137,7 +140,7 @@ impl StableDiffusionVersion {
fn vae_file(&self, use_f16: bool) -> &'static str { fn vae_file(&self, use_f16: bool) -> &'static str {
match self { match self {
Self::V1_5 | Self::V2_1 | Self::Xl => { Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 { if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors" "vae/diffusion_pytorch_model.fp16.safetensors"
} else { } else {
@ -149,7 +152,7 @@ impl StableDiffusionVersion {
fn clip_file(&self, use_f16: bool) -> &'static str { fn clip_file(&self, use_f16: bool) -> &'static str {
match self { match self {
Self::V1_5 | Self::V2_1 | Self::Xl => { Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 { if use_f16 {
"text_encoder/model.fp16.safetensors" "text_encoder/model.fp16.safetensors"
} else { } else {
@ -161,7 +164,7 @@ impl StableDiffusionVersion {
fn clip2_file(&self, use_f16: bool) -> &'static str { fn clip2_file(&self, use_f16: bool) -> &'static str {
match self { match self {
Self::V1_5 | Self::V2_1 | Self::Xl => { Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 { if use_f16 {
"text_encoder_2/model.fp16.safetensors" "text_encoder_2/model.fp16.safetensors"
} else { } else {
@ -189,7 +192,7 @@ impl ModelFile {
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
"openai/clip-vit-base-patch32" "openai/clip-vit-base-patch32"
} }
StableDiffusionVersion::Xl => { StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => {
// This seems similar to the patch32 version except some very small // This seems similar to the patch32 version except some very small
// difference in the split regex. // difference in the split regex.
"openai/clip-vit-large-patch14" "openai/clip-vit-large-patch14"
@ -206,7 +209,11 @@ impl ModelFile {
Self::Vae => { Self::Vae => {
// Override for SDXL when using f16 weights. // Override for SDXL when using f16 weights.
// See https://github.com/huggingface/candle/issues/1060 // See https://github.com/huggingface/candle/issues/1060
if version == StableDiffusionVersion::Xl && use_f16 { if matches!(
version,
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo,
) && use_f16
{
( (
"madebyollin/sdxl-vae-fp16-fix", "madebyollin/sdxl-vae-fp16-fix",
"diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.safetensors",
@ -261,6 +268,7 @@ fn text_embeddings(
use_f16: bool, use_f16: bool,
device: &Device, device: &Device,
dtype: DType, dtype: DType,
use_guide_scale: bool,
first: bool, first: bool,
) -> Result<Tensor> { ) -> Result<Tensor> {
let tokenizer_file = if first { let tokenizer_file = if first {
@ -285,16 +293,6 @@ fn text_embeddings(
} }
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the Clip transformer."); println!("Building the Clip transformer.");
let clip_weights_file = if first { let clip_weights_file = if first {
ModelFile::Clip ModelFile::Clip
@ -310,8 +308,24 @@ fn text_embeddings(
let text_model = let text_model =
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?; stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?; let text_embeddings = text_model.forward(&tokens)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?; let text_embeddings = if use_guide_scale {
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
} else {
text_embeddings.to_dtype(dtype)?
};
Ok(text_embeddings) Ok(text_embeddings)
} }
@ -356,6 +370,7 @@ fn run(args: Args) -> Result<()> {
unet_weights, unet_weights,
tracing, tracing,
use_f16, use_f16,
guidance_scale,
use_flash_attn, use_flash_attn,
img2img, img2img,
img2img_strength, img2img_strength,
@ -374,6 +389,24 @@ fn run(args: Args) -> Result<()> {
None None
}; };
let guidance_scale = match guidance_scale {
Some(guidance_scale) => guidance_scale,
None => match sd_version {
StableDiffusionVersion::V1_5
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 7.5,
StableDiffusionVersion::Turbo => 0.,
},
};
let n_steps = match n_steps {
Some(n_steps) => n_steps,
None => match sd_version {
StableDiffusionVersion::V1_5
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 30,
StableDiffusionVersion::Turbo => 1,
},
};
let dtype = if use_f16 { DType::F16 } else { DType::F32 }; let dtype = if use_f16 { DType::F16 } else { DType::F32 };
let sd_config = match sd_version { let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => { StableDiffusionVersion::V1_5 => {
@ -385,13 +418,19 @@ fn run(args: Args) -> Result<()> {
StableDiffusionVersion::Xl => { StableDiffusionVersion::Xl => {
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width) stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
} }
StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo(
sliced_attention_size,
height,
width,
),
}; };
let scheduler = sd_config.build_scheduler(n_steps)?; let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?; let device = candle_examples::device(cpu)?;
let use_guide_scale = guidance_scale > 1.0;
let which = match sd_version { let which = match sd_version {
StableDiffusionVersion::Xl => vec![true, false], StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false],
_ => vec![true], _ => vec![true],
}; };
let text_embeddings = which let text_embeddings = which
@ -407,10 +446,12 @@ fn run(args: Args) -> Result<()> {
use_f16, use_f16,
&device, &device,
dtype, dtype,
use_guide_scale,
*first, *first,
) )
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
println!("{text_embeddings:?}"); println!("{text_embeddings:?}");
@ -434,11 +475,19 @@ fn run(args: Args) -> Result<()> {
0 0
}; };
let bsize = 1; let bsize = 1;
let vae_scale = match sd_version {
StableDiffusionVersion::V1_5
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 0.18215,
StableDiffusionVersion::Turbo => 0.13025,
};
for idx in 0..num_samples { for idx in 0..num_samples {
let timesteps = scheduler.timesteps(); let timesteps = scheduler.timesteps();
let latents = match &init_latent_dist { let latents = match &init_latent_dist {
Some(init_latent_dist) => { Some(init_latent_dist) => {
let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?; let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
if t_start < timesteps.len() { if t_start < timesteps.len() {
let noise = latents.randn_like(0f64, 1f64)?; let noise = latents.randn_like(0f64, 1f64)?;
scheduler.add_noise(&latents, noise, timesteps[t_start])? scheduler.add_noise(&latents, noise, timesteps[t_start])?
@ -465,21 +514,31 @@ fn run(args: Args) -> Result<()> {
continue; continue;
} }
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; let latent_model_input = if use_guide_scale {
Tensor::cat(&[&latents, &latents], 0)?
} else {
latents.clone()
};
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
let noise_pred = let noise_pred =
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); let noise_pred = if use_guide_scale {
let noise_pred = let noise_pred = noise_pred.chunk(2, 0)?;
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?; let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)?
} else {
noise_pred
};
latents = scheduler.step(&noise_pred, timestep, &latents)?; latents = scheduler.step(&noise_pred, timestep, &latents)?;
let dt = start_time.elapsed().as_secs_f32(); let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
if args.intermediary_images { if args.intermediary_images {
let image = vae.decode(&(&latents / 0.18215)?)?; let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = let image_filename =
@ -493,7 +552,7 @@ fn run(args: Args) -> Result<()> {
idx + 1, idx + 1,
num_samples num_samples
); );
let image = vae.decode(&(&latents / 0.18215)?)?; let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None); let image_filename = output_filename(&final_image, idx + 1, num_samples, None);

View File

@ -96,25 +96,9 @@ impl T5ModelBuilder {
let api = api.repo(repo); let api = api.repo(repo);
let config_filename = api.get("config.json")?; let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?; let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = if model_id == "google/flan-t5-xxl" { let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
vec![ {
api.get("model-00001-of-00005.safetensors")?, candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
api.get("model-00002-of-00005.safetensors")?,
api.get("model-00003-of-00005.safetensors")?,
api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?,
]
} else if model_id == "google/flan-ul2" {
vec![
api.get("model-00001-of-00008.safetensors")?,
api.get("model-00002-of-00008.safetensors")?,
api.get("model-00003-of-00008.safetensors")?,
api.get("model-00004-of-00008.safetensors")?,
api.get("model-00005-of-00008.safetensors")?,
api.get("model-00006-of-00008.safetensors")?,
api.get("model-00007-of-00008.safetensors")?,
api.get("model-00008-of-00008.safetensors")?,
]
} else { } else {
vec![api.get("model.safetensors")?] vec![api.get("model.safetensors")?]
}; };

View File

@ -218,21 +218,7 @@ fn main() -> Result<()> {
.split(',') .split(',')
.map(std::path::PathBuf::from) .map(std::path::PathBuf::from)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
None => match args.which { None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
Which::L6b => vec![
repo.get("model-00001-of-00002.safetensors")?,
repo.get("model-00002-of-00002.safetensors")?,
],
Which::L34b => vec![
repo.get("model-00001-of-00007.safetensors")?,
repo.get("model-00002-of-00007.safetensors")?,
repo.get("model-00003-of-00007.safetensors")?,
repo.get("model-00004-of-00007.safetensors")?,
repo.get("model-00005-of-00007.safetensors")?,
repo.get("model-00006-of-00007.safetensors")?,
repo.get("model-00007-of-00007.safetensors")?,
],
},
}; };
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

View File

@ -147,7 +147,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
let func = candle_nn::func(move |xs| { let func = candle_nn::func(move |xs| {
let xs = conv.forward(xs)?; let xs = conv.forward(xs)?;
let xs = match &bn { let xs = match &bn {
Some(bn) => bn.forward(&xs)?, Some(bn) => xs.apply_t(bn, false)?,
None => xs, None => xs,
}; };
let xs = if leaky { let xs = if leaky {

View File

@ -117,3 +117,30 @@ pub fn save_image_resize<P: AsRef<std::path::Path>>(
image.save(p).map_err(candle::Error::wrap)?; image.save(p).map_err(candle::Error::wrap)?;
Ok(()) Ok(())
} }
/// Loads the safetensors files for a model from the hub based on a json index file.
pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle::Error::wrap))
.collect::<Result<Vec<_>>>()?;
Ok(safetensors_files)
}

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-flash-attn" name = "candle-flash-attn"
version = "0.3.1" version = "0.3.3"
edition = "2021" edition = "2021"
description = "Flash attention layer for the candle ML framework." description = "Flash attention layer for the candle ML framework."
@ -11,14 +11,14 @@ license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.1", package = "candle-core" } candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] } half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies] [build-dependencies]
bindgen_cuda = "0.1.1"
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
num_cpus = "1.15.0"
rayon = "1.7.0"
[dev-dependencies] [dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", version = "0.3.1", features = ["cuda"] } candle-nn = { path = "../candle-nn", features = ["cuda"] }

View File

@ -2,44 +2,32 @@
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment // The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often. // variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use rayon::prelude::*;
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr;
const KERNEL_FILES: [&str; 17] = [ const KERNEL_FILES: [&str; 17] = [
"flash_api.cu", "kernels/flash_api.cu",
"flash_fwd_hdim128_fp16_sm80.cu", "kernels/flash_fwd_hdim128_fp16_sm80.cu",
"flash_fwd_hdim160_fp16_sm80.cu", "kernels/flash_fwd_hdim160_fp16_sm80.cu",
"flash_fwd_hdim192_fp16_sm80.cu", "kernels/flash_fwd_hdim192_fp16_sm80.cu",
"flash_fwd_hdim224_fp16_sm80.cu", "kernels/flash_fwd_hdim224_fp16_sm80.cu",
"flash_fwd_hdim256_fp16_sm80.cu", "kernels/flash_fwd_hdim256_fp16_sm80.cu",
"flash_fwd_hdim32_fp16_sm80.cu", "kernels/flash_fwd_hdim32_fp16_sm80.cu",
"flash_fwd_hdim64_fp16_sm80.cu", "kernels/flash_fwd_hdim64_fp16_sm80.cu",
"flash_fwd_hdim96_fp16_sm80.cu", "kernels/flash_fwd_hdim96_fp16_sm80.cu",
"flash_fwd_hdim128_bf16_sm80.cu", "kernels/flash_fwd_hdim128_bf16_sm80.cu",
"flash_fwd_hdim160_bf16_sm80.cu", "kernels/flash_fwd_hdim160_bf16_sm80.cu",
"flash_fwd_hdim192_bf16_sm80.cu", "kernels/flash_fwd_hdim192_bf16_sm80.cu",
"flash_fwd_hdim224_bf16_sm80.cu", "kernels/flash_fwd_hdim224_bf16_sm80.cu",
"flash_fwd_hdim256_bf16_sm80.cu", "kernels/flash_fwd_hdim256_bf16_sm80.cu",
"flash_fwd_hdim32_bf16_sm80.cu", "kernels/flash_fwd_hdim32_bf16_sm80.cu",
"flash_fwd_hdim64_bf16_sm80.cu", "kernels/flash_fwd_hdim64_bf16_sm80.cu",
"flash_fwd_hdim96_bf16_sm80.cu", "kernels/flash_fwd_hdim96_bf16_sm80.cu",
]; ];
fn main() -> Result<()> { fn main() -> Result<()> {
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|_| num_cpus::get_physical(),
|s| usize::from_str(&s).unwrap(),
);
rayon::ThreadPoolBuilder::new()
.num_threads(num_cpus)
.build_global()
.unwrap();
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
for kernel_file in KERNEL_FILES.iter() { for kernel_file in KERNEL_FILES.iter() {
println!("cargo:rerun-if-changed=kernels/{kernel_file}"); println!("cargo:rerun-if-changed={kernel_file}");
} }
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h"); println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h"); println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
@ -66,223 +54,30 @@ fn main() -> Result<()> {
)) ))
} }
}; };
set_cuda_include_dir()?;
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); let kernels = KERNEL_FILES.iter().collect();
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); let builder = bindgen_cuda::Builder::default()
.kernel_paths(kernels)
let compute_cap = compute_cap()?; .out_dir(build_dir.clone())
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
let out_file = build_dir.join("libflashattention.a"); let out_file = build_dir.join("libflashattention.a");
builder.build_lib(out_file);
let kernel_dir = PathBuf::from("kernels");
let cu_files: Vec<_> = KERNEL_FILES
.iter()
.map(|f| {
let mut obj_file = out_dir.join(f);
obj_file.set_extension("o");
(kernel_dir.join(f), obj_file)
})
.collect();
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
let should_compile = if out_file.exists() {
kernel_dir
.read_dir()
.expect("kernels folder should exist")
.any(|entry| {
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
let in_modified = entry.metadata().unwrap().modified().unwrap();
in_modified.duration_since(*out_modified).is_ok()
} else {
true
}
})
} else {
true
};
if should_compile {
cu_files
.par_iter()
.map(|(cu_file, obj_file)| {
let mut command = std::process::Command::new("nvcc");
command
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("-c")
.args(["-o", obj_file.to_str().unwrap()])
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
Ok(())
})
.collect::<Result<()>>()?;
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
let mut command = std::process::Command::new("nvcc");
command
.arg("--lib")
.args(["-o", out_file.to_str().unwrap()])
.args(obj_files);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
println!("cargo:rustc-link-search={}", build_dir.display()); println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=flashattention"); println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart"); println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++"); println!("cargo:rustc-link-lib=dylib=stdc++");
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
finishing to run for some reason. Calling nvcc manually worked fine.
cc::Build::new()
.cuda(true)
.include("cutlass/include")
.flag("--expt-relaxed-constexpr")
.flag("--default-stream")
.flag("per-thread")
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
.compile("flashattn");
*/
Ok(()) Ok(())
} }
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse compute cap")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
};
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
}
Ok(compute_cap)
}

View File

@ -0,0 +1,62 @@
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_causal, typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset,
const int max_seqlen_q,
const int warp_row_stride,
const float alibi_slope) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
}
}
}
}
} // namespace flash

View File

@ -14,9 +14,12 @@ struct BlockInfo {
template<typename Params> template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb) __device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{ {
} }
@ -32,8 +35,10 @@ struct BlockInfo {
const int sum_s_q; const int sum_s_q;
const int sum_s_k; const int sum_s_k;
const uint32_t actual_seqlen_q; const int actual_seqlen_q;
const uint32_t actual_seqlen_k; // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int seqlen_k_cache;
const int actual_seqlen_k;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -7,15 +7,6 @@
#include <cuda.h> #include <cuda.h>
#include <vector> #include <vector>
// #ifdef OLD_GENERATOR_PATH
// #include <ATen/CUDAGeneratorImpl.h>
// #else
// #include <ATen/cuda/CUDAGeneratorImpl.h>
// #endif
//
// #include <ATen/cuda/CUDAGraphsUtils.cuh>
constexpr int TOTAL_DIM = 0; constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1; constexpr int H_DIM = 1;
constexpr int D_DIM = 2; constexpr int D_DIM = 2;
@ -53,6 +44,7 @@ struct Flash_fwd_params : public Qkv_params {
// The O matrix (output). // The O matrix (output).
void * __restrict__ o_ptr; void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;
// The stride between rows of O. // The stride between rows of O.
index_t o_batch_stride; index_t o_batch_stride;
@ -64,9 +56,10 @@ struct Flash_fwd_params : public Qkv_params {
// The pointer to the softmax sum. // The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr; void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions. // The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
// The scaling factors for the kernel. // The scaling factors for the kernel.
float scale_softmax; float scale_softmax;
@ -76,8 +69,30 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k; int * __restrict__ cu_seqlens_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
int *__restrict__ blockmask; int *__restrict__ blockmask;
// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;
// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;
// The cos and sin matrices for rotary embedding.
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
int *__restrict__ cache_batch_idx;
// The dropout probability (probability of keeping an activation). // The dropout probability (probability of keeping an activation).
float p_dropout; float p_dropout;
// uint32_t p_dropout_in_uint; // uint32_t p_dropout_in_uint;
@ -88,11 +103,22 @@ struct Flash_fwd_params : public Qkv_params {
float rp_dropout; float rp_dropout;
float scale_softmax_rp_dropout; float scale_softmax_rp_dropout;
// Random state. // Local window size
// at::PhiloxCudaState philox_args; int window_size_left, window_size_right;
bool is_bf16; bool is_bf16;
bool is_causal; bool is_causal;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
bool is_rotary_interleaved;
int num_splits; // For split-KV version
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -132,10 +158,14 @@ struct Flash_bwd_params : public Flash_fwd_params {
// The pointer to the softmax d sum. // The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum; void *__restrict__ dsoftmax_sum;
bool deterministic;
index_t dq_accum_split_stride;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream); template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure); template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);

View File

@ -1,17 +1,15 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
// FWD_HEADDIM_SWITCH(params.d, [&] { FP16_SWITCH(!params.is_bf16, [&] {
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream); FWD_HEADDIM_SWITCH(params.d, [&] {
// }); // if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
// } run_mha_fwd_<elem_type, kHeadDim>(params, stream);
// } else {
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) { // run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
FP16_SWITCH(!params.is_bf16, [&] { // }
FWD_HEADDIM_SWITCH(params.d, [&] { });
run_mha_fwd_<elem_type, kHeadDim>(params, stream); });
});
});
} }
extern "C" void run_mha( extern "C" void run_mha(
@ -20,6 +18,7 @@ extern "C" void run_mha(
void *v_ptr, void *v_ptr,
void *o_ptr, void *o_ptr,
void *softmax_lse_ptr, void *softmax_lse_ptr,
void *alibi_slopes_ptr,
int32_t *cu_seqlens_q_ptr, int32_t *cu_seqlens_q_ptr,
int32_t *cu_seqlens_k_ptr, int32_t *cu_seqlens_k_ptr,
@ -28,6 +27,7 @@ extern "C" void run_mha(
uint32_t k_batch_stride, uint32_t k_batch_stride,
uint32_t v_batch_stride, uint32_t v_batch_stride,
uint32_t o_batch_stride, uint32_t o_batch_stride,
uint32_t alibi_slopes_batch_stride,
uint32_t q_row_stride, uint32_t q_row_stride,
uint32_t k_row_stride, uint32_t k_row_stride,
@ -51,8 +51,11 @@ extern "C" void run_mha(
uint32_t seqlen_q_rounded, uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded, uint32_t seqlen_k_rounded,
int is_bf16,
int is_causal, int is_causal,
int is_bf16
int window_size_left,
int window_size_right
) { ) {
Flash_fwd_params params; Flash_fwd_params params;
// Reset the parameters // Reset the parameters
@ -65,12 +68,14 @@ extern "C" void run_mha(
params.o_ptr = o_ptr; params.o_ptr = o_ptr;
params.softmax_lse_ptr = softmax_lse_ptr; params.softmax_lse_ptr = softmax_lse_ptr;
params.alibi_slopes_ptr = alibi_slopes_ptr;
// All stride are in elements, not bytes. // All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride; params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride; params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride; params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride; params.o_batch_stride = o_batch_stride;
params.alibi_slopes_batch_stride = alibi_slopes_batch_stride;
params.q_row_stride = q_row_stride; params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride; params.k_row_stride = k_row_stride;
@ -92,7 +97,6 @@ extern "C" void run_mha(
params.seqlen_k_rounded = seqlen_k_rounded; params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d; params.d = d;
params.d_rounded = d_rounded; params.d_rounded = d_rounded;
params.is_causal = is_causal;
// Set the different scale values. // Set the different scale values.
params.scale_softmax = softmax_scale; params.scale_softmax = softmax_scale;
@ -106,6 +110,14 @@ extern "C" void run_mha(
params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`. params.p_ptr = nullptr; // used for `return_softmax`.
params.seqused_k = nullptr;
params.is_causal = is_causal;
params.window_size_left = window_size_left;
params.window_size_right = window_size_right;
params.is_seqlens_k_cumulative = true;
params.num_splits = 1;
cudaStream_t stream = 0; // Use the default stream. cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream); run_mha_fwd(params, stream);

View File

@ -1,19 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,32 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
// // 1st ones are good for H100, A100
// // 2nd one is good for A6000 bc we get slightly better occupancy
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
// // 1st one is good for H100, A100, A6000
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream); run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
} }

View File

@ -1,17 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,27 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
// // For A100, H100, 1st is fastest.
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::half_t>(params, stream); run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
} }

View File

@ -1,16 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<> template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,27 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // This one is slightly faster for causal?
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
// });
// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
// // For A6000, 1st is faster when causal, 3rd is faster when not causal
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t>(params, stream); run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::half_t>(params, stream); run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t>(params, stream); run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
} }

View File

@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,23 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // For dropout there might be a lot of register spilling?
// // These two are very slow due to register spilling
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
// // This one is slightly slower
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t>(params, stream); run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
} }

View File

@ -1,19 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,26 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// // Using block size (64 x 256) is 27% slower for seqlen=2k
// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t>(params, stream); run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
} }

View File

@ -1,17 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,23 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<> template<>
// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
// // This 3rd one is good for H100, and A100, A6000
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
// // These two are always slower
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
// });
// }
template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t>(params, stream); run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
} }

View File

@ -4,20 +4,18 @@
#pragma once #pragma once
#include <cmath>
#include <cute/algorithm/copy.hpp> #include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/array.h> #include <cutlass/array.h>
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h" #include "block_info.h"
#include "kernel_traits.h" #include "kernel_traits.h"
#include "utils.h" #include "utils.h"
#include "softmax.h" #include "softmax.h"
#include "philox.cuh"
#include "alibi.h"
namespace flash { namespace flash {
@ -25,49 +23,6 @@ using namespace cute;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
// TODO: Shouldn't this be size<1>?
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2> template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
Tensor2 &acc_o, float softmax_scale_log2) { Tensor2 &acc_o, float softmax_scale_log2) {
@ -77,7 +32,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
flash::reduce_sum(scores, scores_sum); flash::reduce_sum(scores, scores_sum);
} else { } else {
Tensor scores_max_prev = make_fragment_like(scores_max); Tensor scores_max_prev = make_fragment_like(scores_max);
copy(scores_max, scores_max_prev); cute::copy(scores_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, scores_max); flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
@ -103,23 +58,22 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy> template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem( inline __device__ void write_softmax_to_gmem(
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
) { ) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout(); Layout l = tOrP.layout();
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
// TODO(laurent): reactivate the following CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
// CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) { for (int mi = 0; mi < size<1>(tPrP); ++mi) {
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
} }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
@ -138,16 +92,65 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) { if (Is_causal || Is_local) {
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// } // }
} }
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV.
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
// if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
// auto seeds = at::cuda::philox::unpack(params.philox_args);
// params.rng_state[0] = std::get<0>(seeds);
// params.rng_state[1] = std::get<1>(seeds);
// params.rng_state[0] = 0;
// params.rng_state[1] = 0;
// }
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_tensor<Element>(shape(tOgO));
clear(tOrO);
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOgO); ++m) {
const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
}
return;
}
// if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
// We iterate over the blocks in reverse order. This is because the last block is the only one // We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
@ -185,8 +188,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
@ -208,16 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Copy Atom retiling // Copy Atom retiling
// //
auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();} // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70 // TODO: this might need to change if we change the mma instruction in SM70
@ -268,8 +275,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tQrQ = make_fragment_like(tQgQ); Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
// // Copy rmem to smem // // Copy rmem to smem
@ -285,14 +292,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads(); __syncthreads();
} }
int n_block = n_block_max - 1; int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN); binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence(); cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads(); // __syncthreads();
@ -302,7 +309,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
} }
// auto seeds = at::cuda::philox::unpack(params.philox_args); // auto seeds = at::cuda::philox::unpack(params.philox_args);
@ -313,13 +320,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_o); clear(acc_o);
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't. // those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll #pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
@ -330,28 +343,42 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Advance gV // Advance gV
if (masking_step > 0) { if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else { } else {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
); );
// if (cute::thread0()) { print(acc_s); } // if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print_tensor(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK // We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN. // can produce Inf / NaN.
if (!Is_causal) {
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } if (Has_alibi) {
flash::apply_alibi<Is_causal>(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
alibi_slope
);
}
if (!Is_causal && !Is_local) {
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else { } else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) // Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
@ -364,20 +391,24 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Idk why it's get<1> and not get<0> of the stride. // Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row // I can't get the stride from idx_row
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, flash::apply_mask_local</*HasWSLeft=*/Is_local>(
// m_block * kBlockM + get<0>(idx_row(0)), scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, // m_block * kBlockM + get<0>(idx_row(0)),
kNWarps * 16); m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); binfo.actual_seqlen_q, kNWarps * 16,
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); params.window_size_left, params.window_size_right
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
);
// if (cute::thread0()) { print_tensor(scores); }
} }
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (n_block > 0) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
@ -385,24 +416,24 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf // TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0 masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16 // Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32); int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps block_row_idx, block_col_idx, kNWarps
); );
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN); tPgP.data() = tPgP.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
@ -411,37 +442,38 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
} }
// if (cute::thread0()) { print(tOrP); } // if (cute::thread0()) { print(tOrP); }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration // This check is at the end of the loop since we always have at least 1 iteration
if (n_masking_steps > 1 && n_block <= 0) { if (n_masking_steps > 1 && n_block <= n_block_min) {
--n_block; --n_block;
break; break;
} }
} }
// These are the iterations where we don't need masking on S // These are the iterations where we don't need masking on S
for (; n_block >= 0; --n_block) { for (; n_block >= n_block_min; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s); clear(acc_s);
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
// Advance gV // Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
); );
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (n_block > 0) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
@ -449,22 +481,44 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
if (Has_alibi) {
flash::apply_alibi<Is_causal>(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
alibi_slope
);
}
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32); int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps block_row_idx, block_col_idx, kNWarps
); );
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN); tPgP.data() = tPgP.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
@ -472,7 +526,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
block_row_idx, block_col_idx, kNWarps); block_row_idx, block_col_idx, kNWarps);
} }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
} }
// Epilogue // Epilogue
@ -496,15 +550,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor rO = flash::convert_type<Element>(acc_o); Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning // Partition sO to match the accumulator partitioning
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here. // sO has the same size as sQ, so we don't need to sync here.
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
copy(smem_thr_copy_O, taccOrO, taccOsO); cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
@ -515,14 +569,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse), Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{}); Shape<Int<kBlockM>>{}, Stride<_1>{});
auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
__syncthreads(); __syncthreads();
Tensor tOrO = make_tensor<Element>(shape(tOgO)); Tensor tOrO = make_tensor<Element>(shape(tOgO));
copy(gmem_thr_copy_O, tOsO, tOrO); cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
@ -548,14 +603,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
} }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
); );
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) { inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
@ -571,7 +627,7 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of // the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block); flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -4,15 +4,14 @@
#pragma once #pragma once
// #include <ATen/cuda/CUDAContext.h>
#include "static_switch.h" #include "static_switch.h"
#include "flash.h" #include "flash.h"
#include "flash_fwd_kernel.h" #include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) { __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params); static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal> template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
@ -26,35 +25,39 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h); dim3 grid(num_m_block, params.b, params.h);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
// for cu_seqlens_q as well.
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr; const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
// Will only return softmax if dropout, to reduce compilation time. BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>; BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>; // Will only return softmax if dropout, to reduce compilation time.
// if (smem_size >= 48 * 1024) { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// C10_CUDA_CHECK(cudaFuncSetAttribute( // If return_softmax, set IsEvenMNConst to false to reduce number of templates
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// } // If Is_local, set Is_causal to false
int ctas_per_sm; auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); // int ctas_per_sm;
// C10_CUDA_KERNEL_LAUNCH_CHECK(); // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
}); });
}); });
}); });
} }
template<typename T> template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 32; constexpr static int Headdim = 32;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
@ -64,7 +67,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 64; constexpr static int Headdim = 64;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
@ -86,7 +89,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 96; constexpr static int Headdim = 96;
// auto dprops = at::cuda::getCurrentDeviceProperties(); // auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -112,7 +115,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 128; constexpr static int Headdim = 128;
// auto dprops = at::cuda::getCurrentDeviceProperties(); // auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -149,7 +152,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 160; constexpr static int Headdim = 160;
// auto dprops = at::cuda::getCurrentDeviceProperties(); // auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -179,7 +182,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 192; constexpr static int Headdim = 192;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
@ -198,7 +201,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 224; constexpr static int Headdim = 224;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_block; int max_smem_per_block;
@ -224,7 +227,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 256; constexpr static int Headdim = 256;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block; int max_smem_per_sm, max_smem_per_block;

View File

@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ{}, SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{})); Shape<Int<kBlockN>, Int<kHeadDim>>{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomVtransposed = decltype( using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape( using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{}, SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{})); Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape // Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter? // And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomVtransposedNoSwizzle{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype( using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
@ -110,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape( using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{}, SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{})); Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>; using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
static constexpr int kSmemQCount = size(SmemLayoutQ{}); static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
@ -138,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base {
DefaultCopy DefaultCopy
>; >;
using GmemTiledCopyQKV = decltype( using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{}, make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{}, GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype( using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{}, GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
@ -151,10 +155,30 @@ struct Flash_fwd_kernel_traits : public Base {
Stride<Int<kGmemThreadsPerRowP>, _1>>; Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype( using GmemTiledCopyP = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomP{}, GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomOaccum = std::conditional_t<
kBlockKSmem == 32,
Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
Stride< _8, _1>>,
Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
using GmemTiledCopyRotcossin = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinCont = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
}; };
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
@ -223,16 +247,19 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV{}, SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomKtransposed = decltype( using SmemLayoutAtomKtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutKtransposed = decltype(tile_to_shape( using SmemLayoutKtransposed = decltype(tile_to_shape(
SmemLayoutAtomKtransposed{}, SmemLayoutAtomKtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{}))); make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape // Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter? // And the strides don't matter?
using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomKtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN // TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
@ -250,24 +277,30 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutPdS = decltype(tile_to_shape( using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{}, SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{}))); make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
Stride<_1, Int<kPBlockN>>>;
using SmemLayoutAtomPdStransposed = decltype( using SmemLayoutAtomPdStransposed = decltype(
composition(Swizzle<kSwizzlePdS, 3, 3>{}, composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
Stride<_1, Int<kPBlockN>>>{}));
using SmemLayoutPdStransposed = decltype(tile_to_shape( using SmemLayoutPdStransposed = decltype(tile_to_shape(
SmemLayoutAtomPdStransposed{}, SmemLayoutAtomPdStransposed{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{}))); make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomPdStransposedNoSwizzle{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomQdOtransposed = decltype( using SmemLayoutAtomQdOtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutQdOtransposed = decltype(tile_to_shape( using SmemLayoutQdOtransposed = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposed{}, SmemLayoutAtomQdOtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{}))); make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using SmemLayoutAtomdKV = decltype( using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
@ -292,13 +325,11 @@ struct Flash_bwd_kernel_traits : public Base {
static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
static constexpr int kSmemPCount = size(SmemLayoutPdS{}); static constexpr int kSmemPCount = size(SmemLayoutPdS{});
static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
static constexpr int kSmemdPsumCount = kBlockM;
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
static constexpr int kSmemSize = kSmemQdOSize static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs + (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)

View File

@ -0,0 +1,159 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include <cutlass/numeric_types.h>
using namespace cute;
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
struct Flash_kernel_traits_sm90 {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using Element = elem_type;
static constexpr bool Has_cp_async = true;
#else
using Element = cutlass::half_t;
static constexpr bool Has_cp_async = false;
#endif
using ElementAccum = float;
using index_t = uint32_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
#else
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
#endif
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
typename Base=Flash_kernel_traits_sm90<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
struct Flash_fwd_kernel_traits : public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using SmemLayoutAtomQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutKV = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
// to the same banks.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -8,8 +8,7 @@
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <cutlass/cutlass.h> #include <cutlass/numeric_types.h>
#include <cutlass/array.h>
#include "philox.cuh" #include "philox.cuh"
#include "utils.h" #include "utils.h"
@ -117,15 +116,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
} }
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) { inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; const int col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) { if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results // Without the "make_coord" we get wrong results
#pragma unroll #pragma unroll
@ -137,30 +139,30 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
} }
} }
template <typename Engine, typename Layout> template <bool HasWSLeft=true, typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_, inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, const int max_seqlen_k, const int row_idx_offset,
const uint32_t warp_row_stride) { const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
const uint32_t row_idx_offset = row_idx_offset_;
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) { for (int i = 0; i < size<0, 0>(tensor); ++i) {
const uint32_t row_idx = row_idx_base + i * 8; const int row_idx = row_idx_base + i * 8;
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8; const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = col_idx_base + j; const int col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit) { if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
} }
} }
@ -174,10 +176,19 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
} }
} }
template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
max_seqlen_q, warp_row_stride, -1, 0);
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx( inline __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol, Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{ {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
@ -186,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) { for (int mi = 0; mi < size<0>(tensor); ++mi) {
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
@ -204,8 +215,8 @@ inline __device__ void apply_mask_causal_w_idx(
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout> template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t, inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
unsigned long long seed, unsigned long long offset, unsigned long long seed, unsigned long long offset,
uint32_t block_row_start, uint32_t block_col_start, int block_row_start, int block_col_start,
uint32_t block_row_stride) { int block_row_stride) {
// tensor has shape (8, MMA_M, MMA_N / 2) // tensor has shape (8, MMA_M, MMA_N / 2)
using T = typename Engine::value_type; using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) { auto encode_dropout = [](bool keep, T val) {

View File

@ -87,46 +87,6 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ float2 half2_unpack(uint32_t a);
template <>
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
return __half22float2(reinterpret_cast<__half2 (&)>(a));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert two half2's or bf162's into float, then take their dot product.
template <typename T>
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
float2 af = flash::half2_unpack<T>(a);
float2 bf = flash::half2_unpack<T>(b);
return af.x * bf.x + af.y * bf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
template<typename T>
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
float sum;
sum = flash::hfma2_to_float<T>(a.x, b.x);
sum += flash::hfma2_to_float<T>(a.y, b.y);
sum += flash::hfma2_to_float<T>(a.z, b.z);
sum += flash::hfma2_to_float<T>(a.w, b.w);
return sum;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
struct MaxOp { struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } __device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
@ -173,10 +133,12 @@ static __device__ inline T run(T x, Operator &op) {
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1, template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4, typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy0, typename TiledCopy1> typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB>
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma, Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
@ -184,13 +146,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
#pragma unroll #pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) { for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) { if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
} }
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
} }
@ -199,19 +161,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy> typename TiledMma, typename TiledCopy, typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll #pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) { for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) { if (i < size<2>(tCrA) - 1) {
copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
} }
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
} }
@ -225,7 +188,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -241,9 +207,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
static_assert(mma_shape_K == 8 || mma_shape_K == 16); static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), // TD [2023-08-13]: Same error as above on Cutlass 3.2
get<0, 1>(l), // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
get<1, 1, 1>(l)); // get<0, 1>(l),
// get<1, 1, 1>(l));
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
get<1>(get<0>(l)),
get<1>(get<1>(get<1>(l))));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -319,9 +289,9 @@ void cp_async_wait() {
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S, inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) { Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
@ -335,13 +305,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
#pragma unroll #pragma unroll
for (int k = 0; k < size<2>(S); ++k) { for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) { if (Is_even_K || predicate_K(k)) {
copy(thr_copy, S(_, m, k), D(_, m, k)); cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) { } else if (Clear_OOB_K) {
clear(D(_, m, k)); cute::clear(D(_, m, k));
} }
} }
} else if (Clear_OOB_MN) { } else if (Clear_OOB_MN) {
clear(D(_, m, _)); cute::clear(D(_, m, _));
} }
} }
// TD [2023-04-13]: Strange that the code below can cause race condition. // TD [2023-04-13]: Strange that the code below can cause race condition.
@ -350,7 +320,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll // #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) { // for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, _), D(_, m, _)); // copy(tiled_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) { // } else if (Clear_OOB_MN) {
// clear(D(_, m, _)); // clear(D(_, m, _));
// } // }
@ -362,7 +332,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll // #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) { // for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, k), D(_, m, k)); // copy(tiled_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) { // } else if (Clear_OOB_MN) {
// clear(D(_, m, k)); // clear(D(_, m, k));
// } // }

View File

@ -7,6 +7,8 @@ extern "C" {
v_ptr: *const c_void, v_ptr: *const c_void,
o_ptr: *const c_void, o_ptr: *const c_void,
softmax_lse_ptr: *const c_void, softmax_lse_ptr: *const c_void,
alibi_slopes_ptr: *const c_void,
cu_seqlens_q_ptr: *const i32, cu_seqlens_q_ptr: *const i32,
cu_seqlens_k_ptr: *const i32, cu_seqlens_k_ptr: *const i32,
@ -14,6 +16,7 @@ extern "C" {
k_batch_stride: u32, k_batch_stride: u32,
v_batch_stride: u32, v_batch_stride: u32,
o_batch_stride: u32, o_batch_stride: u32,
alibi_slopes_batch_stride: u32,
q_row_stride: u32, q_row_stride: u32,
k_row_stride: u32, k_row_stride: u32,
@ -37,8 +40,11 @@ extern "C" {
seqlen_q_rounded: u32, seqlen_q_rounded: u32,
seqlen_k_rounded: u32, seqlen_k_rounded: u32,
is_causal: c_int,
is_bf16: c_int, is_bf16: c_int,
is_causal: c_int,
window_size_left: c_int,
window_size_right: c_int,
); );
} }

View File

@ -3,12 +3,14 @@ mod ffi;
use candle::backend::BackendStorage; use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr; use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Layout, Result, Shape, Tensor}; use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use half::{bf16, f16}; use half::{bf16, f16};
pub struct FlashAttn { pub struct FlashAttn {
pub softmax_scale: f32, pub softmax_scale: f32,
pub causal: bool, pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
} }
fn round_multiple(x: usize, m: usize) -> usize { fn round_multiple(x: usize, m: usize) -> usize {
@ -85,6 +87,51 @@ impl FlashAttn {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
} }
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
"DType mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes.dtype(),
DType::F32
);
}
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
if num_heads != alibi_slopes_layout.shape().dims1()? {
candle::bail!(
"shape mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes_layout.shape(),
(num_heads)
);
}
let alibi_slopes = match &*alibi_slopes {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
};
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
*alibi_slopes.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};
// if window_size_left > self.max_seqlen_k or None => -1
let mut window_size_left = self
.window_size_left
.filter(|v| v <= &seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
// if window_size_right > self.max_seqlen_k or None => -1
let mut window_size_right = self
.window_size_right
.filter(|v| v <= &seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
let head_size = round_multiple(head_size_og, 8); let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32); let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(seqlen_q, 128); let seqlen_q_rounded = round_multiple(seqlen_q, 128);
@ -94,9 +141,22 @@ impl FlashAttn {
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?; let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?; let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 };
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
let is_causal = if window_size_left < 0 && window_size_right == 0 {
1
} else {
0
};
if window_size_left < 0 && window_size_right >= 0 {
window_size_left = seqlen_k as i32;
}
if window_size_left >= 0 && window_size_right < 0 {
window_size_right = seqlen_k as i32;
}
unsafe { unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@ -109,12 +169,14 @@ impl FlashAttn {
v_ptr, v_ptr,
dst_ptr, dst_ptr,
softmax_lse_ptr, softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ std::ptr::null(), /* cu_seqlens_q_ptr */ std::ptr::null(),
/* cu_seqlens_k_ptr */ std::ptr::null(), /* cu_seqlens_k_ptr */ std::ptr::null(),
/* q_batch_stride */ q_stride[0] as u32, /* q_batch_stride */ q_stride[0] as u32,
/* k_batch_stride */ k_stride[0] as u32, /* k_batch_stride */ k_stride[0] as u32,
/* v_batch_stride */ v_stride[0] as u32, /* v_batch_stride */ v_stride[0] as u32,
/* o_batch_stride */ o_stride[0] as u32, /* o_batch_stride */ o_stride[0] as u32,
/* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32, /* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32,
@ -133,8 +195,10 @@ impl FlashAttn {
/* seqlen_k */ seqlen_k as u32, /* seqlen_k */ seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ is_bf16, /* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
) )
} }
@ -197,20 +261,137 @@ pub fn flash_attn(
softmax_scale: f32, softmax_scale: f32,
causal: bool, causal: bool,
) -> Result<Tensor> { ) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttn { let op = FlashAttn {
softmax_scale, softmax_scale,
causal, alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_alibi(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttn {
softmax_scale,
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_alibi_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
}; };
q.apply_op3(k, v, op) q.apply_op3(k, v, op)
} }
struct FlashAttnVarLen { struct FlashAttnVarLen {
softmax_scale: f32, pub softmax_scale: f32,
causal: bool, pub max_seqlen_q: usize,
max_seqlen_q: usize, pub max_seqlen_k: usize,
max_seqlen_k: usize, pub seqlens_q: Tensor,
seqlens_q: Tensor, pub seqlens_k: Tensor,
seqlens_k: Tensor, pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
} }
impl FlashAttnVarLen { impl FlashAttnVarLen {
@ -311,7 +492,54 @@ impl FlashAttnVarLen {
if nseqlens_k != nseqlens_q { if nseqlens_k != nseqlens_q {
candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}") candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}")
} }
let batch_size = nseqlens_q - 1; let batch_size = nseqlens_q - 1;
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
"DType mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes.dtype(),
DType::F32
);
}
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
if num_heads != alibi_slopes_layout.shape().dims1()? {
candle::bail!(
"shape mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes_layout.shape(),
(num_heads)
);
}
let alibi_slopes = match &*alibi_slopes {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
};
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
*alibi_slopes.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};
// if window_size_left > self.max_seqlen_k or None => -1
let mut window_size_left = self
.window_size_left
.filter(|v| v <= &self.max_seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
// if window_size_right > self.max_seqlen_k or None => -1
let mut window_size_right = self
.window_size_right
.filter(|v| v <= &self.max_seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
let head_size = round_multiple(head_size_og, 8); let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32); let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128); let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128);
@ -323,9 +551,22 @@ impl FlashAttnVarLen {
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q) .alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
.w()?; .w()?;
let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 };
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
let is_causal = if window_size_left < 0 && window_size_right == 0 {
1
} else {
0
};
if window_size_left < 0 && window_size_right >= 0 {
window_size_left = self.max_seqlen_k as i32;
}
if window_size_left >= 0 && window_size_right < 0 {
window_size_right = self.max_seqlen_k as i32;
}
unsafe { unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@ -340,12 +581,14 @@ impl FlashAttnVarLen {
v_ptr, v_ptr,
dst_ptr, dst_ptr,
softmax_lse_ptr, softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ seqlens_q_ptr, /* cu_seqlens_q_ptr */ seqlens_q_ptr,
/* cu_seqlens_k_ptr */ seqlens_k_ptr, /* cu_seqlens_k_ptr */ seqlens_k_ptr,
/* q_batch_stride */ 0, /* q_batch_stride */ 0,
/* k_batch_stride */ 0, /* k_batch_stride */ 0,
/* v_batch_stride */ 0, /* v_batch_stride */ 0,
/* o_batch_stride */ 0, /* o_batch_stride */ 0,
/* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32, /* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32,
@ -364,8 +607,10 @@ impl FlashAttnVarLen {
/* seqlen_k */ self.max_seqlen_k as u32, /* seqlen_k */ self.max_seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ is_bf16, /* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
) )
} }
@ -440,13 +685,176 @@ pub fn flash_attn_varlen(
softmax_scale: f32, softmax_scale: f32,
causal: bool, causal: bool,
) -> Result<Tensor> { ) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttnVarLen { let op = FlashAttnVarLen {
softmax_scale, softmax_scale,
causal,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
seqlens_q: seqlens_q.clone(), seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(), seqlens_k: seqlens_k.clone(),
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
pub fn flash_attn_varlen_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
pub fn flash_attn_varlen_alibi(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
pub fn flash_attn_varlen_alibi_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
}; };
q.apply_op3(k, v, op) q.apply_op3(k, v, op)
} }

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-kernels" name = "candle-kernels"
version = "0.3.1" version = "0.3.3"
edition = "2021" edition = "2021"
description = "CUDA kernels for Candle" description = "CUDA kernels for Candle"
@ -12,6 +12,4 @@ license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
[build-dependencies] [build-dependencies]
anyhow = { version = "1", features = ["backtrace"] } bindgen_cuda = "0.1.1"
glob = "0.3.1"
rayon = "1.7.0"

View File

@ -1,243 +1,8 @@
use std::io::Write;
fn main() { fn main() {
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
cuda::set_include_dir(); let builder = bindgen_cuda::Builder::default();
let (write, kernel_paths) = cuda::build_ptx(); println!("cargo:info={builder:?}");
if write { let bindings = builder.build_ptx().unwrap();
let mut file = std::fs::File::create("src/lib.rs").unwrap(); bindings.write("src/lib.rs").unwrap();
for kernel_path in kernel_paths {
let name = kernel_path.file_stem().unwrap().to_str().unwrap();
file.write_all(
format!(
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
name.to_uppercase().replace('.', "_"),
name
)
.as_bytes(),
)
.unwrap();
file.write_all(&[b'\n']).unwrap();
}
}
}
mod cuda {
use anyhow::{Context, Result};
pub fn set_include_dir() {
use std::path::PathBuf;
// NOTE: copied from cudarc build.rs.
// We can't actually set a env!() value from another crate,
// so we have to do that here.
// use PathBuf;
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
#[allow(unused)]
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
#[allow(unused)]
let roots = roots.into_iter().map(Into::<PathBuf>::into);
#[cfg(feature = "ci-check")]
let root: PathBuf = "ci".into();
#[cfg(not(feature = "ci-check"))]
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.unwrap();
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
}
pub fn build_ptx() -> (bool, Vec<std::path::PathBuf>) {
use rayon::prelude::*;
use std::path::PathBuf;
let out_dir = std::env::var("OUT_DIR").unwrap();
let kernel_paths: Vec<PathBuf> = glob::glob("src/*.cu")
.unwrap()
.map(|p| p.unwrap())
.collect();
let mut include_directories: Vec<PathBuf> = glob::glob("src/**/*.cuh")
.unwrap()
.map(|p| p.unwrap())
.collect();
println!("cargo:rerun-if-changed=src/");
// for path in &kernel_paths {
// println!("cargo:rerun-if-changed={}", path.display());
// }
for path in &mut include_directories {
// println!("cargo:rerun-if-changed={}", path.display());
let destination =
std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap());
std::fs::copy(path.clone(), destination).unwrap();
// remove the filename from the path so it's just the directory
path.pop();
}
include_directories.sort();
include_directories.dedup();
let compute_cap = compute_cap().expect("Could not get Cuda compute cap");
#[allow(unused)]
let include_options: Vec<String> = include_directories
.into_iter()
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
.collect::<Vec<_>>();
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let children = kernel_paths
.par_iter()
.flat_map(|p| {
let mut output = p.clone();
output.set_extension("ptx");
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
let ignore = if output_filename.exists() {
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
let in_modified = p.metadata().unwrap().modified().unwrap();
out_modified.duration_since(in_modified).is_ok()
} else {
false
};
if ignore {
None
} else {
let mut command = std::process::Command::new("nvcc");
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", &out_dir])
// Flash attention only
// .arg("--expt-relaxed-constexpr")
.args(&include_options);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(p);
Some((p, command.spawn()
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
}
})
.collect::<Vec<_>>();
let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
.unwrap()
.map(|p| p.unwrap())
.collect();
// We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed
// some old ones
let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len();
for (kernel_path, child) in children {
let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
assert!(
output.status.success(),
"nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
(write, kernel_paths)
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse code")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
};
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
}
Ok(compute_cap)
}
} }

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.3.1" version = "0.3.3"
edition = "2021" edition = "2021"
description = "Metal kernels for Candle" description = "Metal kernels for Candle"
@ -10,7 +10,7 @@ categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
metal = { version = "0.27.0", features = ["mps"], package="candle-metal" } metal = { version = "0.27.0", features = ["mps"]}
once_cell = "1.18.0" once_cell = "1.18.0"
thiserror = "1" thiserror = "1"
tracing = "0.1.37" tracing = "0.1.37"

View File

@ -109,16 +109,16 @@ kernel void FN_NAME##_strided( \
} \ } \
AFFINE(affine_float, float) AFFINE(affine_f32, float)
AFFINE(affine_half, half) AFFINE(affine_f16, half)
POWF(powf_float, float) POWF(powf_f32, float)
POWF(powf_half, half) POWF(powf_f16, half)
ELU(elu_float, float) ELU(elu_f32, float)
ELU(elu_half, half) ELU(elu_f16, half)
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
AFFINE(affine_bfloat, bfloat); AFFINE(affine_bf16, bfloat);
POWF(powf_bfloat, bfloat); POWF(powf_bf16, bfloat);
ELU(elu_bfloat, bfloat); ELU(elu_bf16, bfloat);
#endif #endif

View File

@ -1,5 +1,8 @@
#include <metal_stdlib> #include <metal_stdlib>
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index( METAL_FUNC uint get_strided_index(
uint idx, uint idx,
constant size_t &num_dims, constant size_t &num_dims,
@ -22,15 +25,15 @@ kernel void FN_NAME( \
constant size_t &dim, \ constant size_t &dim, \
device const TYPENAME *left, \ device const TYPENAME *left, \
device const TYPENAME *right, \ device const TYPENAME *right, \
device TYPENAME *output, \ device OUT_TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \ uint tid [[ thread_position_in_grid ]] \
) { \ ) { \
if (thread_position_in_grid >= dim) { \ if (tid >= dim) { \
return; \ return; \
} \ } \
TYPENAME x = left[thread_position_in_grid]; \ TYPENAME x = left[tid]; \
TYPENAME y = right[thread_position_in_grid]; \ TYPENAME y = right[tid]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \ output[tid] = OUT_TYPENAME(FN); \
}\ }\
kernel void FN_NAME_STRIDED( \ kernel void FN_NAME_STRIDED( \
constant size_t &dim, \ constant size_t &dim, \
@ -40,33 +43,73 @@ kernel void FN_NAME_STRIDED( \
constant size_t *right_strides, \ constant size_t *right_strides, \
device const TYPENAME *left, \ device const TYPENAME *left, \
device const TYPENAME *right, \ device const TYPENAME *right, \
device TYPENAME *output, \ device OUT_TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \ uint tid [[ thread_position_in_grid ]] \
) { \ ) { \
if (thread_position_in_grid >= dim) { \ if (tid >= dim) { \
return; \ return; \
} \ } \
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \ TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \ output[tid] = OUT_TYPENAME(FN); \
} }
#define BINARY_OP(FN, NAME) \ #define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \ BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, half, NAME##_half, NAME##_half_strided); BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
#define INT64_BINARY_OP(NAME, FN) \
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \ #define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define BINARY_OP_OUT(NAME, FN) \
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
#define INT64_BINARY_OP_OUT(NAME, FN) \
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
BINARY_OP(x + y, add) BINARY_OP(x + y, add)
BINARY_OP(x - y, sub) BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul) BINARY_OP(x * y, mul)
BINARY_OP(x / y, div) BINARY_OP(x / y, div)
BINARY_OP(MIN(x, y), min)
BINARY_OP(MAX(x, y), max)
BINARY_OP_OUT(eq, x == y)
BINARY_OP_OUT(ne, x != y)
BINARY_OP_OUT(le, x <= y)
BINARY_OP_OUT(lt, x < y)
BINARY_OP_OUT(ge, x >= y)
BINARY_OP_OUT(gt, x > y)
#if __METAL_VERSION__ >= 220
INT64_BINARY_OP(add, x + y)
INT64_BINARY_OP(sub, x - y)
INT64_BINARY_OP(mul, x * y)
INT64_BINARY_OP(div, x / y)
INT64_BINARY_OP(min, MIN(x, y))
INT64_BINARY_OP(max, MAX(x, y))
INT64_BINARY_OP_OUT(eq, x == y)
INT64_BINARY_OP_OUT(ne, x != y)
INT64_BINARY_OP_OUT(le, x <= y)
INT64_BINARY_OP_OUT(lt, x < y)
INT64_BINARY_OP_OUT(ge, x >= y)
INT64_BINARY_OP_OUT(gt, x > y)
#endif
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub) BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul) BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div) BFLOAT_BINARY_OP(x / y, div)
BFLOAT_BINARY_OP(MIN(x, y), min)
BFLOAT_BINARY_OP(MAX(x, y), max)
#endif #endif

View File

@ -48,8 +48,17 @@ kernel void FN_NAME_STRIDED( \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float) CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half) CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 220
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif
#if __METAL_VERSION__ >= 310
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
#endif #endif

View File

@ -0,0 +1,213 @@
template <typename T>
METAL_FUNC void im2col(
constant size_t &dst_numel,
constant size_t &h_out,
constant size_t &w_out,
constant size_t &h_k,
constant size_t &w_k,
constant size_t &stride,
constant size_t &padding,
constant size_t &dilation,
constant size_t *src_dims,
constant size_t *src_strides,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// dst: (b_size, h_out, w_out, c_in, h_k, w_k)
// src: (b_size, c_in, h_in, w_in)
if (tid >= dst_numel) {
return;
}
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3];
const size_t dst_s4 = w_k;
const size_t dst_s3 = h_k * dst_s4;
const size_t dst_s2 = c_in * dst_s3;
const size_t dst_s1 = w_out * dst_s2;
const size_t dst_s0 = h_out * dst_s1;
size_t tmp_tid = tid;
const size_t b_idx = tmp_tid / dst_s0;
tmp_tid -= b_idx * dst_s0;
const size_t h_idx = tmp_tid / dst_s1;
tmp_tid -= h_idx * dst_s1;
const size_t w_idx = tmp_tid / dst_s2;
tmp_tid -= w_idx * dst_s2;
const size_t c_idx = tmp_tid / dst_s3;
tmp_tid -= c_idx * dst_s3;
const size_t h_k_idx = tmp_tid / dst_s4;
tmp_tid -= h_k_idx * dst_s4;
const size_t w_k_idx = tmp_tid;
size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
if (src_h_idx < padding || src_h_idx >= h_in + padding) {
dst[tid] = static_cast<T>(0);
}
else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
dst[tid] = static_cast<T>(0);
}
else {
src_h_idx -= padding;
src_w_idx -= padding;
const size_t src_i =
b_idx * src_strides[0]
+ c_idx * src_strides[1]
+ src_h_idx * src_strides[2]
+ src_w_idx * src_strides[3];
dst[tid] = src[src_i];
}
}
template <typename T>
METAL_FUNC void im2col1d(
constant size_t &dst_numel,
constant size_t &l_out,
constant size_t &l_k,
constant size_t &stride,
constant size_t &padding,
constant size_t &dilation,
constant size_t *src_dims,
constant size_t *src_strides,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// dst: (b_size, l_out, c_in, l_k)
// src: (b_size, c_in, l_in)
if (tid >= dst_numel) {
return;
}
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
const size_t dst_s2 = l_k;
const size_t dst_s1 = c_in * dst_s2;
const size_t dst_s0 = l_out * dst_s1;
size_t tmp_dst_i = tid;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t l_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= l_idx * dst_s1;
const size_t c_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= c_idx * dst_s2;
const size_t l_k_idx = tmp_dst_i;
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[tid] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_strides[0] + c_idx * src_strides[1] + src_l_idx * src_strides[2];
dst[tid] = src[src_i];
}
}
template <typename T>
METAL_FUNC void upsample_nearest2d(
constant size_t &w_out,
constant size_t &h_out,
constant float &w_scale,
constant float &h_scale,
constant size_t *src_dims,
constant size_t *src_s,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// src: (b_size, c_in, w_in, h_in)
const size_t c = src_dims[1];
const size_t w_in = src_dims[2];
const size_t h_in = src_dims[3];
if (tid >= src_dims[0] * c * w_out * h_out) {
return;
}
// TODO: Improve this.
const size_t b_idx = tid / (w_out * h_out * c);
const size_t c_idx = (tid / (w_out * h_out)) % c;
const size_t dst_w = (tid / h_out) % w_out;
const size_t dst_h = tid % h_out;
size_t src_w = static_cast<size_t>(dst_w * w_scale);
size_t src_h = static_cast<size_t>(dst_h * h_scale);
if (src_w >= w_in) {
src_w = w_in - 1;
}
if (src_h >= h_in) {
src_h = h_in - 1;
}
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
dst[tid] = src[src_i];
}
#define IM2COL_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_numel, \
constant size_t &h_out, \
constant size_t &w_out, \
constant size_t &h_k, \
constant size_t &w_k, \
constant size_t &stride, \
constant size_t &padding, \
constant size_t &dilation, \
constant size_t *src_dims, \
constant size_t *src_strides, \
device const T *src, \
device T *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
im2col<T>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \
#define IM2COL1D_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_numel, \
constant size_t &l_out, \
constant size_t &l_k, \
constant size_t &stride, \
constant size_t &padding, \
constant size_t &dilation, \
constant size_t *src_dims, \
constant size_t *src_strides, \
device const T *src, \
device T *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &w_out, \
constant size_t &h_out, \
constant float &w_scale, \
constant float &h_scale, \
constant size_t *dims, \
constant size_t *strides, \
device const TYPENAME *src, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \
} \
IM2COL_OP(float, im2col_f32)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)

View File

@ -1,6 +1,34 @@
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t id_i = (tid / right_size) % ids_size;
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
/*
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
output[tid] = input[src_i];
}
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \ kernel void NAME( \
constant size_t &dst_size, \ constant size_t &dst_size, \
@ -11,93 +39,160 @@ kernel void NAME( \
const device TYPENAME *input, \ const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \ const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \ device TYPENAME *output, \
uint gid [[ thread_position_in_grid ]] \ uint tid [[ thread_position_in_grid ]] \
) { \ ) { \
if (gid >= dst_size) { \ index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
return; \
} \
const size_t id_i = (gid / right_size) % ids_size; \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid / right_size / ids_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
output[gid] = input[src_i]; \
} }
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void gather(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const INDEX_TYPENAME input_i = input_ids[tid];
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
output[tid] = input[src_i];
}
template <typename T, typename I> # define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
void index_add( kernel void NAME( \
device I *ids [[buffer(0)]], constant size_t &dst_size, \
device T *inp [[buffer(1)]], constant size_t &left_size, \
device T *out [[buffer(2)]], constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
}
constant uint &ids_dim_size, template<typename TYPENAME, typename INDEX_TYPENAME>
constant uint &left_size, METAL_FUNC void scatter_add(
constant uint &dst_dim_size, constant size_t &dst_size,
constant uint &right_size, constant size_t &left_size,
constant size_t &src_dim_size,
uint gid [[ thread_position_in_grid ]] \ constant size_t &right_size,
) { constant size_t &dst_dim_size,
const device TYPENAME *input,
if (gid >= left_size * right_size) { const device INDEX_TYPENAME *input_ids,
return; device TYPENAME *output,
} uint tid [[ thread_position_in_grid ]]
) {
const uint i = gid; if (tid >= dst_size) {
const uint pre = i / right_size; return;
const uint post = i % right_size; }
const size_t right_rank_i = tid % right_size;
for (uint j = 0; j < ids_dim_size; j++) { const size_t left_rank_i = tid / right_size;
const uint idx = ids[j]; for (unsigned int j = 0; j < src_dim_size; ++j) {
const uint src_i = (pre * ids_dim_size + j) * right_size + post; const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post; const INDEX_TYPENAME idx = input_ids[src_i];
out[dst_i] += inp[src_i]; const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
} }
} }
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ # define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void FN_NAME( \ kernel void NAME( \
device INDEX_TYPENAME *ids [[buffer(0)]], \ constant size_t &dst_size, \
device TYPENAME *inp [[buffer(1)]], \ constant size_t &left_size, \
device TYPENAME *out [[buffer(2)]], \ constant size_t &src_dim_size, \
constant uint &ids_dim_size, \ constant size_t &right_size, \
constant uint &left_size, \ constant size_t &dst_dim_size, \
constant uint &dst_dim_size, \ const device TYPENAME *input, \
constant uint &right_size, \ const device INDEX_TYPENAME *input_ids, \
uint gid [[ thread_position_in_grid ]] \ device TYPENAME *output, \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \ uint tid [[ thread_position_in_grid ]] \
) { \
scatter_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index_add(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
constant size_t &ids_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) {
const INDEX_TYPENAME idx = input_ids[j];
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
}
}
# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &dst_dim_size, \
constant size_t &ids_dim_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
index_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, input, input_ids, output, tid); \
}
INDEX_OP(is_u32_f32, uint, float) INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half) INDEX_OP(is_u32_f16, uint, half)
GATHER_OP(gather_u32_f32, uint, float)
GATHER_OP(gather_u32_f16, uint, half)
SCATTER_ADD_OP(sa_u32_f32, uint, float)
SCATTER_ADD_OP(sa_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16) INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
IA_OP(bfloat, uint32_t, ia_u32_bf16) INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
IA_OP(bfloat, uint8_t, ia_u8_bf16) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
#endif #endif
IA_OP(half, uint32_t, ia_u32_f16) INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
IA_OP(half, uint8_t, ia_u8_f16) INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
IA_OP(float, int64_t, ia_i64_f32) INDEX_ADD_OP(ia_i64_f32, int64_t, float)
IA_OP(uint8_t, int64_t, ia_i64_u8) INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
IA_OP(int64_t, int64_t, ia_i64_i64) INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
IA_OP(uint32_t, int64_t, ia_i64_u32) INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
IA_OP(float, uint32_t, ia_u32_f32) INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
IA_OP(uint8_t, uint32_t, ia_u32_u8) INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
IA_OP(int64_t, uint32_t, ia_u32_i64) INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
IA_OP(uint32_t, uint32_t, ia_u32_u32) INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
IA_OP(float, uint8_t, ia_u8_f32) INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
IA_OP(uint8_t, uint8_t, ia_u8_u8) INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
IA_OP(uint32_t, uint8_t, ia_u8_u32) INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
IA_OP(int64_t, uint8_t, ia_u8_i64) INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)

View File

@ -5,7 +5,6 @@ use metal::{
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void; use std::ffi::c_void;
use std::sync::RwLock; use std::sync::RwLock;
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
const AFFINE: &str = include_str!("affine.metal"); const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal"); const INDEXING: &str = include_str!("indexing.metal");
@ -14,8 +13,13 @@ const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal"); const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal"); const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal"); const REDUCE: &str = include_str!("reduce.metal");
const CONV: &str = include_str!("conv.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
/// actual total buffer length).
/// Then kernels can just do their op on their single point in the buffer.
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64; let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
@ -37,6 +41,10 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data) <P as EncoderParam>::set_param(encoder, position, data)
} }
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
trait EncoderParam { trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
} }
@ -108,6 +116,7 @@ pub enum Source {
Cast, Cast,
Reduce, Reduce,
Mfa, Mfa,
Conv,
} }
macro_rules! ops{ macro_rules! ops{
@ -118,16 +127,20 @@ macro_rules! ops{
$( $(
pub mod $name { pub mod $name {
use super::Kernel; use super::Kernel;
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float")); pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8"));
} }
)+ )+
pub mod copy { pub mod copy {
use super::Kernel; use super::Kernel;
pub const FLOAT: Kernel = Kernel("copy_float"); pub const FLOAT: Kernel = Kernel("copy_f32");
pub const HALF: Kernel = Kernel("copy_half"); pub const HALF: Kernel = Kernel("copy_f16");
pub const BFLOAT: Kernel = Kernel("copy_bfloat"); pub const BFLOAT: Kernel = Kernel("copy_bf16");
pub const I64: Kernel = Kernel("copy_i64");
pub const U32: Kernel = Kernel("copy_u32"); pub const U32: Kernel = Kernel("copy_u32");
pub const U8: Kernel = Kernel("copy_u8"); pub const U8: Kernel = Kernel("copy_u8");
} }
@ -138,16 +151,20 @@ macro_rules! ops{
$( $(
pub mod $name { pub mod $name {
use super::Kernel; use super::Kernel;
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided")); pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided")); pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided"));
} }
)+ )+
pub mod copy { pub mod copy {
use super::Kernel; use super::Kernel;
pub const FLOAT: Kernel = Kernel("copy_float_strided"); pub const FLOAT: Kernel = Kernel("copy_f32_strided");
pub const HALF: Kernel = Kernel("copy_half_strided"); pub const HALF: Kernel = Kernel("copy_f16_strided");
pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided"); pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
pub const I64: Kernel = Kernel("copy_i64_strided");
pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U32: Kernel = Kernel("copy_u32_strided");
pub const U8: Kernel = Kernel("copy_u8_strided"); pub const U8: Kernel = Kernel("copy_u8_strided");
} }
@ -156,10 +173,13 @@ macro_rules! ops{
} }
pub mod unary { pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
recip
);
} }
pub mod binary { pub mod binary {
ops!(add, sub, mul, div); ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
} }
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -218,9 +238,13 @@ impl Kernels {
Source::Indexing => INDEXING, Source::Indexing => INDEXING,
Source::Cast => CAST, Source::Cast => CAST,
Source::Reduce => REDUCE, Source::Reduce => REDUCE,
Source::Conv => CONV,
Source::Mfa => panic!("Invalid lib"), Source::Mfa => panic!("Invalid lib"),
} }
} }
/// Load the give library from its [`source`].
/// If this has been previously loaded it will just fetch it from cache.
pub fn load_library( pub fn load_library(
&self, &self,
device: &Device, device: &Device,
@ -233,9 +257,11 @@ impl Kernels {
let lib = match source { let lib = match source {
Source::Mfa => { Source::Mfa => {
let source_data = MFA; let source_data = MFA;
device device.new_library_with_data(source_data).map_err(|e| {
.new_library_with_data(source_data) MetalKernelError::LoadLibraryError(format!(
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?
} }
source => { source => {
let source_content = self.get_library_source(source); let source_content = self.get_library_source(source);
@ -263,6 +289,9 @@ impl Kernels {
Ok(func) Ok(func)
} }
/// Load the give pipeline
/// loads the library from source, then gets the function [`name`] from
/// that source
fn load_pipeline_with_constants( fn load_pipeline_with_constants(
&self, &self,
device: &Device, device: &Device,
@ -291,6 +320,9 @@ impl Kernels {
} }
} }
/// Load the give pipeline
/// loads the library from source, then gets the function [`name`] from
/// that source (without constants)
pub fn load_pipeline( pub fn load_pipeline(
&self, &self,
device: &Device, device: &Device,
@ -570,6 +602,64 @@ pub fn call_reduce_contiguous(
Ok(()) Ok(())
} }
pub fn call_reduce_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
strides: &[usize],
out_length: usize,
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length: usize = shape.iter().product();
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
shape.len(),
shape,
strides,
elements_to_sum,
(input, input_offset),
output
)
);
let thread_group_count = MTLSize {
width: out_length as u64,
height: 1,
depth: 1,
};
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn call_last_softmax( pub fn call_last_softmax(
device: &Device, device: &Device,
@ -579,6 +669,7 @@ pub fn call_last_softmax(
length: usize, length: usize,
elements_to_sum: usize, elements_to_sum: usize,
input: &Buffer, input: &Buffer,
input_offset: usize,
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
@ -586,7 +677,10 @@ pub fn call_last_softmax(
encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, input, output)); set_params!(
encoder,
(length, elements_to_sum, (input, input_offset), output)
);
let out_length = length / elements_to_sum; let out_length = length / elements_to_sum;
@ -930,6 +1024,164 @@ pub fn call_index_select(
Ok(()) Ok(())
} }
#[allow(clippy::too_many_arguments)]
pub fn call_gather(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
ids_size: usize,
dim: usize,
input: &Buffer,
input_offset: usize,
ids: &Buffer,
ids_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids_size * left_size * right_size;
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
left_size,
src_dim_size,
right_size,
ids_size,
(input, input_offset),
(ids, ids_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
pub fn call_scatter_add(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
src_shape: &[usize],
dst_shape: &[usize],
dim: usize,
input: &Buffer,
input_offset: usize,
ids: &Buffer,
ids_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product();
let right_size: usize = src_shape[dim + 1..].iter().product();
let src_dim_size = src_shape[dim];
let dst_el = left_size * right_size;
let dst_dim_size = dst_shape[dim];
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
left_size,
src_dim_size,
right_size,
dst_dim_size,
(input, input_offset),
(ids, ids_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
pub fn call_index_add(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
src_shape: &[usize],
dst_shape: &[usize],
ids_shape: &[usize],
dim: usize,
input: &Buffer,
input_offset: usize,
ids: &Buffer,
ids_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product();
let right_size: usize = src_shape[dim + 1..].iter().product();
let src_dim_size = src_shape[dim];
let dst_el = left_size * right_size;
let dst_dim_size = dst_shape[dim];
let ids_dim_size = ids_shape[0];
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
left_size,
src_dim_size,
right_size,
dst_dim_size,
ids_dim_size,
(input, input_offset),
(ids, ids_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum Value { pub enum Value {
USize(usize), USize(usize),
@ -1053,204 +1305,272 @@ pub fn call_gemm(
mnk: (m, n, k), mnk: (m, n, k),
})?; })?;
}; };
// let d_trans = false; let d_trans = false;
// let alpha = 1.0f32; let alpha = 1.0f32;
// let beta = 0.0f32; let beta = 0.0f32;
// let batched = b > 1; let batched = b > 1;
// let fused_activation = false; let fused_activation = false;
// let fused_bias = false; let fused_bias = false;
// let m_simd = 16; let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
// let n_simd = 16; let m_simd = 8;
// let k_simd = 16; let n_simd = 8;
// let m_splits = 2; let k_simd = 64;
// let n_splits = 2; let m_splits = 1;
// let constants = Some(ConstantValues::new(vec![ let n_splits = 1;
// (0, Value::USize(m)), (m_simd, n_simd, k_simd, m_splits, n_splits)
// (1, Value::USize(n)), } else {
// (2, Value::USize(k)), let m_simd = 40;
// (10, Value::Bool(a_trans)), let n_simd = 40;
// (11, Value::Bool(b_trans)), let k_simd = 32;
// (13, Value::Bool(d_trans)), let m_splits = 1;
// (20, Value::F32(alpha)), let n_splits = 1;
// (21, Value::F32(beta)), (m_simd, n_simd, k_simd, m_splits, n_splits)
// (100, Value::Bool(batched)), };
// (101, Value::Bool(fused_activation)), let constants = Some(ConstantValues::new(vec![
// // Garbage (0, Value::USize(m)),
// (102, Value::Bool(false)), (1, Value::USize(n)),
// (103, Value::Bool(false)), (2, Value::USize(k)),
// (113, Value::Bool(false)), (10, Value::Bool(a_trans)),
// (50_000, Value::Bool(false)), (11, Value::Bool(b_trans)),
// // End garbage (13, Value::Bool(d_trans)),
// (200, Value::U16(m_simd)), (20, Value::F32(alpha)),
// (201, Value::U16(n_simd)), (21, Value::F32(beta)),
// (202, Value::U16(k_simd)), (100, Value::Bool(batched)),
// (210, Value::U16(m_splits)), (101, Value::Bool(fused_activation)),
// (211, Value::U16(n_splits)), // Garbage
// (50_001, Value::Bool(fused_bias)), (102, Value::Bool(false)),
// ])); (103, Value::Bool(false)),
// let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; (113, Value::Bool(false)),
// let m_group = m_simd * m_splits; (50_000, Value::Bool(false)),
// let n_group = n_simd * n_splits; // End garbage
// (200, Value::U16(m_simd)),
// let a_block_length = m_group * k_simd; (201, Value::U16(n_simd)),
// let b_block_length = k_simd * n_group; (202, Value::U16(k_simd)),
// (210, Value::U16(m_splits)),
// let mut block_elements = a_block_length + b_block_length; (211, Value::U16(n_splits)),
// if (m % 8 != 0) && (n % 8 != 0) { (50_001, Value::Bool(fused_bias)),
// let c_block_length = m_group * n_group; ]));
// block_elements = std::cmp::max(c_block_length, block_elements) let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
// } let m_group = m_simd * m_splits;
// if fused_bias { let n_group = n_simd * n_splits;
// if d_trans {
// block_elements = std::cmp::max(block_elements, m_group);
// } else {
// block_elements = std::cmp::max(block_elements, n_group);
// }
// }
// let bytes = match name {
// "sgemm" => 4,
// "hgemm" => 2,
// other => {
// return Err(MetalKernelError::LoadLibraryError(format!(
// "{other} is not a valid kernel for gemm"
// )));
// }
// };
// let block_bytes = block_elements * bytes;
//
// let encoder = command_buffer.new_compute_command_encoder();
// encoder.wait_for_fence(&kernels.fence);
// encoder.set_compute_pipeline_state(&pipeline);
// encoder.set_threadgroup_memory_length(0, block_bytes.into());
// encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
// encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
// encoder.set_buffer(2, Some(output), 0);
// // TODO Tensor D
//
// let grid_z = b;
// if batched {
// let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
// let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
// let byte_stride_c = m * n * bytes as usize;
// // TODO byte_stride_d
// let byte_stride_d = 0;
//
// let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
// for i in 0..b {
// buffer.push((i * byte_stride_a) as u64);
// buffer.push((i * byte_stride_b) as u64);
// buffer.push((i * byte_stride_c) as u64);
// buffer.push((i * byte_stride_d) as u64);
// }
// encoder.set_bytes(
// 10,
// (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
// buffer.as_ptr() as *const NSUInteger as *const c_void,
// );
// }
//
// let grid_size = MTLSize {
// width: divide(n, n_group.into()),
// height: divide(m, m_group.into()),
// depth: grid_z as NSUInteger,
// };
// let group_size = MTLSize {
// width: 32 * (m_splits as u64) * (n_splits as u64),
// height: 1,
// depth: 1,
// };
// // println!("grid size {grid_size:?} group size {group_size:?}");
// encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
// encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
// encoder.use_resource(output, metal::MTLResourceUsage::Write);
// encoder.dispatch_thread_groups(grid_size, group_size);
// encoder.update_fence(&kernels.fence);
// encoder.end_encoding();
let (b, m, n, k) = ( let a_block_length = m_group * k_simd;
b as NSUInteger, let b_block_length = k_simd * n_group;
m as NSUInteger,
n as NSUInteger,
k as NSUInteger,
);
let (size, data_type) = if name == "sgemm" { (4, 0x10000000 | 32) } else { (2, 0x10000000 | 16) }; let mut block_elements = a_block_length + b_block_length;
if (m % 8 != 0) && (n % 8 != 0) {
let c_block_length = m_group * n_group;
block_elements = std::cmp::max(c_block_length, block_elements)
}
if fused_bias {
if d_trans {
block_elements = std::cmp::max(block_elements, m_group);
} else {
block_elements = std::cmp::max(block_elements, n_group);
}
}
let bytes = match name {
"sgemm" => 4,
"hgemm" => 2,
other => {
return Err(MetalKernelError::LoadLibraryError(format!(
"{other} is not a valid kernel for gemm"
)));
}
};
let block_bytes = block_elements * bytes;
let left_matrix = create_matrix( let encoder = command_buffer.new_compute_command_encoder();
lhs_buffer, encoder.wait_for_fence(&kernels.fence);
(b, m, k), encoder.set_compute_pipeline_state(&pipeline);
a_trans, encoder.set_threadgroup_memory_length(0, block_bytes.into());
size, encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
lhs_offset as NSUInteger, encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
data_type, encoder.set_buffer(2, Some(output), 0);
).unwrap(); // TODO Tensor D
let right_matrix = create_matrix( let grid_z = b;
rhs_buffer, if batched {
(b, k, n), let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
b_trans, let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
size, let byte_stride_c = m * n * bytes as usize;
rhs_offset as NSUInteger, // TODO byte_stride_d
data_type, let byte_stride_d = 0;
).unwrap();
let result_matrix = create_matrix( let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
output, for i in 0..b {
(b, m, n), buffer.push((i * byte_stride_a) as u64);
false, buffer.push((i * byte_stride_b) as u64);
size, buffer.push((i * byte_stride_c) as u64);
0, buffer.push((i * byte_stride_d) as u64);
data_type, }
).unwrap(); encoder.set_bytes(
10,
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
buffer.as_ptr() as *const NSUInteger as *const c_void,
);
}
// Create kernel let grid_size = MTLSize {
let matrix_multiplication = MatrixMultiplication::init( width: divide(n, n_group.into()),
&device, height: divide(m, m_group.into()),
a_trans, depth: grid_z as NSUInteger,
b_trans, };
m, let group_size = MTLSize {
n, width: 32 * (m_splits as u64) * (n_splits as u64),
k, height: 1,
1.0, depth: 1,
0.0, };
).unwrap(); // println!("grid size {grid_size:?} group size {group_size:?}");
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
matrix_multiplication.encode_to_command_buffer( encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
command_buffer, encoder.use_resource(output, metal::MTLResourceUsage::Write);
&left_matrix, encoder.dispatch_thread_groups(grid_size, group_size);
&right_matrix, encoder.update_fence(&kernels.fence);
&result_matrix, encoder.end_encoding();
);
Ok(()) Ok(())
} }
fn create_matrix( #[allow(clippy::too_many_arguments)]
buffer: &Buffer, pub fn call_im2col1d_strided(
(b, rows, columns): (NSUInteger, NSUInteger, NSUInteger), device: &Device,
transpose: bool, command_buffer: &CommandBufferRef,
size: NSUInteger, kernels: &Kernels,
offset: NSUInteger, name: &'static str,
data_type: u32, shape: &[usize],
) -> Option<Matrix> { strides: &[usize],
let (rows, columns) = if transpose { (k_size, stride, padding, dilation): (usize, usize, usize, usize),
(columns, rows) input: &Buffer,
} else { input_offset: usize,
(rows, columns) output: &Buffer,
}; ) -> Result<(), MetalKernelError> {
let descriptor = if b == 1 { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
MatrixDescriptor::init_single(rows, columns, columns * size, data_type) let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
} else { let dst_el = shape[0] * l_out * shape[1] * k_size;
MatrixDescriptor::init_multiple(
rows, let encoder = command_buffer.new_compute_command_encoder();
columns, let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
b, encoder.wait_for_fence(&kernels.fence);
columns * size, encoder.set_compute_pipeline_state(&pipeline);
rows * columns * size, set_params!(
data_type, encoder,
(
dst_el,
l_out,
k_size,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
output
) )
}; );
return Matrix::init_with_buffer_descriptor(&buffer, offset * size, &descriptor); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_im2col_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
strides: &[usize],
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let h = shape[2];
let w = shape[3];
let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
h_out,
w_out,
h_k,
w_k,
stride,
padding,
dilation,
shape,
strides,
(input, input_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_upsample_nearest_2d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
strides: &[usize],
out_w: usize,
out_h: usize,
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let dst_el = out_w * out_h * shape[0] * shape[1];
let scale_w = shape[2] as f32 / out_w as f32;
let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
out_w,
out_h,
scale_w,
scale_h,
shape,
strides,
(input, input_offset),
output
)
);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
} }
fn divide(m: usize, b: usize) -> NSUInteger { fn divide(m: usize, b: usize) -> NSUInteger {

View File

@ -2,6 +2,7 @@
using namespace metal; using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y)) #define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index( METAL_FUNC uint get_strided_index(
uint idx, uint idx,
@ -20,9 +21,130 @@ METAL_FUNC uint get_strided_index(
constant int THREADGROUP_SIZE = 2048; constant int THREADGROUP_SIZE = 2048;
# define REDUCE(FN, NAME, T) \
#define ARGMIN(NAME, T, MAXVALUE) \
kernel void NAME( \ kernel void NAME( \
constant size_t &src_numel, \ constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
bool notset = true; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || src[strided_i] < shared_memory[tid]) { \
shared_memory[tid] = src[strided_i]; \
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
bool notset = true; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || shared_memory[tid] < src[strided_i]) { \
shared_memory[tid] = src[strided_i]; \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define REDUCE(FN, NAME, T, START) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \ constant size_t &el_to_sum_per_block, \
device const T *src, \ device const T *src, \
device T *dst, \ device T *dst, \
@ -34,21 +156,21 @@ kernel void NAME( \
\ \
threadgroup T shared_memory[THREADGROUP_SIZE]; \ threadgroup T shared_memory[THREADGROUP_SIZE]; \
\ \
shared_memory[tid] = 0; \ shared_memory[tid] = START; \
/* \ /* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \ // Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \ // to (dst_id + 1) * el_to_sum_per_block. \
*/ \ */ \
size_t start_idx = dst_id * el_to_sum_per_block; \ size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \ size_t idx = start_idx + tid; \
while (idx < stop_idx) { \ while (idx < stop_idx) { \
/* \ /* \
// TODO: Fast version for the contiguous case. \ // TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \ */ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
T x = shared_memory[tid]; \ T x = shared_memory[tid]; \
T y = src[idx]; \ T y = src[strided_i]; \
shared_memory[tid] = FN; \ shared_memory[tid] = FN; \
idx += block_dim; \ idx += block_dim; \
} \ } \
@ -71,10 +193,6 @@ kernel void NAME( \
} \ } \
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)
#define SOFTMAX(NAME, T) \ #define SOFTMAX(NAME, T) \
kernel void NAME( \ kernel void NAME( \
constant size_t &src_numel, \ constant size_t &src_numel, \
@ -142,8 +260,47 @@ kernel void NAME(
} \ } \
} \ } \
SOFTMAX(softmax_float, float) REDUCE(x + y, fast_sum_f32_strided, float, 0)
SOFTMAX(softmax_half, half) REDUCE(x + y, fast_sum_u32_strided, uint, 0)
#if __METAL_VERSION__ >= 310 REDUCE(x + y, fast_sum_f16_strided, half, 0)
SOFTMAX(softmax_bfloat, bfloat) REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
REDUCE(x * y, fast_mul_f32_strided, float, 1)
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
REDUCE(x * y, fast_mul_f16_strided, half, 1)
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32_strided, uint, 0)
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
#if __METAL_VERSION__ >= 220
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
#endif
#if __METAL_VERSION__ >= 310
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
#endif #endif

View File

@ -55,6 +55,9 @@ kernel void FN_NAME( \
WHERE_OP(float, uint8_t, where_u8_f32) WHERE_OP(float, uint8_t, where_u8_f32)
// WHERE_OP(double, uint8_t, where_u8_f64) // WHERE_OP(double, uint8_t, where_u8_f64)
// WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint8_t, uint8_t, where_u8_u8)
// WHERE_OP(uint32_t, uint8_t, where_u8_u32) WHERE_OP(uint32_t, uint8_t, where_u8_u32)
// WHERE_OP(int64_t, uint8_t, where_u8_i64)
#if __METAL_VERSION__ >= 220
WHERE_OP(int64_t, uint8_t, where_u8_i64)
#endif

View File

@ -1,209 +0,0 @@
import Metal
import MetalPerformanceShadersGraph
let type = MTLDataType.float;
let dataType = type;
var B = 2;
var M = 2;
var N = 2;
var K = 2;
var A_trans = false;
var B_trans = false;
var D_trans = false;
var alpha = Float(1.0);
var beta = Float(0.0);
var batched = B > 1;
var fused_activation = false;
var fused_bias = false;
let constants = MTLFunctionConstantValues()
constants.setConstantValue(&M, type: .uint, index: 0)
constants.setConstantValue(&N, type: .uint, index: 1)
constants.setConstantValue(&K, type: .uint, index: 2)
constants.setConstantValue(&A_trans, type: .bool, index: 10)
constants.setConstantValue(&B_trans, type: .bool, index: 11)
constants.setConstantValue(&D_trans, type: .bool, index: 13)
constants.setConstantValue(&alpha, type: .float, index: 20)
constants.setConstantValue(&beta, type: .float, index: 21)
constants.setConstantValue(&batched, type: .bool, index: 100)
constants.setConstantValue(&fused_activation, type: .bool, index: 101)
constants.setConstantValue(&fused_bias, type: .bool, index: 50001)
var M_simd = UInt16(16)
var N_simd = UInt16(16)
var K_simd = UInt16(32)
var M_splits = UInt16(2)
var N_splits = UInt16(2)
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
constants.setConstantValue(&K_simd, type: .ushort, index: 202)
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
constants.setConstantValue(&N_splits, type: .ushort, index: 211)
let M_group = M_simd * M_splits
let N_group = N_simd * N_splits
// Satisfy Metal API validation.
#if DEBUG
do {
var garbage: SIMD4<UInt64> = .zero
constants.setConstantValue(&garbage, type: .bool, index: 102)
constants.setConstantValue(&garbage, type: .bool, index: 103)
constants.setConstantValue(&garbage, type: .bool, index: 113)
constants.setConstantValue(&garbage, type: .bool, index: 50000)
}
#endif
let device = MTLCopyAllDevices().first!
device.shouldMaximizeConcurrentCompilation = true
var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!;
libraryURL.append(component: "src")
libraryURL.append(component: "libMetalFlashAttention.metallib")
let library = try! device.makeLibrary(URL: libraryURL)
var name: String
switch dataType {
case .half: name = "hgemm"
case .float: name = "sgemm"
default: fatalError()
}
let function = try! library.makeFunction(
name: name, constantValues: constants)
let A_block_length = M_group * K_simd
let B_block_length = K_simd * N_group
var blockElements = A_block_length + B_block_length;
if (M % 8 != 0) && (N % 8 != 0) {
let C_block_length = M_group * N_group;
blockElements = max(C_block_length, blockElements)
}
if fused_bias {
if D_trans {
blockElements = max(blockElements, M_group)
} else {
blockElements = max(blockElements, N_group)
}
}
// let blockBytes = blockElements * UInt16(dataType.size)
let elementSize = 4
let blockBytes = blockElements * UInt16(elementSize)
func ceilDivide(target: Int, granularity: UInt16) -> Int {
(target + Int(granularity) - 1) / Int(granularity)
}
var gridSize = MTLSize(
width: ceilDivide(target: N, granularity: N_group),
height: ceilDivide(target: M, granularity: M_group),
depth: 1)
let groupSize = MTLSize(
width: Int(32 * M_splits * N_splits),
height: 1,
depth: 1)
let commandQueue = device.makeCommandQueue()!
let threadgroupMemoryLength = blockBytes;
let rowsA = M;
let columnsA = K;
let rowsB = K;
let columnsB = N;
let rowsC = M;
let columnsC = N;
var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA)
var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB)
var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC)
var arrayD = [Float](repeating: 0, count: B * rowsC * columnsC)
for i in 0..<arrayA.count {
arrayA[i] = Float(i)
}
for i in 0..<arrayB.count {
arrayB[i] = Float(i)
}
let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])!
let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])!
let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
let bufferD = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
let pipeline = try device.makeComputePipelineState(function: function)
func call(bufferA: MTLBuffer, bufferB: MTLBuffer, bufferC: MTLBuffer){
let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)!
encoder.setComputePipelineState(pipeline)
encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0)
encoder.setBuffer(bufferA, offset: 0, index: 0)
encoder.setBuffer(bufferB, offset: 0, index: 1)
encoder.setBuffer(bufferC, offset: 0, index: 2)
let gridZ: Int = B
if batched{
func byteStride(shape: [Int]) -> Int {
let rank = shape.count
var output = elementSize * shape[rank - 2] * shape[rank - 1]
if shape.dropLast(2).reduce(1, *) == 1 {
output = 0
}
return output
}
let byteStrideA = M*K*elementSize
let byteStrideB = N*K*elementSize
let byteStrideC = M*N*elementSize
let byteStrideD = 0
withUnsafeTemporaryAllocation(
of: SIMD4<UInt64>.self, capacity: gridZ
) { buffer in
for i in 0..<buffer.count {
buffer[i] = SIMD4(
UInt64(truncatingIfNeeded: i * byteStrideA),
UInt64(truncatingIfNeeded: i * byteStrideB),
UInt64(truncatingIfNeeded: i * byteStrideC),
UInt64(truncatingIfNeeded: i * byteStrideD))
}
let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride
assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4)
encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
}
}
gridSize.depth = gridZ
encoder.dispatchThreadgroups(
gridSize, threadsPerThreadgroup: groupSize
)
encoder.endEncoding()
}
var commandBuffer = commandQueue.makeCommandBuffer()!
call(bufferA:bufferA, bufferB:bufferB, bufferC:bufferC)
commandBuffer.commit()
commandBuffer = commandQueue.makeCommandBuffer()!
commandBuffer.encodeWaitForEvent(event, value: 2)
call(bufferA:bufferA, bufferB:bufferC, bufferC:bufferD)
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
var contents = bufferC.contents();
var count = B * rowsA * columnsB;
var typedPointer = contents.bindMemory(to: Float.self, capacity: count)
var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
print("First matmul is OK", Array(bufferedPointer))
contents = bufferD.contents();
count = B * rowsA * columnsB;
typedPointer = contents.bindMemory(to: Float.self, capacity: count)
bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
print("This should be filled", Array(bufferedPointer))

View File

@ -1,6 +1,6 @@
use super::*; use super::*;
use half::{bf16, f16}; use half::{bf16, f16};
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; use metal::{Device, MTLResourceOptions};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T; let ptr = buffer.contents() as *const T;
@ -312,7 +312,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
&device, &device,
command_buffer, command_buffer,
&kernels, &kernels,
"affine_float", "affine_f32",
size, size,
&input, &input,
&output, &output,
@ -346,7 +346,7 @@ fn run_affine_strided<T: Clone>(
&device, &device,
command_buffer, command_buffer,
&kernels, &kernels,
"affine_float_strided", "affine_f32_strided",
shape, shape,
&input, &input,
strides, strides,
@ -485,73 +485,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
read_to_vec(&dst_buffer, dst_el) read_to_vec(&dst_buffer, dst_el)
} }
#[test]
fn index_add() {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
let dst_dim_size: u32 = 15;
let left_size: u32 = 3;
let right_size: u32 = 3;
let function = library.get_function("ia_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let index_buffer = new_buffer(&device, &index);
let inputs_buffer = new_buffer(&device, &left);
let outputs_buffer = new_buffer(&device, &right);
set_params!(
encoder,
(
&index_buffer,
&inputs_buffer,
&outputs_buffer,
ids_dim_size,
left_size,
dst_dim_size,
right_size
)
);
let grid_size = MTLSize {
width: right.len() as NSUInteger,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: pipeline.max_total_threads_per_threadgroup(),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result: Vec<f32> = read_to_vec(&outputs_buffer, right.len());
assert_eq!(result, expected);
}
#[test] #[test]
fn cos_f16() { fn cos_f16() {
let v: Vec<f16> = [1.0f32, 2.0, 3.0] let v: Vec<f16> = [1.0f32, 2.0, 3.0]
@ -574,12 +507,15 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
call_reduce_contiguous( let dims = vec![v.len()];
let strides = vec![1];
call_reduce_strided(
&device, &device,
command_buffer, command_buffer,
&kernels, &kernels,
name, name,
v.len(), &dims,
&strides,
out_length, out_length,
&input, &input,
0, 0,
@ -608,6 +544,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
v.len(), v.len(),
last_dim, last_dim,
&input, &input,
0,
&output, &output,
) )
.unwrap(); .unwrap();
@ -622,7 +559,7 @@ fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1; let out_length = 1;
let results = run_reduce(&v, out_length, "fast_sum_float"); let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![21.0]); assert_eq!(approx(results, 4), vec![21.0]);
} }
@ -631,7 +568,7 @@ fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2; let out_length = 2;
let results = run_reduce(&v, out_length, "fast_sum_float"); let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![6.0, 15.0]); assert_eq!(approx(results, 4), vec![6.0, 15.0]);
} }
@ -639,7 +576,7 @@ fn reduce_sum2() {
fn softmax() { fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 6; let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_float"); let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!( assert_eq!(
approx(results, 4), approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
@ -651,7 +588,7 @@ fn softmax() {
for i in 0..n { for i in 0..n {
v[i * last_dim] = 20.0; v[i * last_dim] = 20.0;
} }
let results = run_softmax(&v, last_dim, "softmax_float"); let results = run_softmax(&v, last_dim, "softmax_f32");
let results = approx(results, 4); let results = approx(results, 4);
println!("{results:?}"); println!("{results:?}");
assert_eq!( assert_eq!(
@ -665,7 +602,7 @@ fn softmax() {
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
let last_dim = 6; let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_float"); let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!( assert_eq!(
approx(results, 4), approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
@ -673,7 +610,7 @@ fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 3; let last_dim = 3;
let results = run_softmax(&v, last_dim, "softmax_float"); let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!( assert_eq!(
approx(results, 4), approx(results, 4),
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
@ -684,7 +621,7 @@ fn softmax() {
.map(|v| f16::from_f32(*v)) .map(|v| f16::from_f32(*v))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let last_dim = 6; let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_half"); let results = run_softmax(&v, last_dim, "softmax_f16");
assert_eq!( assert_eq!(
approx_f16(results, 4), approx_f16(results, 4),
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
@ -695,7 +632,7 @@ fn softmax() {
.map(|v| bf16::from_f32(*v)) .map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let last_dim = 6; let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_bfloat"); let results = run_softmax(&v, last_dim, "softmax_bf16");
assert_eq!( assert_eq!(
approx_bf16(results, 4), approx_bf16(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328] vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]

View File

@ -19,7 +19,9 @@ METAL_FUNC uint get_strided_index(
} }
template <typename T> METAL_FUNC T sqr(T in){ return in * in; } template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T recip(T in){ return T(1.0 / in); }
template <typename T> METAL_FUNC T neg(T in){ return -in; } template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T erf(T in){ template <typename T> METAL_FUNC T erf(T in){
float x = (float) in; float x = (float) in;
// constants // constants
@ -57,19 +59,17 @@ template <typename T> METAL_FUNC T gelu(T x) {
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta))); return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
} }
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \ kernel void FN_NAME( \
constant size_t &dim, \ constant size_t &dim, \
device const TYPENAME *input, \ device const TYPENAME *input, \
device TYPENAME *output, \ device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \ uint tid [[ thread_position_in_grid ]] \
) { \ ) { \
if (thread_position_in_grid >= dim) { \ if (tid >= dim) { \
return; \ return; \
} \ } \
output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \ output[tid] = TYPENAME(FN(float(input[tid]))); \
}\ }\
kernel void FN_NAME_STRIDED( \ kernel void FN_NAME_STRIDED( \
constant size_t &dim, \ constant size_t &dim, \
@ -78,20 +78,20 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \ constant size_t *strides, \
device const TYPENAME *input, \ device const TYPENAME *input, \
device TYPENAME *output, \ device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \ uint tid [[ thread_position_in_grid ]] \
) { \ ) { \
if (thread_position_in_grid >= dim) { \ if (tid >= dim) { \
return; \ return; \
} \ } \
output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \ output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \
} }
#define UNARY_OP(NAME) \ #define UNARY_OP(NAME) \
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \ UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \
UNARY(NAME, half, NAME##_half, NAME##_half_strided); UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
#define BFLOAT_UNARY_OP(NAME) \ #define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided); UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
UNARY_OP(cos) UNARY_OP(cos)
@ -102,17 +102,24 @@ UNARY_OP(neg)
UNARY_OP(exp) UNARY_OP(exp)
UNARY_OP(log) UNARY_OP(log)
UNARY_OP(gelu) UNARY_OP(gelu)
UNARY_OP(abs)
UNARY_OP(ceil) UNARY_OP(ceil)
UNARY_OP(floor) UNARY_OP(floor)
UNARY_OP(round) UNARY_OP(round)
UNARY_OP(gelu_erf) UNARY_OP(gelu_erf)
UNARY_OP(erf) UNARY_OP(erf)
UNARY_OP(tanh) UNARY_OP(tanh)
UNARY(id, float, copy_float, copy_float_strided) UNARY_OP(recip)
UNARY(id, half, copy_half, copy_half_strided)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided) UNARY(id, uint32_t, copy_u32, copy_u32_strided)
#if __METAL_VERSION__ >= 220
UNARY(id, int64_t, copy_i64, copy_i64_strided)
#endif
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin) BFLOAT_UNARY_OP(sin)
@ -128,6 +135,7 @@ BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif #endif

View File

@ -11,7 +11,7 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } candle = { workspace = true }
half = { workspace = true } half = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
@ -20,7 +20,7 @@ rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
metal = { workspace = true, optional = true } metal = { workspace = true, optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } candle-metal-kernels = { workspace = true, optional = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }

View File

@ -1,4 +1,4 @@
use candle::Tensor; use candle::{Result, Tensor};
use serde::Deserialize; use serde::Deserialize;
#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)] #[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)]
@ -21,7 +21,7 @@ pub enum Activation {
} }
impl super::Module for Activation { impl super::Module for Activation {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self { match self {
Self::Gelu => xs.gelu_erf(), Self::Gelu => xs.gelu_erf(),
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78 // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
@ -40,3 +40,60 @@ impl super::Module for Activation {
} }
} }
} }
#[derive(Clone, Debug)]
pub struct PReLU {
weight: Tensor,
is_scalar: bool,
}
impl PReLU {
pub fn new(weight: Tensor, is_scalar: bool) -> Self {
Self { weight, is_scalar }
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn is_scalar(&self) -> bool {
self.is_scalar
}
}
impl candle::Module for PReLU {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let weight = if self.is_scalar {
self.weight.reshape(())?
} else if xs.rank() >= 2 {
let num_channels = xs.dim(1)?;
let num_weights = self.weight.elem_count();
if num_weights != num_channels {
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
}
let mut s = vec![1; xs.rank()];
s[1] = self.weight.elem_count();
self.weight.reshape(s)?
} else {
self.weight.clone()
};
let zeros = xs.zeros_like()?;
xs.maximum(&zeros)? + xs.minimum(&zeros)?.broadcast_mul(&weight)?
}
}
/// Create or initialize a new PReLU layer.
///
/// This uses some default name for weights, namely `"weight"`.
/// # Arguments
///
/// * `num_channels` - The number of channels. Use `None` to have as single trainable value and
/// `Some` for a 1D vector with the appropriate number of channels. When applying the `forward`
/// function, the input tensor shape `s` should either be one dimension with this number of
/// channels or if `s.len() >= 2` it should have `s[1]` equal to this number.
pub fn prelu(num_channels: Option<usize>, vs: crate::VarBuilder) -> Result<PReLU> {
let init_ws = crate::init::Init::Const(0.25);
// When using a scalar weight, the PyTorch encoding is to use a 1d vector of length 1.
let ws = vs.get_with_hints((num_channels.unwrap_or(1),), "weight", init_ws)?;
Ok(PReLU::new(ws, num_channels.is_none()))
}

Some files were not shown because too many files have changed in this diff Show More