Compare commits

...

118 Commits

Author SHA1 Message Date
a3dd87f15e Adding Gemm and ArgMax operators to candle-onnx (#2231)
* feat(gemm): implement Gemm operator in candle-onnx

* feat(onnx): Add support for ArgMax operator in candle-onnx

* Apply rustfmt.

* Remove argmax as it was already present.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-06-28 21:40:31 +02:00
242e006bbb Depth Anything v2 (#2279)
* define structs

* construct ResidualConvUnit

* forward() for ResidualConvUnit

* implement FeatureFusionBlock

* implement Scratch

* implement DPTHead

* add identity module

* implement forward for DTPHead

* add get_intermediate_layers to DinoVisionTransformer

* implement DepthAnythingV2

* some minor tweaks

* fix compile errors

* fix var builder prefixes

* setup initial example

* use fixed patch size of 37 (518 / 14)

* debugged until output

* print min and max values

* add some dynamism to the output location

* scale input image

* extract prep function

* extract output path function

* normalize image with magic mean and std

* add spectral coloring

* squeeze in the right place

* make enterpolation optional

* use bail instead of panic

* omit unnecessary Shape call

* remove empty curly braces

* use bail instead of assert

* use vb and pp

* remove closures

* extract config object

* Apply rustfmt.

* Fix some clippy lints.

* More lints.

* Use the array methods.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-06-24 19:12:52 +02:00
6baa1d486b Fix a bug in the metal implemtation of col2im1d. (#2284) 2024-06-22 23:21:20 +02:00
36cf54525d Fix the fast bf16 gemm cublas kernels. (#2274)
* Use flash-attn in gemma.

* Fix for the fast bf16 cublas gemm.

* Fix some clippy lints.

* Fix another lint.

* Proper clippy fix.
2024-06-18 23:46:58 +02:00
2b10aaa05d implement Slice op (#2260) 2024-06-12 07:15:32 +01:00
9f804af29d feat(ci): add trufflehog secrets detection (#2262)
* feat(ci): add trufflehog secrets detection

* fix(ci): remove unnecessary permissions
2024-06-10 21:03:54 +01:00
54ff971e35 Support for the new Qwen2 models. (#2257)
* Support for the new Qwen2 models.

* Add more models.
2024-06-07 10:51:50 +01:00
b9fac7ec00 implement if, and pad reflect mode (#2251)
* implement if, and pad reflect mode

The intent of this change is to allow eval of the current silero_vad.onnx (v4).
This onnx file uses 'If' and 'Pad' nodes, which had not been supported
by simple_eval until now

* Cleanup (fmt, clippy, minor test tweaks).

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-06-06 22:36:23 +02:00
f65e90e7ef Bump the crate version. (#2248) 2024-06-05 15:49:15 +02:00
d39462856b Apply rustfmt. (#2247) 2024-06-04 22:54:09 +02:00
cb180eb23a ONNX: add ArgMin, ArgMax and LeakyRelu (#2246)
* Add basic RandomUniform implementation

* Use is_some to check if seed is present

* Added Exp operator implementation

* Added ArgMin operator implementation

* Added tests for ArgMin

* ArgMin now returns a tensor with i64

* Added tests from pytorch examples

* Added ArgMax operator implementation

* Added tests for ArgMax

* Added LeakyRelu implementation

* Added a test for LeakyRelu

* Typo fix

* Fix a weird automatic RustRover change

---------

Co-authored-by: Mateusz Okulus <mmokulus@gmail.com>
2024-06-04 22:49:02 +02:00
9182c828e6 Automatically upcast for to_u64 (#2244) 2024-06-04 11:32:36 +02:00
3f13ad3d79 Fix dataset id for MNIST (#2238) 2024-06-04 06:27:24 +02:00
cd4d941ed1 Add LLaVA support (#2234)
* first commit

* llava

* clippy and fmt

* some fixes

* minor fixes

* remove useless file

* refactor: Remove llava/constants.rs and update llava/mod.rs

* modify variable name

* modify code after clippy

* Minor tweaks.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-06-03 11:54:09 +02:00
03344d3c19 ONNX: Add Floor and Ceil (#2235) 2024-06-02 21:45:20 +02:00
1ec3b2cc18 add where_cond f32 for metal (#2236) 2024-06-02 14:30:06 +02:00
f7773d498a Deactivate some book test that breaks the CI. (#2233)
* Deactivate some book test that breaks the CI.

* Clippy fix.
2024-06-01 09:44:22 +02:00
7abc3b8cd7 Bump cudarc version to 0.11.4 (#2230) 2024-06-01 08:18:35 +02:00
46012ed31f Another cudarc update. (#2229) 2024-05-30 22:27:06 +02:00
f3fade3b03 Update cudarc to 0.11.2. (#2227) 2024-05-29 18:50:52 +02:00
ea260aeffd Add Debug, Clone, Deserialize to moondream config (#2222) 2024-05-28 06:08:00 +02:00
0814dfd148 Add a metal kernel for col2im1d. (#2214)
* Add a metal kernel for col2im1d.

* Enable the col2im variant.

* Bugfix.

* Revert the quantized tweak.
2024-05-25 11:03:23 +02:00
3ceca9901a Enable the new layer-norm. (#2213)
* Enable the new layer-norm.

* Shape fixes.
2024-05-24 16:48:21 +02:00
1df2bddccf Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels.

* Dedicated layer norm op.

* Add the slower variant.

* Plug the cuda implementation.

* Add the metal variant.

* Add a dedicated test.

* Bugfix.
2024-05-24 15:58:01 +02:00
6f0b807ffd More efficient cuda implementation for ConvTranspose1d. (#2211)
* More efficient cuda implementation for ConvTranspose1d.

* Small tweak.
2024-05-24 11:05:43 +02:00
d54e02d73d Avoid a contiguous call in the quantized phi 3 model. (#2209)
* Simplify the KvCache api.

* Avoid a contiguous call in the quantized phi3 model.
2024-05-23 21:24:55 +02:00
45e235a747 Simplify the KvCache api. (#2207) 2024-05-23 17:07:21 +02:00
31cf64147b Add a couple kv-cache helper functions. (#2206) 2024-05-23 16:21:47 +02:00
77ea479a18 Add Phi-3 Medium (#2205) 2024-05-23 13:33:17 +02:00
72e7ca529a Add some missing where-cond kernels for metal. (#2203) 2024-05-22 09:44:52 +02:00
7ff921c538 Add RandomNormal ONNX operator (#2200) 2024-05-21 21:47:32 +02:00
9b8537a62f Remove the deprecated wav crate in favor of hound. (#2202) 2024-05-21 21:43:35 +02:00
7ebc3548e1 Use flash-attn in gemma. (#2195)
* Use flash-attn in gemma.

* Fix flash-attn for head dim 256.
2024-05-18 19:18:59 +02:00
eefc1c77ef Support flash-attn in quantized phi3. (#2194) 2024-05-18 17:12:56 +02:00
01545f7303 Add a slice_set op. (#2193)
* Add a slice_set op.

* Add some testing.

* Add the dedicated kv-cache module.

* Derive debug and clone.

* Expose more kv-cache functions.

* Return the current data when appending.

* Use the new cache in the quantized phi3 model.
2024-05-18 15:58:18 +02:00
349c3e806a Support embedding model gte-Qwen1.5-7B-instruct (#2190)
* Support embedding model gte-Qwen1.5-7B-instruct

This is a text embedding model based on Qwen2. They share same
model architecture except the last MLP module. This commit brings in
minimal modification of the old Qwen2 implementation to support both
models.

An example is provided, and had been verified according to the official
PyTorch implementation.

* Avoid doing the 'last-token filtering' based on the absence of attention mask.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-05-16 21:34:10 +02:00
bdaa34216a chore: add fix for windows cudarc into the readme (#2189) 2024-05-16 14:32:50 +02:00
cc80e065e5 Allow the threshold argumet to be negative in the segment-anything example (#2187)
Threshold is 0.0 by default, negative values make more points included,
expanding the mask. Positive values make it more picky, making the mask
smaller.

Negative numbers start with a minus sign, which normally makes clap
consider it a flag.
2024-05-15 13:17:20 +02:00
13c64f6828 Fix VarBuilder::from_slice_safetensors (#2180)
Also implement SimpleBackend for SliceSafetensors

Signed-off-by: Harry Stern <harry@harrystern.net>
2024-05-12 07:26:06 +02:00
21f82a5155 Add SliceSafetensors. (#2179)
* Add SlicedSafetensors.

* And add some testing.
2024-05-11 13:15:42 +02:00
9cff7bc3f4 Make it possible to use TF32 accumulation in F32 matmuls. (#2178)
* Allow the use of tf32 accumulation in matmul.

* Better timings.

* Dummy versions for use when cuda is not enabled.
2024-05-11 12:28:39 +02:00
d9bc5ec151 Switch cudarc back to dynamic linking. (#2176) 2024-05-09 10:35:44 +02:00
84328e2b60 Update cudarc requirement from 0.11.0 to 0.11.1 (#2174)
* Upgrading cudarc dependency from v0.11.0 to v0.11.1 due to that version having resolved a compile-time bug.

See: https://github.com/huggingface/candle/issues/2173
2024-05-08 20:40:36 +02:00
82b641fd27 Update cudarc requirement from 0.10.0 to 0.11.0 (#2165)
* Update cudarc requirement from 0.10.0 to 0.11.0

Updates the requirements on [cudarc](https://github.com/coreylowman/cudarc) to permit the latest version.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.10.0...v0.10.0)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* Use the default cuda version.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-05-06 17:12:14 +02:00
01794dc16e Use write rather than try-write on the metal rw-locks. (#2162) 2024-05-05 07:22:46 +02:00
a75cd8164f Force the revision for the phi3-llama quantized models. (#2159) 2024-05-04 10:41:18 +02:00
b13a82a438 Separate quantized phi-3 implementation. (#2157)
* Separate quantized phi-3 implementation.

* Integrate the quantized phi3 model.=

* Small fixes, get the generation to work properly.

* Keep the old llama implementation around.

* Change the default.
2024-05-04 10:14:57 +02:00
59b18d974e Pin the version used for the quantized phi 3 gguf file. (#2156) 2024-05-03 15:03:22 +02:00
89f53b9d7b Bump the version number to 0.5.1. (#2155)
* Bump the version number to 0.5.1.

* Fix clippy lints for 1.78.

* More clippy fixes.
2024-05-03 11:17:05 +02:00
a09d451d11 Support top-k in tthe llama example. (#2150) 2024-05-01 22:25:47 +02:00
fa06f5f5f9 F16/BF16 bugfix (bis). (#2143)
* F16/BF16 bugfix (bis).

* Another fix.

* Yet another fix.
2024-04-29 14:08:44 +02:00
09d4845aa8 Bugfix the recent f16/bf16 changes. (#2142) 2024-04-29 13:30:11 +02:00
a0d03aded1 Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. (#2124)
* When converting a tensor to a variable, clone if the tensor is already a variable.

* Add a test to ensure training a batch norm works with VarMaps

---------

Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
2024-04-29 11:21:53 +02:00
3bbb88fcb4 Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op

* small fix

* add as a method on `Tensor`

* implement gradient calculation for sigmoid

* add sigmoid tests

* we should have a specialized op for this

* fix clippy

* fix clippy 2

* Revert all previous commits in favor of a `CustomOp` based solution

* use `CustomOp1` implementation

* fix rustfmt

* experimental add metal impl

* add cuda kernel impl

* fix fmt

* Add a test + reduce some cuda duplication.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-29 11:04:43 +02:00
ed7b99f525 Add a toggle for F16/BF16 accumulation in gemm. (#2141)
* Add a toggle to control f16/bf16 gemm precision.

* Use the faster variant in the quantized example.

* Bugfix.
2024-04-29 09:21:07 +02:00
287013ef28 Add a forward_via_f16 method to the qmatmul op. (#2138) 2024-04-28 20:35:01 +02:00
eb26e2467e Add the cuda dequantize f16 kernels. (#2137)
* Add the cuda dequantize f16 kernels.

* Expose the cuda kernels.

* Add some testing + fix.

* Test the other cases too.

* A few more tests.

* Add an environment variable to enable the dequantize f16 + matmul behavior.
2024-04-28 20:05:05 +02:00
c68ed8963f chore: fix some typos in comments (#2121)
Signed-off-by: hardlydearly <799511800@qq.com>
2024-04-28 08:34:32 +02:00
e5c8b88f90 Apply the cast before the scaling. (#2135) 2024-04-28 08:30:35 +02:00
805f3be8e1 Add a sort function. (#2134) 2024-04-28 08:18:04 +02:00
3b429f3023 Make the dtype configurable for phi. (#2133) 2024-04-27 21:32:49 +02:00
96a48e5cc4 Add argsort. (#2132)
* Add the argsort cuda kernels.

* CPU version of arg-sort.

* Hook the cuda kernel + rework the cpu bits.

* Add some dedicated test.

* Working cuda kernel.

* Metal kernel.

* Metal adjustments.

* Bugfix.

* Use the fast rope in qwen.

* Rework the expert selection in qwen.
2024-04-27 20:17:35 +02:00
6cf82fd7a3 Add Olmo models (#2127)
* add olmo support

* add olmo readme

* Fix fmt.

* Fix clippy.

* Get olmo to work on cuda.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-26 11:02:51 +02:00
cfab6e7616 Mention phi-v3 in the readmes. (#2122) 2024-04-24 20:54:24 +02:00
11d4a3c588 Add the phi-3 model. (#2120)
* Add the phi-3 model.

* Faster rope.

* Bugfix.

* Fix the detokenization.
2024-04-24 09:48:13 +02:00
9d3f1c8af5 Add the phi-v3 quantized model. (#2118)
* Add the phi-v3 quantized model.

* Also include phi-3 in the main phi example.
2024-04-24 08:22:23 +02:00
7211009179 Fix for rustfmt. (#2117) 2024-04-23 19:09:33 +02:00
6fadaf2eff candle-onnx: add operators RandomUniform and Exp (#2116)
* Add basic RandomUniform implementation

* Use is_some to check if seed is present

* Added Exp operator implementation

---------

Co-authored-by: Mateusz Okulus <mmokulus@gmail.com>
2024-04-23 19:02:19 +02:00
8a05743a21 Add StorageRef. (#2113)
* Add the storage-ref bits.

* Add the metal implementation.
2024-04-23 13:23:27 +02:00
b2e816752b Use the faster rms-norm kernel for llama. (#2107)
* Use the faster rms-norm kernel for llama.

* Use the fast variant by default.
2024-04-22 18:52:00 +02:00
618ecf5e23 Better time measurement for the llama example. (#2106) 2024-04-22 17:54:27 +02:00
267601eec1 Update tokenizers requirement from 0.15.0 to 0.19.1 (#2104)
Updates the requirements on [tokenizers](https://github.com/huggingface/tokenizers) to permit the latest version.
- [Release notes](https://github.com/huggingface/tokenizers/releases)
- [Changelog](https://github.com/huggingface/tokenizers/blob/main/RELEASE.md)
- [Commits](https://github.com/huggingface/tokenizers/compare/v0.15.0...v0.15.2)

---
updated-dependencies:
- dependency-name: tokenizers
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-22 17:10:46 +02:00
08a15cb79e Update zip requirement from 0.6.6 to 1.1.1 (#2103)
* Update zip requirement from 0.6.6 to 1.1.1

---
updated-dependencies:
- dependency-name: zip
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* Fix for the zip crate update.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-22 16:23:27 +02:00
c388be93e7 Updated quantized phi model (#2099)
* Quantized phi in a separate file.

* Add the quantized phi example + rework the model code.

* Improve the phi model.

* Get some generation out.

* Use the appropriate rope shape.

* Tweak the default prompt.

---------

Co-authored-by: Jane Doe <jane.doe@example.org>
2024-04-21 07:37:07 +02:00
d22f1d4f4e Derive clone and debug traits for Moondream model (#2100)
* moondream implementation

* add moondream example

* change config default activation

* Add assets and integrate phi mixformer with example

* Make use of kv cache and fix seq_len bug; Clean up example code

* Add README link to example

* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig

* Delete image

* Use apply instead of forward

* Use latest release special token; Fix token/s accuracy; Use GeluPytorchTanh in VisionConfig v2

* Derive debug and clone traits for Moondream model.
2024-04-21 07:08:28 +02:00
0067fe00a8 Metal Unary: Add benchmarks and process kernels in a tile based fashion (#2056)
* add basic unary bench for sqrt

* process unary commands in tiles of 4

* re-enable all benchmarks

* rename helper to unary

* modify approach to split up tiled and non-tiled operations

* undo bench ignore for other tests

* update tile size to 2

* only perform the optimization on the contiguous even numbered element case
2024-04-21 00:10:33 +02:00
587ee3bb6f Small cleanups to the llama multi-process example. (#2098) 2024-04-20 22:19:46 +02:00
dd78422701 Handle multiple dimensions in metal QMM + two fixes. (#2097) 2024-04-20 18:55:45 +02:00
9215e9ce8c Add missing onnx operations (#2096)
* Add missing onnx operations

* Add tests and fix errors

* Run rustfmt
2024-04-20 18:44:22 +02:00
52ae332910 Use llama v3 by default + add to readme. (#2094) 2024-04-20 16:11:24 +02:00
8b390ddd29 Only download the weights in the main process (and not in the child processes). (#2093) 2024-04-20 13:01:23 +02:00
c97d639fa0 Multiprocess/multi-GPU support for llama 3. (#2092)
* Multiprocess/multi-GPU support for llama 3.

* Modernize the mp example a bit.
2024-04-20 12:49:21 +02:00
b45c710dbf Fix for gemma MQA. (#2091) 2024-04-19 21:49:55 +02:00
9c532aef47 Also enable llama-v3 8b instruct. (#2088) 2024-04-19 08:50:06 +02:00
f7a6468238 Add support for llama3 on the quantized example (#2086)
* add support for l3b, new tokenizer

* add todo

* Add todo and use k_s model

* Use the official tokenizers.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-18 22:52:00 +02:00
2b93dffe64 Use faster rotary embeddings for llama like models. (#2087) 2024-04-18 22:34:29 +02:00
e6ee7ba4d4 Llama v3. (#2085)
* Llama v3.

* Tweak the default params + handle special tokens.

* Small tweak.
2024-04-18 22:19:54 +02:00
1690ab45d2 Fix the silu gradient issue on 0. (#2083) 2024-04-18 14:31:41 +02:00
8de0ce6cba Add more QMMV cuda kernels. (#2077)
* Add more QMMV cuda kernels.

* Enable the new kernels.

* Adapt the testing.
2024-04-18 08:36:43 +02:00
ce6d08df94 Minor fix to the readme. (#2080)
Co-authored-by: Jane Doe <jane.doe@example.org>
2024-04-17 22:43:00 +02:00
2817643db9 Add the mmv kernels for small batch sizes. (#2075)
* Add the mmv kernels for smaller sizes.

* Support more mmv kernels.

* Use the new kernels.

* Fix the call.

* Silly fix.

* Improve the testing.

* Fix for dmmv.

* Add another dedicated test for the batching mmv.
2024-04-16 21:30:51 +02:00
4d14777673 Utilize batches in Stable Diffusion (#2071)
* Utilize batches in Stable Diffusion that were already there, but unutilized.

Also refactor out the `save_image` function.

* Clippy + cosmetic fixes.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-16 06:49:04 +02:00
f135b7963d Fix for the batch dim in the quantized matmul example. (#2073)
* Fix for the batch dim in the quantized matmul example.

* Enable more tests on cuda.

* Add a test for qmm with a batch.

* Fix the zeros-dim test on metal.
2024-04-15 20:00:28 +02:00
af955f260c Make the falcon model cloneable. (#2067) 2024-04-15 09:39:03 +02:00
8ad822a983 Add a function to clear the KV cache in falcon. (#2066)
* Add a function to clear the KV cache in falcon.

* Clippy.
2024-04-15 09:29:25 +02:00
e198bb0816 Handle zero dims in some simple operations. (#2064)
* Handle zero dims in some simple operations.

* Handle zero-dims in matmul.

* More testing.
2024-04-15 09:18:54 +02:00
f7d5bf5b97 Faster kernels for quantized matmul on cuda (#2060)
* Hook the quantized matmul cuda kernels.

* Add a (currently broken) test.

* Kernel fixes.

* Fix by transposing the rhs matrix.

* Add the q4-1 kernels.

* Proper block sizes.

* More details in the tests.
2024-04-15 08:32:47 +02:00
c119600d6e Move image tensor to device in trocr example (#2063)
Signed-off-by: Harry Stern <harry@harrystern.net>
2024-04-15 06:50:32 +02:00
c449f65b12 Expose the synchronize function on the generic device. (#2062) 2024-04-14 23:02:03 +02:00
db7dbf3071 Add missing bfloat unary strided kernels and fix typo (#2058) 2024-04-14 20:01:13 +02:00
4ecedb1598 Add the full quantized matmul kernels for cuda. (#2057) 2024-04-14 17:52:08 +02:00
53e5380bf6 Add a synchronize method to devices. (#2055)
* Add a synchronize method to devices.

* Metal version.
2024-04-14 16:32:55 +02:00
50e49ecc5f Add a quantized version of recurrent-gemma. (#2054)
* Add a quantized version of recurrent-gemma.

* Share the rglru part.

* Get the quantized gemma model to work.
2024-04-13 20:07:01 +02:00
4c88c3ce06 Add benchmarks for qmatmul operations (#2048)
* Add qmatmul bench

* add all dtypes
2024-04-13 12:30:14 +02:00
8b8fb630df Add a convenient way to rename tensors accessed through a varbuilder. (#2052) 2024-04-13 12:09:41 +02:00
fb805b8ca2 Avoid crashes when running T5 models with F16 tensors on CPU (#2047)
* This change avoids crashes when running T5 models with F16 tensors on CPU.

* This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point.

* Revert "This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point."

This reverts commit d886d3ce5e.
2024-04-13 11:07:28 +02:00
79e3bec789 Change for the encoder-only ProstT5 model (#2045)
* This change avoids crashes when running T5 models with F16 tensors on CPU.

* This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN.  This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point.
2024-04-13 11:06:24 +02:00
e6d412b156 Add ReduceMean onnx operation (#2049)
* Add ReduceMean onnx operation

* Format code with rustfmt
2024-04-13 11:00:25 +02:00
26cbbf8d84 Mandatory topk sampling for recurrent-gemma. (#2051) 2024-04-13 10:31:39 +02:00
2bf413caa3 Add the recurrent-gemma model. (#2039)
* Start adding the recurrent-gemma model.

* More griffin.

* Add the example + get the weights to load from the HF version.

* More inference code.

* Rope + kv-cache on the attention side.

* Add to the inference code.

* Add more to the recurrent gemma inference.

* Get some first inference to run.

* Add the softcap on logits.

* Fixes.

* Use partial rotary embeddings.

* Get inference to work.

* Add a comment.

* And add a readme.
2024-04-13 00:05:21 +02:00
3ad4770eb6 Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation.

* Move the function to utils + use it in mistral.

* Use the shared repeat-kv in a few more models.

* Fix.
2024-04-12 09:15:10 +02:00
a0460cd2b1 Add the code-gemma models. (#2038)
* Add the code-gemma models.

* Tweak to the gemma config.
2024-04-10 21:19:21 +02:00
b81ecf712d Support alternative dtypes for mamba (#2036)
* Allow different dtypes in mamba.

* Add a dtype flag.
2024-04-10 18:10:01 +02:00
a4d5a414e3 Support gather on bf16 for metal. (#2035) 2024-04-10 12:49:25 +02:00
798e0335cd Handle more tensor shapes in onnx "Gather" operation (#2026)
* Handle more tensor shapes in onnx "Gather" operation

* Add more tests

* Add comment

* Fix typo
2024-04-08 14:06:14 +02:00
718671a0d5 Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend.

* More BufferOffset usage.

* Use in where-cond.
2024-04-08 09:37:25 +02:00
c5fe4a7f89 Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module.

* Use the BufferOffset for unary ops.

* Fix clippy lints.

* Use the new BufferOffset.

* Adapt the binary ops.

* Affine.

* More ops (powf, elu, cast).
2024-04-07 22:37:53 +02:00
7f354473cf Optimize copy-2d for metal. (#2024)
* Optimize copy-2d for metal.

* Add a hacky stopping rule for moondream.
2024-04-07 12:34:16 +02:00
148 changed files with 15360 additions and 1673 deletions

15
.github/workflows/trufflehog.yml vendored Normal file
View File

@ -0,0 +1,15 @@
on:
push:
name: Secret Leaks
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.5.0"
version = "0.6.0"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,21 +33,22 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.5.0" }
candle-datasets = { path = "./candle-datasets", version = "0.5.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.0" }
candle-kernels = { path = "./candle-kernels", version = "0.5.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.0" }
candle-nn = { path = "./candle-nn", version = "0.5.0" }
candle-onnx = { path = "./candle-onnx", version = "0.5.0" }
candle-transformers = { path = "./candle-transformers", version = "0.5.0" }
candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" }
candle-datasets = { path = "./candle-datasets", version = "0.6.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" }
candle-kernels = { path = "./candle-kernels", version = "0.6.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" }
candle-nn = { path = "./candle-nn", version = "0.6.0" }
candle-onnx = { path = "./candle-onnx", version = "0.6.0" }
candle-transformers = { path = "./candle-transformers", version = "0.6.0" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] }
cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
@ -65,13 +66,12 @@ serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.15.0", default-features = false }
tokenizers = { version = "0.19.1", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
[profile.release-with-debug]

View File

@ -60,12 +60,14 @@ These online demos run entirely in your browser:
We also provide a some command line based examples using state of the art models:
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
Deepmind.
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
Griffin based models from Google that mix attention with a RNN like state.
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
2.7b, and 3.8b general LLMs with performance on par with 7b models.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
@ -110,7 +112,7 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmentation model.
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
model using residual vector quantization.
@ -199,10 +201,10 @@ If you have an addition to this list, please submit a pull request.
- WASM support, run your models in a browser.
- Included models.
- Language Models.
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
- LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.
- Falcon.
- StarCoder, StarCoder2.
- Phi 1, 1.5, and 2.
- Phi 1, 1.5, 2, and 3.
- Mamba, Minimal Mamba
- Gemma 2b and 7b.
- Mistral 7b v0.1.
@ -374,9 +376,9 @@ git submodule update --init
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ...:
```
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
```
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
```
#### Linking error on windows when running rustdoc or mdbook tests
@ -406,3 +408,10 @@ This may be caused by the models being loaded from `/mnt/c`, more details on
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
error is generated.
#### CudaRC error
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`

View File

@ -37,7 +37,6 @@ tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
parquet = { workspace = true }
image = { workspace = true }

View File

@ -81,7 +81,7 @@ let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];
if size % world_size != 0 {
panic!("The dimension is not divisble by `world_size`");
panic!("The dimension is not divisible by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
@ -106,8 +106,8 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
}
}
#[allow(unused)]
#[rustfmt::skip]
#[test]
fn book_training_1() -> Result<()>{
// ANCHOR: book_training_1
use hf_hub::{api::sync::Api, Repo, RepoType};

View File

@ -7,4 +7,6 @@ criterion_main!(
benchmarks::random::benches,
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,
benchmarks::unary::benches
);

View File

@ -1,7 +1,9 @@
pub(crate) mod affine;
pub(crate) mod conv_transpose2d;
pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;
pub(crate) mod unary;
pub(crate) mod where_cond;
use candle_core::{Device, Result};

View File

@ -0,0 +1,72 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{
quantized::{self, GgmlDType, QMatMul},
Device, Module, Tensor,
};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(matmul: &QMatMul, x: &Tensor) {
matmul.forward(&x).unwrap();
}
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
let b = 1;
let m = 1;
let n = 1024;
let k = 1024;
let lhs = (0..(m * k))
.map(|v| v as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..(k * n))
.map(|v| v as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
group.sample_size(200);
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&matmul), black_box(&lhs));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in vec![
GgmlDType::F32,
GgmlDType::F16,
GgmlDType::Q4_0,
GgmlDType::Q4_1,
GgmlDType::Q5_0,
GgmlDType::Q5_1,
GgmlDType::Q8_0,
GgmlDType::Q2K,
GgmlDType::Q3K,
GgmlDType::Q4K,
GgmlDType::Q5K,
GgmlDType::Q6K,
] {
run_bench(c, &device, dtype);
}
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -0,0 +1,49 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor) {
a.sqrt().unwrap();
}
fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let b = 1;
let m = 1024;
let k = 1024;
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device)
.unwrap()
.to_dtype(dtype)
.unwrap()
.reshape((b, m, k))
.unwrap();
let flops = b * m * k * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in [DType::F32, DType::BF16, DType::F16] {
let name = format!("sqrt_{:?}", dtype);
run_unary_benchmark(c, &device, dtype, &name);
}
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -5,32 +5,29 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Module, Tensor};
use candle_core::quantized::{QMatMul, QTensor};
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
let q_cpu = q.to_device(&Device::Cpu)?;
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
let q = QMatMul::from_qtensor(q)?;
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
let res_q_cuda = q.forward(&x)?;
println!("{res_q_cuda}");
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
let x_cpu = x.to_device(&Device::Cpu)?;
let res_q_cpu = q_cpu.forward(&x_cpu)?;
println!("{res_q_cpu}");
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
.abs()?
.flatten_all()?
.max(0)?;
println!("{diff}");
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("tf32: {:?}", start_time.elapsed());
drop(_x1);
Ok(())
}

View File

@ -133,6 +133,8 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
/// after this call.
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
@ -142,4 +144,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn set_seed(&self, _: u64) -> Result<()>;
/// Synchronize should block until all the operations on the device are completed.
fn synchronize(&self) -> Result<()>;
}

View File

@ -624,7 +624,7 @@ impl Tensor {
Op::Unary(arg, UnaryOp::Silu) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
let sigmoid_arg = (*node / arg)?;
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}

View File

@ -1,6 +1,7 @@
pub mod erf;
pub mod kernels;
#[allow(unused)]
trait Cpu<const ARR: usize> {
type Unit;
type Array;
@ -18,6 +19,7 @@ trait Cpu<const ARR: usize> {
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
}
#[allow(unused)]
trait CpuF16<const ARR: usize> {
type Unit;
type Array;

View File

@ -10,7 +10,7 @@ pub use utils::{
};
const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV1D_TR: bool = true;
const USE_COL2IM_CONV1D_TR: bool = true;
const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
@ -26,6 +26,17 @@ pub enum CpuStorage {
F64(Vec<f64>),
}
#[derive(Debug, Clone)]
pub enum CpuStorageRef<'a> {
U8(&'a [u8]),
U32(&'a [u32]),
I64(&'a [i64]),
BF16(&'a [bf16]),
F16(&'a [f16]),
F32(&'a [f32]),
F64(&'a [f64]),
}
#[derive(Debug, Clone)]
pub struct CpuDevice;
@ -110,7 +121,8 @@ impl ReduceIndex {
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
let dst_to_set = dst.spare_capacity_mut();
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
let dst_to_set =
unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
match src_l.contiguous_offsets() {
Some((o1, o2)) => {
let src = &src[o1..o2];
@ -2238,7 +2250,7 @@ impl BackendStorage for CpuStorage {
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
if USE_IM2COL_CONV1D_TR && can_use_col2im {
if USE_COL2IM_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {
@ -2445,6 +2457,10 @@ impl BackendDevice for CpuDevice {
true
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
Ok(T::to_cpu_storage(s))
}
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
Ok(s.clone())
}
@ -2628,6 +2644,10 @@ impl BackendDevice for CpuDevice {
};
Ok(storage)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}
#[macro_export]

View File

@ -174,7 +174,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
@ -185,7 +187,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec(
@ -224,7 +228,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec(
@ -311,7 +317,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
};
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
@ -333,7 +341,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
} else {
let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
};
let mut dst_index = 0;
for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len];

View File

@ -1,5 +1,5 @@
use crate::backend::BackendDevice;
use crate::{CpuStorage, DType, Layout, Result, Shape};
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
@ -334,6 +334,43 @@ impl BackendDevice for CudaDevice {
})
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
@ -407,4 +444,9 @@ impl BackendDevice for CudaDevice {
device: self.clone(),
})
}
fn synchronize(&self) -> Result<()> {
self.device.synchronize().map_err(crate::Error::wrap)?;
Ok(())
}
}

View File

@ -16,9 +16,9 @@ mod error;
mod utils;
pub use device::{CudaDevice, DeviceId};
pub use error::{CudaError, WrapErr};
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S};
enum SlicePtrOrNull<T> {
pub enum SlicePtrOrNull<T> {
Ptr(CudaSlice<T>),
Null,
}
@ -33,7 +33,7 @@ unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
}
impl SlicePtrOrNull<usize> {
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
let ds = if l.is_contiguous() {
SlicePtrOrNull::Null
} else {
@ -250,44 +250,6 @@ impl Map1 for Powf {
}
}
struct Sum<'a>(&'a [usize]);
impl<'a> Map1 for Sum<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let src_dims = shape.dims();
let el = shape.elem_count();
let mut dst_el = el;
for &sum_dim in self.0.iter() {
dst_el /= src_dims[sum_dim];
}
let mut sum_dims = self.0.to_vec();
// Sort the sum_dims as they have to be processed from left to right when converting the
// indexes.
sum_dims.sort();
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
let sum_dims_s: Vec<usize> = sum_dims
.iter()
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
.collect();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev
.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())
.w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
let out = dev.alloc_zeros::<T>(dst_el).w()?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct FastReduce<'a>(&'a [usize], ReduceOp);
impl<'a> Map1Any for FastReduce<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
@ -668,6 +630,31 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
struct Col2Im1D {
stride: usize,
}
impl Map1 for Col2Im1D {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
col: &CudaSlice<T>,
dev: &CudaDevice,
l: &Layout,
) -> Result<CudaSlice<T>> {
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
let stride = self.stride;
let l_out = (l_in - 1) * stride + k_size;
let dst_el = b_size * c_out * l_out;
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im);
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(im)
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -1404,9 +1391,55 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
const USE_COL2IM_CONV1D_TR: bool = true;
let device = self.device().clone();
let slice =
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
let can_use_col2im = kernel_l.is_contiguous()
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {
crate::bail!(
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
)
}
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
l.shape(),
kernel_l.shape()
)
}
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
kernel_l.start_offset(),
);
self.matmul(
kernel,
(
b_size,
/* m */ l_in,
/* n */ c_out * k_size,
/* k */ c_in,
),
&l.transpose(1, 2)?,
&kernel_l_mm,
)?
};
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
Col2Im1D {
stride: params.stride,
}
.map(&col.slice, &device, &col_l)?
} else {
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?
};
Ok(Self { slice, device })
}
@ -1635,12 +1668,8 @@ impl BackendStorage for CudaStorage {
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}
.w()?;
unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
@ -1648,12 +1677,8 @@ impl BackendStorage for CudaStorage {
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}
.w()?;
unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
@ -1661,12 +1686,8 @@ impl BackendStorage for CudaStorage {
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}
.w()?;
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
@ -1856,3 +1877,203 @@ impl BackendStorage for CudaStorage {
Ok(())
}
}
// Default for the reduced precision setting is false, similar to pytorch.
// https://github.com/pytorch/pytorch/issues/123157
static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn gemm_reduced_precision_f32() -> bool {
MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn set_gemm_reduced_precision_f32(b: bool) {
MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn gemm_reduced_precision_f16() -> bool {
MM_F16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn set_gemm_reduced_precision_f16(b: bool) {
MM_F16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn gemm_reduced_precision_bf16() -> bool {
MM_BF16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn set_gemm_reduced_precision_bf16(b: bool) {
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
}
unsafe fn gemm_strided_batched_f32(
cublas: &cudarc::cublas::CudaBlas,
cfg: StridedBatchedConfig<f32>,
a: &cudarc::driver::CudaView<f32>,
b: &cudarc::driver::CudaView<f32>,
c: &mut CudaSlice<f32>,
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;
let compute_type = if gemm_reduced_precision_f32() {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
} else {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};
let alpha = &cfg.gemm.alpha as *const f32 as *const _;
let beta = &cfg.gemm.beta as *const f32 as *const _;
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
alpha,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.ldb,
cfg.stride_b,
beta,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
compute_type,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
}
unsafe fn gemm_strided_batched_f16(
cublas: &cudarc::cublas::CudaBlas,
cfg: StridedBatchedConfig<f16>,
a: &cudarc::driver::CudaView<f16>,
b: &cudarc::driver::CudaView<f16>,
c: &mut CudaSlice<f16>,
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;
let alpha = cfg.gemm.alpha;
let beta = cfg.gemm.beta;
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
let beta_f32: f32 = cfg.gemm.beta.to_f32();
let (compute_type, alpha, beta) = if gemm_reduced_precision_f16() {
(
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
(&alpha) as *const f16 as *const _,
(&beta) as *const f16 as *const _,
)
} else {
(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
(&alpha_f32) as *const f32 as *const _,
(&beta_f32) as *const f32 as *const _,
)
};
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
alpha,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldb,
cfg.stride_b,
beta,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
compute_type,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
}
unsafe fn gemm_strided_batched_bf16(
cublas: &cudarc::cublas::CudaBlas,
cfg: StridedBatchedConfig<bf16>,
a: &cudarc::driver::CudaView<bf16>,
b: &cudarc::driver::CudaView<bf16>,
c: &mut CudaSlice<bf16>,
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
let beta_f32: f32 = cfg.gemm.beta.to_f32();
// The type for alpha and beta depends on the computeType.
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF,
(&alpha_f32) as *const f32 as *const _,
(&beta_f32) as *const f32 as *const _,
)
} else {
(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
(&alpha_f32) as *const f32 as *const _,
(&beta_f32) as *const f32 as *const _,
)
};
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
alpha,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldb,
cfg.stride_b,
beta,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
compute_type,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
}

View File

@ -54,6 +54,44 @@ pub trait Map2 {
}
}
pub trait Map3 {
#[allow(clippy::too_many_arguments)]
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
src3: &CudaSlice<T>,
layout3: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
#[allow(clippy::too_many_arguments)]
fn map(
&self,
s1: &S,
l1: &Layout,
s2: &S,
l2: &Layout,
s3: &S,
l3: &Layout,
d: &CudaDevice,
) -> Result<S> {
let out = match (s1, s2, s3) {
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
};
Ok(out)
}
}
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,

View File

@ -306,6 +306,20 @@ impl Device {
}
}
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
Device::Cuda(device) => {
let storage = device.storage_from_slice(data)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.storage_from_slice(data)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
@ -337,4 +351,12 @@ impl Device {
}
}
}
pub fn synchronize(&self) -> Result<()> {
match self {
Self::Cpu => Ok(()),
Self::Cuda(d) => d.synchronize(),
Self::Metal(d) => d.synchronize(),
}
}
}

View File

@ -1,7 +1,7 @@
//! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result};
use crate::{CpuStorage, CpuStorageRef, Error, Result};
/// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
@ -100,12 +100,14 @@ pub trait WithDType:
+ 'static
+ Send
+ Sync
+ std::any::Any
+ crate::cpu::kernels::VecOps
{
const DTYPE: DType;
fn from_f64(v: f64) -> Self;
fn to_f64(self) -> f64;
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
@ -129,6 +131,10 @@ macro_rules! with_dtype {
$to_f64(self)
}
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
CpuStorageRef::$dtype(data)
}
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
CpuStorage::$dtype(data)
}

View File

@ -214,6 +214,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -229,4 +233,38 @@ impl crate::backend::BackendDevice for CudaDevice {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn gemm_reduced_precision_f16() -> bool {
true
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn set_gemm_reduced_precision_f16(_: bool) {}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn gemm_reduced_precision_bf16() -> bool {
true
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn gemm_reduced_precision_f32() -> bool {
true
}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn set_gemm_reduced_precision_f32(_b: bool) {}

View File

@ -226,6 +226,10 @@ impl crate::backend::BackendDevice for MetalDevice {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
@ -241,4 +245,8 @@ impl crate::backend::BackendDevice for MetalDevice {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}

View File

@ -219,10 +219,14 @@ impl Error {
Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
pub fn msg(err: impl std::error::Error) -> Self {
Self::Msg(err.to_string()).bt()
}
pub fn debug(err: impl std::fmt::Debug) -> Self {
Self::Msg(format!("{err:?}")).bt()
}
pub fn bt(self) -> Self {
let backtrace = std::backtrace::Backtrace::capture();
match backtrace.status() {

View File

@ -47,7 +47,7 @@ mod custom_op;
mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
pub mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error;
mod indexer;
@ -63,6 +63,7 @@ pub mod quantized;
pub mod safetensors;
pub mod scalar;
pub mod shape;
mod sort;
mod storage;
mod strided_index;
mod tensor;
@ -74,7 +75,7 @@ mod variable;
#[cfg(feature = "cudnn")]
pub use cuda_backend::cudnn;
pub use cpu_backend::CpuStorage;
pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
@ -88,10 +89,12 @@ pub use tensor::{Tensor, TensorId};
pub use variable::Var;
#[cfg(feature = "cuda")]
pub use cuda_backend::{CudaDevice, CudaStorage};
pub use cuda_backend as cuda;
#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
pub use dummy_cuda_backend as cuda;
pub use cuda::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};

View File

@ -100,11 +100,11 @@ impl MetalDevice {
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self
.command_buffer_index
.try_write()
.write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
@ -119,7 +119,7 @@ impl MetalDevice {
}
pub fn wait_until_completed(&self) -> Result<()> {
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
@ -179,7 +179,7 @@ impl MetalDevice {
size,
MTLResourceOptions::StorageModeManaged,
);
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
@ -232,7 +232,7 @@ impl MetalDevice {
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
@ -251,7 +251,7 @@ impl MetalDevice {
option: MTLResourceOptions,
_name: &str,
) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
@ -283,5 +283,5 @@ impl MetalDevice {
}
fn buf_size(size: NSUInteger) -> NSUInteger {
(size - 1).next_power_of_two() as NSUInteger
size.saturating_sub(1).next_power_of_two() as NSUInteger
}

View File

@ -1,17 +1,22 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels::CallConvTranspose2dCfg;
use candle_metal_kernels::Kernels;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::{Arc, Mutex, RwLock, TryLockError};
use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
BufferOffset {
buffer,
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
}
}
/// Simple way to catch lock error without
/// depending on T
#[derive(thiserror::Error, Debug)]
@ -31,6 +36,12 @@ impl<T> From<TryLockError<T>> for MetalError {
}
}
impl<T> From<PoisonError<T>> for MetalError {
fn from(p: PoisonError<T>) -> Self {
MetalError::LockError(LockError::Poisoned(p.to_string()))
}
}
/// Metal related errors
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
@ -102,7 +113,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "affine")?;
let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "affine_f32",
DType::F16 => "affine_f16",
@ -115,7 +127,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
&self.buffer,
src,
&buffer,
mul as f32,
add as f32,
@ -134,9 +146,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * dtype.size_in_bytes(),
&buffer,
mul as f32,
add as f32,
@ -155,7 +166,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "powf")?;
let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "powf_f32",
DType::F16 => "powf_f16",
@ -168,7 +180,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
&self.buffer,
src,
&buffer,
pow as f32,
)
@ -186,9 +198,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * dtype.size_in_bytes(),
&buffer,
pow as f32,
)
@ -206,7 +217,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "elu")?;
let command_buffer = self.device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "elu_f32",
DType::F16 => "elu_f16",
@ -219,7 +231,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
&self.buffer,
src,
&buffer,
alpha as f32,
)
@ -237,9 +249,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * dtype.size_in_bytes(),
&buffer,
alpha as f32,
)
@ -309,6 +320,7 @@ impl BackendStorage for MetalStorage {
let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_reduce_strided(
&device.device,
&command_buffer,
@ -317,8 +329,7 @@ impl BackendStorage for MetalStorage {
&dims,
&stride,
dst_el,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&buffer,
)
.map_err(MetalError::from)?;
@ -344,7 +355,8 @@ impl BackendStorage for MetalStorage {
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
let command_buffer = device.command_buffer()?;
if layout.is_contiguous() && layout.start_offset() == 0 {
let src = buffer_o(&self.buffer, layout, self.dtype);
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U32, DType::F16) => "cast_u32_f16",
@ -392,8 +404,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
el_count,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&buffer,
)
.map_err(MetalError::from)?;
@ -420,9 +431,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
src,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
@ -439,137 +449,239 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
let command_buffer = device.command_buffer()?;
command_buffer.set_label(B::KERNEL);
if layout.is_contiguous() && layout.start_offset() == 0 {
use candle_metal_kernels::unary::contiguous;
let src = buffer_o(&self.buffer, layout, self.dtype);
let kernel_name = match (B::KERNEL, dtype) {
("uabs", DType::F16) => contiguous::abs::HALF,
("uabs", DType::F32) => contiguous::abs::FLOAT,
("uabs", DType::BF16) => contiguous::abs::BFLOAT,
("uceil", DType::F16) => contiguous::ceil::HALF,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("ucos", DType::F32) => contiguous::cos::FLOAT,
("ucos", DType::BF16) => contiguous::cos::BFLOAT,
("uerf", DType::F16) => contiguous::erf::HALF,
("uerf", DType::F32) => contiguous::erf::FLOAT,
("uerf", DType::BF16) => contiguous::erf::BFLOAT,
("uexp", DType::F16) => contiguous::exp::HALF,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("uexp", DType::BF16) => contiguous::exp::BFLOAT,
("ufloor", DType::F16) => contiguous::floor::HALF,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
("ulog", DType::F16) => contiguous::log::HALF,
("ulog", DType::F32) => contiguous::log::FLOAT,
("ulog", DType::BF16) => contiguous::log::BFLOAT,
("uneg", DType::F16) => contiguous::neg::HALF,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uneg", DType::BF16) => contiguous::neg::BFLOAT,
("urecip", DType::F16) => contiguous::recip::HALF,
("urecip", DType::F32) => contiguous::recip::FLOAT,
("urecip", DType::BF16) => contiguous::recip::BFLOAT,
("urelu", DType::F16) => contiguous::relu::HALF,
("urelu", DType::F32) => contiguous::relu::FLOAT,
("urelu", DType::BF16) => contiguous::relu::BFLOAT,
("uround", DType::F16) => contiguous::round::HALF,
("uround", DType::F32) => contiguous::round::FLOAT,
("uround", DType::BF16) => contiguous::round::BFLOAT,
("usilu", DType::F16) => contiguous::silu::HALF,
("usilu", DType::F32) => contiguous::silu::FLOAT,
("usilu", DType::BF16) => contiguous::silu::BFLOAT,
("usin", DType::F16) => contiguous::sin::HALF,
("usin", DType::F32) => contiguous::sin::FLOAT,
("usin", DType::BF16) => contiguous::sin::BFLOAT,
("usqr", DType::F16) => contiguous::sqr::HALF,
("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
("usqrt", DType::F16) => contiguous::sqrt::HALF,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
("utanh", DType::F16) => contiguous::tanh::HALF,
("utanh", DType::F32) => contiguous::tanh::FLOAT,
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
("usign", DType::F16) => contiguous::sign::HALF,
("usign", DType::F32) => contiguous::sign::FLOAT,
("usign", DType::BF16) => contiguous::sign::BFLOAT,
("usign", DType::I64) => contiguous::sign::I64,
(name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
}
};
candle_metal_kernels::call_unary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT,
("usin", DType::F32) => strided::sin::FLOAT,
("usqr", DType::F32) => strided::sqr::FLOAT,
("usqrt", DType::F32) => strided::sqrt::FLOAT,
("uneg", DType::F32) => strided::neg::FLOAT,
("uexp", DType::F32) => strided::exp::FLOAT,
("ulog", DType::F32) => strided::log::FLOAT,
("ugelu", DType::F32) => strided::gelu::FLOAT,
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uerf", DType::F32) => strided::erf::FLOAT,
("usilu", DType::F32) => strided::silu::FLOAT,
("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("utanh", DType::F32) => strided::tanh::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF,
("usqrt", DType::F16) => strided::sqrt::HALF,
("uneg", DType::F16) => strided::neg::HALF,
("uexp", DType::F16) => strided::exp::HALF,
("ulog", DType::F16) => strided::log::HALF,
("ugelu", DType::F16) => strided::gelu::HALF,
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uerf", DType::F16) => strided::erf::HALF,
("usilu", DType::F16) => strided::silu::HALF,
("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF,
("utanh", DType::F16) => strided::tanh::HALF,
(name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
}
};
candle_metal_kernels::call_unary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
0,
)
.map_err(MetalError::from)?;
match (el_count % 2, dtype, layout.is_contiguous()) {
(0, DType::BF16 | DType::F16, true) => {
use candle_metal_kernels::unary::contiguous_tiled;
let kernel_name = match (B::KERNEL, dtype) {
("uabs", DType::F16) => contiguous_tiled::abs::HALF,
("uabs", DType::F32) => contiguous_tiled::abs::FLOAT,
("uabs", DType::BF16) => contiguous_tiled::abs::BFLOAT,
("uceil", DType::F16) => contiguous_tiled::ceil::HALF,
("uceil", DType::F32) => contiguous_tiled::ceil::FLOAT,
("uceil", DType::BF16) => contiguous_tiled::ceil::BFLOAT,
("ucos", DType::F16) => contiguous_tiled::cos::HALF,
("ucos", DType::F32) => contiguous_tiled::cos::FLOAT,
("ucos", DType::BF16) => contiguous_tiled::cos::BFLOAT,
("uerf", DType::F16) => contiguous_tiled::erf::HALF,
("uerf", DType::F32) => contiguous_tiled::erf::FLOAT,
("uerf", DType::BF16) => contiguous_tiled::erf::BFLOAT,
("uexp", DType::F16) => contiguous_tiled::exp::HALF,
("uexp", DType::F32) => contiguous_tiled::exp::FLOAT,
("uexp", DType::BF16) => contiguous_tiled::exp::BFLOAT,
("ufloor", DType::F16) => contiguous_tiled::floor::HALF,
("ufloor", DType::F32) => contiguous_tiled::floor::FLOAT,
("ufloor", DType::BF16) => contiguous_tiled::floor::BFLOAT,
("ugelu_erf", DType::F16) => contiguous_tiled::gelu_erf::HALF,
("ugelu_erf", DType::F32) => contiguous_tiled::gelu_erf::FLOAT,
("ugelu_erf", DType::BF16) => contiguous_tiled::gelu_erf::BFLOAT,
("ugelu", DType::F16) => contiguous_tiled::gelu::HALF,
("ugelu", DType::F32) => contiguous_tiled::gelu::FLOAT,
("ugelu", DType::BF16) => contiguous_tiled::gelu::BFLOAT,
("ulog", DType::F16) => contiguous_tiled::log::HALF,
("ulog", DType::F32) => contiguous_tiled::log::FLOAT,
("ulog", DType::BF16) => contiguous_tiled::log::BFLOAT,
("uneg", DType::F16) => contiguous_tiled::neg::HALF,
("uneg", DType::F32) => contiguous_tiled::neg::FLOAT,
("uneg", DType::BF16) => contiguous_tiled::neg::BFLOAT,
("urecip", DType::F16) => contiguous_tiled::recip::HALF,
("urecip", DType::F32) => contiguous_tiled::recip::FLOAT,
("urecip", DType::BF16) => contiguous_tiled::recip::BFLOAT,
("urelu", DType::F16) => contiguous_tiled::relu::HALF,
("urelu", DType::F32) => contiguous_tiled::relu::FLOAT,
("urelu", DType::BF16) => contiguous_tiled::relu::BFLOAT,
("uround", DType::F16) => contiguous_tiled::round::HALF,
("uround", DType::F32) => contiguous_tiled::round::FLOAT,
("uround", DType::BF16) => contiguous_tiled::round::BFLOAT,
("usilu", DType::F16) => contiguous_tiled::silu::HALF,
("usilu", DType::F32) => contiguous_tiled::silu::FLOAT,
("usilu", DType::BF16) => contiguous_tiled::silu::BFLOAT,
("usin", DType::F16) => contiguous_tiled::sin::HALF,
("usin", DType::F32) => contiguous_tiled::sin::FLOAT,
("usin", DType::BF16) => contiguous_tiled::sin::BFLOAT,
("usqr", DType::F16) => contiguous_tiled::sqr::HALF,
("usqr", DType::F32) => contiguous_tiled::sqr::FLOAT,
("usqr", DType::BF16) => contiguous_tiled::sqr::BFLOAT,
("usqrt", DType::F16) => contiguous_tiled::sqrt::HALF,
("usqrt", DType::F32) => contiguous_tiled::sqrt::FLOAT,
("usqrt", DType::BF16) => contiguous_tiled::sqrt::BFLOAT,
("utanh", DType::F16) => contiguous_tiled::tanh::HALF,
("utanh", DType::F32) => contiguous_tiled::tanh::FLOAT,
("utanh", DType::BF16) => contiguous_tiled::tanh::BFLOAT,
("usign", DType::F16) => contiguous_tiled::sign::HALF,
("usign", DType::F32) => contiguous_tiled::sign::FLOAT,
("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT,
("usign", DType::I64) => contiguous_tiled::sign::I64,
(name, dtype) => {
crate::bail!(
"Metal contiguous_tiled unary {name} {dtype:?} not implemented"
)
}
};
candle_metal_kernels::call_unary_contiguous_tiled(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
}
(_, _, true) => {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("uabs", DType::F16) => contiguous::abs::HALF,
("uabs", DType::F32) => contiguous::abs::FLOAT,
("uabs", DType::BF16) => contiguous::abs::BFLOAT,
("uceil", DType::F16) => contiguous::ceil::HALF,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("ucos", DType::F32) => contiguous::cos::FLOAT,
("ucos", DType::BF16) => contiguous::cos::BFLOAT,
("uerf", DType::F16) => contiguous::erf::HALF,
("uerf", DType::F32) => contiguous::erf::FLOAT,
("uerf", DType::BF16) => contiguous::erf::BFLOAT,
("uexp", DType::F16) => contiguous::exp::HALF,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("uexp", DType::BF16) => contiguous::exp::BFLOAT,
("ufloor", DType::F16) => contiguous::floor::HALF,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
("ulog", DType::F16) => contiguous::log::HALF,
("ulog", DType::F32) => contiguous::log::FLOAT,
("ulog", DType::BF16) => contiguous::log::BFLOAT,
("uneg", DType::F16) => contiguous::neg::HALF,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uneg", DType::BF16) => contiguous::neg::BFLOAT,
("urecip", DType::F16) => contiguous::recip::HALF,
("urecip", DType::F32) => contiguous::recip::FLOAT,
("urecip", DType::BF16) => contiguous::recip::BFLOAT,
("urelu", DType::F16) => contiguous::relu::HALF,
("urelu", DType::F32) => contiguous::relu::FLOAT,
("urelu", DType::BF16) => contiguous::relu::BFLOAT,
("uround", DType::F16) => contiguous::round::HALF,
("uround", DType::F32) => contiguous::round::FLOAT,
("uround", DType::BF16) => contiguous::round::BFLOAT,
("usilu", DType::F16) => contiguous::silu::HALF,
("usilu", DType::F32) => contiguous::silu::FLOAT,
("usilu", DType::BF16) => contiguous::silu::BFLOAT,
("usin", DType::F16) => contiguous::sin::HALF,
("usin", DType::F32) => contiguous::sin::FLOAT,
("usin", DType::BF16) => contiguous::sin::BFLOAT,
("usqr", DType::F16) => contiguous::sqr::HALF,
("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
("usqrt", DType::F16) => contiguous::sqrt::HALF,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
("utanh", DType::F16) => contiguous::tanh::HALF,
("utanh", DType::F32) => contiguous::tanh::FLOAT,
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
("usign", DType::F16) => contiguous::sign::HALF,
("usign", DType::F32) => contiguous::sign::FLOAT,
("usign", DType::BF16) => contiguous::sign::BFLOAT,
("usign", DType::I64) => contiguous::sign::I64,
(name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
}
};
candle_metal_kernels::call_unary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
}
(_, _, false) => {
use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT,
("usin", DType::F32) => strided::sin::FLOAT,
("usqr", DType::F32) => strided::sqr::FLOAT,
("usqrt", DType::F32) => strided::sqrt::FLOAT,
("uneg", DType::F32) => strided::neg::FLOAT,
("uexp", DType::F32) => strided::exp::FLOAT,
("ulog", DType::F32) => strided::log::FLOAT,
("ugelu", DType::F32) => strided::gelu::FLOAT,
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uerf", DType::F32) => strided::erf::FLOAT,
("usilu", DType::F32) => strided::silu::FLOAT,
("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("utanh", DType::F32) => strided::tanh::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF,
("usqrt", DType::F16) => strided::sqrt::HALF,
("uneg", DType::F16) => strided::neg::HALF,
("uexp", DType::F16) => strided::exp::HALF,
("ulog", DType::F16) => strided::log::HALF,
("ugelu", DType::F16) => strided::gelu::HALF,
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uerf", DType::F16) => strided::erf::HALF,
("usilu", DType::F16) => strided::silu::HALF,
("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF,
("utanh", DType::F16) => strided::tanh::HALF,
("ucos", DType::BF16) => strided::cos::BFLOAT,
("usin", DType::BF16) => strided::sin::BFLOAT,
("usqr", DType::BF16) => strided::sqr::BFLOAT,
("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
("uneg", DType::BF16) => strided::neg::BFLOAT,
("uexp", DType::BF16) => strided::exp::BFLOAT,
("ulog", DType::BF16) => strided::log::BFLOAT,
("ugelu", DType::BF16) => strided::gelu::BFLOAT,
("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
("uerf", DType::BF16) => strided::erf::BFLOAT,
("usilu", DType::BF16) => strided::silu::BFLOAT,
("uabs", DType::BF16) => strided::abs::BFLOAT,
("uceil", DType::BF16) => strided::ceil::BFLOAT,
("ufloor", DType::BF16) => strided::floor::BFLOAT,
("urelu", DType::BF16) => strided::relu::BFLOAT,
("uround", DType::BF16) => strided::round::BFLOAT,
("utanh", DType::BF16) => strided::tanh::BFLOAT,
(name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
}
};
let dst = BufferOffset::zero_offset(&buffer);
candle_metal_kernels::call_unary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
src,
layout.stride(),
dst,
)
.map_err(MetalError::from)?;
}
}
Ok(Self::new(buffer, device.clone(), el_count, dtype))
}
@ -606,6 +718,7 @@ impl BackendStorage for MetalStorage {
}
let name = match (self.dtype, t.dtype()) {
(DType::U8, DType::F32) => "where_u8_f32",
(DType::U32, DType::F32) => "where_u32_f32",
(DType::U8, DType::BF16) => "where_u8_bf16",
(DType::U8, DType::F16) => "where_u8_f16",
(DType::U8, DType::I64) => "where_u8_i64",
@ -613,21 +726,21 @@ impl BackendStorage for MetalStorage {
(DType::U8, DType::U8) => "where_u8_u8",
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
let t = buffer_o(&t.buffer, t_l, t.dtype);
let f = buffer_o(&f.buffer, f_l, f.dtype);
candle_metal_kernels::call_where_cond_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
dims,
&self.buffer,
(
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
),
&t.buffer,
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer,
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
src,
layout.stride(),
t,
t_l.stride(),
f,
f_l.stride(),
&buffer,
)
.map_err(MetalError::from)?;
@ -660,6 +773,7 @@ impl BackendStorage for MetalStorage {
DType::F32 => "im2col1d_f32",
dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_im2col1d_strided(
&self.device.device,
&command_buffer,
@ -668,8 +782,7 @@ impl BackendStorage for MetalStorage {
layout.shape().dims(),
strides,
(k_size, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&dst,
)
.map_err(MetalError::from)?;
@ -712,44 +825,107 @@ impl BackendStorage for MetalStorage {
k_layout: &Layout,
params: &ParamsConvTranspose1D,
) -> Result<Self> {
const USE_COL2IM_CONV1D_TR: bool = true;
let can_use_col2im = k_layout.is_contiguous()
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
let l_out = params.l_out();
let dst_el = params.c_out * l_out * params.b_size;
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose1d_f32",
DType::F16 => "conv_transpose1d_f16",
DType::BF16 => "conv_transpose1d_bf16",
DType::U32 => "conv_transpose1d_u32",
DType::U8 => "conv_transpose1d_u8",
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
let buffer = if USE_COL2IM_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = layout.shape().dims3()?;
let (c_in2, c_out, k_size) = k_layout.shape().dims3()?;
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
layout.shape(),
k_layout.shape()
)
}
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let name = match self.dtype {
DType::F32 => "col2im1d_f32",
DType::U32 => "col2im1d_u32",
DType::U8 => "col2im1d_u8",
dtype => crate::bail!("metal col2im1d {dtype:?} not implemented"),
};
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
k_layout.start_offset(),
);
self.matmul(
k,
(b_size, l_in, c_out * k_size, c_in),
&layout.transpose(1, 2)?,
&kernel_l_mm,
)?
};
// It is important for the command buffer to be obtained *after* the matmul
// kernel has run, otherwise we might use a command-buffer that has been commited
// already resulting in the following error.
// _status < MTLCommandBufferStatusCommitted >
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_col2im1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
&[b_size, l_in, c_out, k_size],
params.k_size,
params.stride,
BufferOffset::zero_offset(&col.buffer),
&buffer,
)
.map_err(MetalError::from)?;
buffer
} else {
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose1d_f32",
DType::F16 => "conv_transpose1d_f16",
DType::BF16 => "conv_transpose1d_bf16",
DType::U32 => "conv_transpose1d_u32",
DType::U8 => "conv_transpose1d_u8",
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
buffer
};
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}
@ -787,6 +963,7 @@ impl BackendStorage for MetalStorage {
DType::U32 => "im2col_u32",
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, layout, self.dtype);
candle_metal_kernels::call_im2col_strided(
&self.device.device,
&command_buffer,
@ -795,8 +972,7 @@ impl BackendStorage for MetalStorage {
layout.shape().dims(),
layout.stride(),
(h_k, w_k, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
src,
&dst,
)
.map_err(MetalError::from)?;
@ -1009,6 +1185,7 @@ impl BackendStorage for MetalStorage {
.device
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, inp_l, self.dtype);
candle_metal_kernels::call_upsample_nearest_2d(
&self.device.device,
&command_buffer,
@ -1018,8 +1195,7 @@ impl BackendStorage for MetalStorage {
strides,
out_w,
out_h,
&self.buffer,
inp_l.start_offset() * self.dtype.size_in_bytes(),
src,
&buffer,
)
.map_err(MetalError::from)?;
@ -1027,9 +1203,8 @@ impl BackendStorage for MetalStorage {
}
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
let (ids_o1, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
if !ids_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "gather" }.bt());
};
let ids_el = ids_l.dims()[dim];
let dst_el = ids_l.shape().elem_count();
@ -1039,9 +1214,12 @@ impl BackendStorage for MetalStorage {
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
(DType::U32, DType::BF16) => "gather_u32_bf16",
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, src_l, dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_gather(
&device.device,
&command_buffer,
@ -1050,10 +1228,8 @@ impl BackendStorage for MetalStorage {
src_l.dims(),
ids_el,
dim,
&self.buffer,
src_l.start_offset() * dtype.size_in_bytes(),
&ids.buffer,
ids_o1 * ids.dtype.size_in_bytes(),
src,
ids,
&buffer,
)
.map_err(MetalError::from)?;
@ -1071,13 +1247,8 @@ impl BackendStorage for MetalStorage {
) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
let (ids_offset, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
};
let src_offset = match src_l.contiguous_offsets() {
Some((o1, _)) => o1,
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
};
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::F32) => "sa_u8_f32",
@ -1096,6 +1267,8 @@ impl BackendStorage for MetalStorage {
})?,
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_scatter_add(
&self.device.device,
&command_buffer,
@ -1104,10 +1277,8 @@ impl BackendStorage for MetalStorage {
src_l.dims(),
l.dims(),
dim,
&src.buffer,
src_offset * src.dtype.size_in_bytes(),
&ids.buffer,
ids_offset * ids.dtype.size_in_bytes(),
src,
ids,
&acc.buffer,
)
.map_err(MetalError::from)?;
@ -1143,6 +1314,8 @@ impl BackendStorage for MetalStorage {
}
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&self.buffer, src_l, dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_index_select(
&device.device,
&command_buffer,
@ -1154,10 +1327,8 @@ impl BackendStorage for MetalStorage {
src_l.is_contiguous(),
src_l.dims(),
src_l.stride(),
&self.buffer,
src_l.start_offset() * dtype.size_in_bytes(),
&ids.buffer,
ids_l.start_offset() * ids.dtype.size_in_bytes(),
src,
ids,
&buffer,
)
.map_err(MetalError::from)?;
@ -1175,13 +1346,8 @@ impl BackendStorage for MetalStorage {
) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
let (ids_offset, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let src_offset = match src_l.contiguous_offsets() {
Some((o1, _)) => o1,
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt());
};
let name = match (ids.dtype, self.dtype) {
(DType::I64, DType::BF16) => "ia_i64_bf16",
@ -1212,6 +1378,8 @@ impl BackendStorage for MetalStorage {
})?,
};
let command_buffer = self.device.command_buffer()?;
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_index_add(
&self.device.device,
&command_buffer,
@ -1221,10 +1389,8 @@ impl BackendStorage for MetalStorage {
l.dims(),
ids_l.dims(),
dim,
&src.buffer,
src_offset * src.dtype.size_in_bytes(),
&ids.buffer,
ids_offset * ids.dtype.size_in_bytes(),
src,
ids,
&acc.buffer,
)
.map_err(MetalError::from)?;
@ -1358,17 +1524,20 @@ impl BackendStorage for MetalStorage {
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
};
let src = buffer_o(&self.buffer, src_l, self.dtype);
let dst = BufferOffset {
buffer: &dst.buffer,
offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(),
};
candle_metal_kernels::call_unary_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
src_l.dims(),
&self.buffer,
src,
src_l.stride(),
src_l.start_offset() * self.dtype.size_in_bytes(),
&dst.buffer,
dst_offset * dst.dtype.size_in_bytes(),
dst,
)
.map_err(MetalError::from)?;
command_buffer.set_label("copy_strided");
@ -1402,10 +1571,9 @@ impl MetalStorage {
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
&& &op[..1] != "b"
{
let lhs = buffer_o(&self.buffer, lhs_l, self.dtype);
let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype);
let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" {
use candle_metal_kernels::binary::contiguous;
let (kernel_name, dtype) = match (op, self.dtype) {
@ -1486,8 +1654,8 @@ impl MetalStorage {
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
lhs,
rhs,
&buffer,
)
.map_err(MetalError::from)?;
@ -1585,12 +1753,10 @@ impl MetalStorage {
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
lhs,
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&rhs.buffer,
rhs,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
@ -1688,6 +1854,19 @@ impl BackendDevice for MetalDevice {
self.storage_from_cpu_storage(&cpu_storage)
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let (count, buffer) = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
};
Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE))
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let (count, buffer) = match storage {
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
@ -1796,6 +1975,10 @@ impl BackendDevice for MetalDevice {
Ok(())
}
fn synchronize(&self) -> Result<()> {
self.wait_until_completed()
}
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {

View File

@ -330,7 +330,7 @@ impl Tensor {
path: P,
) -> Result<()> {
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
let options =
let options: zip::write::FileOptions<()> =
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
for (name, tensor) in ts.iter() {

View File

@ -2,6 +2,7 @@ use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use half::f16;
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
@ -40,6 +41,7 @@ fn quantize_q8_1(
src: &CudaView<f32>,
dst: &mut CudaSlice<u8>,
elem_count: usize,
ky: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
@ -49,7 +51,7 @@ fn quantize_q8_1(
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
grid_dim: (num_blocks as u32, ky as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
@ -58,7 +60,7 @@ fn quantize_q8_1(
Ok(())
}
fn dequantize(
fn dequantize_f32(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
@ -68,27 +70,27 @@ fn dequantize(
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0",
"dequantize_block_q5_0_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1",
"dequantize_block_q5_1_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
@ -115,6 +117,63 @@ fn dequantize(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_f16(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0_f16",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1_f16",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
if is_k {
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
@ -165,6 +224,7 @@ fn mul_mat_vec_via_q8_1(
dtype: GgmlDType,
ncols: usize,
nrows: usize,
b_size: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
@ -173,14 +233,18 @@ fn mul_mat_vec_via_q8_1(
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols {
if y.len() != ncols * b_size {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
if b_size == 0 || b_size > 8 {
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
}
// Start by quantizing y
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let y_size_in_bytes =
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
@ -195,11 +259,19 @@ fn mul_mat_vec_via_q8_1(
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
1 => (nrows as u32, 4),
2..=4 => ((nrows as u32 + 1) / 2, 4),
5..=8 => ((nrows as u32 + 1) / 2, 2),
_ => crate::bail!("unexpected bsize {b_size}"),
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nrows as u32, 1, 1),
block_dim: (WARP_SIZE as u32, 4, 1),
grid_dim: (nblocks, 1, 1),
block_dim: (WARP_SIZE as u32, nwarps, 1),
shared_mem_bytes: 0,
};
@ -209,13 +281,83 @@ fn mul_mat_vec_via_q8_1(
&dst,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
/* nrows_y */ ncols as i32,
/* nrows_y */ ncols_padded as i32,
/* nrows_dst */ nrows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
#[allow(clippy::too_many_arguments)]
fn mul_mat_via_q8_1(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
x_rows: usize,
x_cols: usize,
y_rows: usize,
y_cols: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
}
if y.len() != y_rows * y_cols {
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
}
if x_cols != y_rows {
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
}
let k = x_cols;
// Start by quantizing y
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
let (kernel_name, mmq_x, mmq_y) = match dtype {
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
ceil_div(x_rows, mmq_y) as u32,
ceil_div(y_cols, mmq_x) as u32,
1,
),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
let params = (
/* vx */ data,
/* vy */ &y_q8_1,
/* dst */ &dst,
/* ncols_x */ x_cols as i32,
/* nrows_x */ x_rows as i32,
/* ncols_y */ y_cols as i32,
/* nrows_y */ k_padded as i32,
/* nrows_dst */ x_rows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
@ -257,7 +399,7 @@ impl QCudaStorage {
| GgmlDType::Q8K
);
if fast_kernel {
return dequantize(&self.data, self.dtype, elem_count, self.device());
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.
@ -285,6 +427,10 @@ impl QCudaStorage {
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
}
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
}
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
@ -313,7 +459,17 @@ impl QCudaStorage {
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
1
} else {
8
};
let use_vec_kernel = match layout.shape().dims() {
[b, m, _k] => b * m <= max_bm,
[b, _k] => *b <= max_bm,
_ => false,
};
if use_vec_kernel {
self.dequantize_matmul_vec(self_shape, storage, layout)
} else {
self.dequantize_matmul(self_shape, storage, layout)
@ -334,25 +490,31 @@ impl QCudaStorage {
Some((o1, o2)) => rhs.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
};
let (with_batch, k) = match rhs_l.shape().dims() {
[1, 1, k] => (true, k),
[1, k] => (false, k),
let (b_size, k) = match rhs_l.shape().dims() {
[b, m, k] => (b * m, *k),
[b, k] => (*b, *k),
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
};
if ncols != *k {
if ncols != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
} else {
mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
};
let out_shape = if with_batch {
vec![1, 1, nrows]
} else {
vec![1, nrows]
mul_mat_vec_via_q8_1(
&self.data,
&rhs,
self.dtype,
ncols,
nrows,
b_size,
self.device(),
)?
};
let mut out_shape = rhs_l.shape().dims().to_vec();
out_shape.pop();
out_shape.push(nrows);
Ok((out, out_shape.into()))
}
@ -373,9 +535,30 @@ impl QCudaStorage {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
}
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
} else {
let storage = storage.as_cuda_slice::<f32>()?;
let storage = match layout.contiguous_offsets() {
Some((o1, o2)) => storage.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous {
op: "quantized-matmul",
}
.bt())?,
};
mul_mat_via_q8_1(
&self.data,
&storage,
self.dtype,
/* x_rows */ n,
/* x_cols */ k,
/* y_rows */ k,
/* y_cols */ b * m,
self.device(),
)?
};
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();
out_shape.push(n);
@ -412,7 +595,7 @@ mod test {
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
@ -430,6 +613,7 @@ mod test {
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
/* b_size */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
@ -453,4 +637,44 @@ mod test {
assert_eq!(vs[0], 5561851.0);
Ok(())
}
#[test]
fn cuda_mm_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ 4,
/* x_cols */ ncols,
/* y_rows */ ncols,
/* y_cols */ 4,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
/*
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
x @ x.t() / 16
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
*/
assert_eq!(vs.len(), 16);
assert_eq!(vs[0], 347604.0);
assert_eq!(vs[1], 888153.06);
assert_eq!(vs[4], 869780.7);
assert_eq!(vs[5], 2483145.0);
assert_eq!(vs[11], 9407368.0);
assert_eq!(vs[14], 9470856.0);
assert_eq!(vs[15], 13138824.0);
Ok(())
}
}

View File

@ -24,6 +24,10 @@ impl QCudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -135,7 +135,6 @@ pub enum ValueType {
// The value is a UTF-8 non-null-terminated string, with length prepended.
String,
// The value is an array of other values, with the length and type prepended.
///
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
Array,
}
@ -218,10 +217,16 @@ impl Value {
}
}
/// This will also automatically upcast any integral types which will not truncate.
pub fn to_u64(&self) -> Result<u64> {
match self {
Self::U64(v) => Ok(*v),
v => crate::bail!("not a u64 {v:?}"),
// Autoupcast cases here
Self::U8(v) => Ok(*v as u64),
Self::U16(v) => Ok(*v as u64),
Self::U32(v) => Ok(*v as u64),
Self::Bool(v) => Ok(*v as u64),
v => crate::bail!("not a u64 or upcastable to u64 {v:?}"),
}
}

View File

@ -152,9 +152,9 @@ impl QMetalStorage {
// We always use a single batch dimension and stack all the tensors in the batch on the
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
// properly.
let (b, m) = match dst_shape.len() {
3 => (1, dst_shape[0] * dst_shape[1]),
2 => (1, dst_shape[0]),
let m = match dst_shape.len() {
3 => dst_shape[0] * dst_shape[1],
2 => dst_shape[0],
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
@ -166,18 +166,23 @@ impl QMetalStorage {
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&self.buffer,
&dst,
)
.map_err(MetalError::from)?;
// In some cases it would be better to use the mm variant, though it has its drawbacks
// around memory alignemnt.
for batch_id in 0..m {
candle_metal_kernels::call_quantized_matmul_mv_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(1, 1, n, k),
storage.buffer(),
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
&self.buffer,
batch_id * n * DType::F32.size_in_bytes(),
&dst,
)
.map_err(MetalError::from)?;
}
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
Ok((dst_storage, dst_shape))
}

View File

@ -1,4 +1,4 @@
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
@ -360,9 +360,24 @@ impl QTensor {
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
let is_variable = false;
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
.to_device(device)
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
}
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
// In the CUDA case, we have a specialized kernel as this can be useful for volta
// architectures. https://github.com/huggingface/candle/issues/2136
match &self.storage {
QStorage::Cuda(s) => {
let s = s.dequantize_f16(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
.to_device(device)
}
_ => {
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
Ok(s)
}
}
}
pub fn storage_size_in_bytes(&self) -> usize {
@ -378,6 +393,7 @@ impl QTensor {
pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
TensorF16(Tensor),
}
thread_local! {
@ -391,6 +407,17 @@ thread_local! {
}
}
thread_local! {
static DEQUANTIZE_ALL_F16: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let dequantize = match qtensor.dtype() {
@ -400,6 +427,9 @@ impl QMatMul {
let t = if dequantize {
let tensor = qtensor.dequantize(&qtensor.device())?;
Self::Tensor(tensor)
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
Self::TensorF16(tensor)
} else {
Self::QTensor(qtensor)
};
@ -409,6 +439,25 @@ impl QMatMul {
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
Self::from_arc(std::sync::Arc::new(qtensor))
}
pub fn dequantize_f16(&self) -> Result<Tensor> {
match self {
Self::QTensor(t) => t.dequantize_f16(&t.device()),
Self::Tensor(t) => t.to_dtype(DType::F16),
Self::TensorF16(t) => Ok(t.clone()),
}
}
pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
let w = self.dequantize_f16()?;
let in_dtype = xs.dtype();
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}
}
impl crate::CustomOp1 for QTensor {
@ -486,6 +535,15 @@ impl crate::Module for QMatMul {
};
xs.matmul(&w)
}
Self::TensorF16(w) => {
let in_dtype = xs.dtype();
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}
}
}
}

View File

@ -349,6 +349,30 @@ impl MmapedSafetensors {
}
}
pub struct SliceSafetensors<'a> {
safetensors: SafeTensors<'a>,
}
impl<'a> SliceSafetensors<'a> {
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
pub fn new(buffer: &'a [u8]) -> Result<Self> {
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
Ok(Self { safetensors })
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.safetensors.tensor(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
self.safetensors.tensors()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
Ok(self.safetensors.tensor(name)?)
}
}
pub struct BufferedSafetensors {
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
}

239
candle-core/src/sort.rs Normal file
View File

@ -0,0 +1,239 @@
use crate::{Result, Tensor};
use rayon::prelude::*;
#[derive(Debug, Clone, Copy)]
struct ArgSort {
asc: bool,
last_dim: usize,
}
impl ArgSort {
fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
#[allow(clippy::uninit_vec)]
// Safety: indexes are set later in the parallelized section.
let mut sort_indexes = unsafe {
let el_count = layout.shape().elem_count();
let mut v = Vec::with_capacity(el_count);
v.set_len(el_count);
v
};
if self.asc {
sort_indexes
.par_chunks_exact_mut(self.last_dim)
.zip(vs.par_chunks_exact(self.last_dim))
.for_each(|(indexes, vs)| {
indexes
.iter_mut()
.enumerate()
.for_each(|(i, v)| *v = i as u32);
indexes.sort_by(|&i, &j| {
vs[i as usize]
.partial_cmp(&vs[j as usize])
.unwrap_or(std::cmp::Ordering::Greater)
})
});
} else {
sort_indexes
.par_chunks_exact_mut(self.last_dim)
.zip(vs.par_chunks_exact(self.last_dim))
.for_each(|(indexes, vs)| {
indexes
.iter_mut()
.enumerate()
.for_each(|(i, v)| *v = i as u32);
indexes.sort_by(|&j, &i| {
vs[i as usize]
.partial_cmp(&vs[j as usize])
.unwrap_or(std::cmp::Ordering::Greater)
})
});
}
sort_indexes
}
}
impl crate::CustomOp1 for ArgSort {
fn name(&self) -> &'static str {
"argsort"
}
fn cpu_fwd(
&self,
storage: &crate::CpuStorage,
layout: &crate::Layout,
) -> Result<(crate::CpuStorage, crate::Shape)> {
let sort_indexes = match storage {
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
};
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
Ok((sort_indexes, layout.shape().into()))
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, crate::Shape)> {
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
use crate::{CudaDevice, WithDType};
impl Map1Any for ArgSort {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &crate::Layout,
_wrap: W,
) -> Result<S> {
let slice = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let ncols_pad = next_power_of_2(ncols);
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
let cfg = LaunchConfig {
grid_dim: (1, nrows as u32, 1),
block_dim: (ncols_pad as u32, 1, 1),
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
};
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U32(dst))
}
}
use crate::backend::BackendStorage;
let dev = storage.device();
let slice = self.map(&storage.slice, dev, layout)?;
let dst = crate::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, layout.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, crate::Shape)> {
use crate::backend::BackendStorage;
use crate::DType;
let name = {
if self.asc {
match storage.dtype() {
DType::BF16 => "asort_asc_bf16",
DType::F16 => "asort_asc_f16",
DType::F32 => "asort_asc_f32",
DType::F64 => "asort_asc_f64",
DType::U8 => "asort_asc_u8",
DType::U32 => "asort_asc_u32",
DType::I64 => "asort_asc_i64",
}
} else {
match storage.dtype() {
DType::BF16 => "asort_desc_bf16",
DType::F16 => "asort_desc_f16",
DType::F32 => "asort_desc_f32",
DType::F64 => "asort_desc_f64",
DType::U8 => "asort_desc_u8",
DType::U32 => "asort_desc_u32",
DType::I64 => "asort_desc_i64",
}
}
};
let device = storage.device();
let kernels = device.kernels();
let command_buffer = device.command_buffer()?;
let el = layout.shape().elem_count();
let ncols = self.last_dim;
let nrows = el / ncols;
let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
let dst = device.new_buffer(el, DType::U32, "asort")?;
let mut ncols_pad = 1;
while ncols_pad < ncols {
ncols_pad *= 2;
}
candle_metal_kernels::call_arg_sort(
device.metal_device(),
&command_buffer,
kernels,
name,
nrows,
ncols,
ncols_pad,
src,
&dst,
)
.map_err(crate::Error::wrap)?;
let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
Ok((dst, layout.shape().clone()))
}
}
#[allow(unused)]
fn next_power_of_2(x: usize) -> usize {
let mut n = 1;
while n < x {
n *= 2
}
n
}
impl Tensor {
/// Returns the indices that sort the tensor along the last dimension.
///
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
/// descending order. The sort is unstable so there is no guarantees on the final order when it
/// comes to ties.
pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
if !self.is_contiguous() {
return Err(crate::Error::RequiresContiguous {
op: "arg_sort_last_dim",
});
}
let last_dim = match self.dims().last() {
None => crate::bail!("empty last-dim in arg-sort"),
Some(last_dim) => *last_dim,
};
// No need for a backward pass for arg sort.
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
}
/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
/// sorted indexes.
///
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
/// descending order. The sort is unstable so there is no guarantees on the final order when it
/// comes to ties.
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
if !self.is_contiguous() {
return Err(crate::Error::RequiresContiguous {
op: "sort_last_dim",
});
}
let asort = self.arg_sort_last_dim(asc)?;
let sorted = self.gather(&asort, crate::D::Minus1)?;
Ok((sorted, asort))
}
}

View File

@ -79,6 +79,9 @@ macro_rules! unary_op {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self) -> Result<Self> {
let shape = self.shape();
if shape.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self
.storage()
.unary_impl::<crate::op::$op_name>(self.layout())?;
@ -92,6 +95,9 @@ macro_rules! binary_op {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
if shape.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(),
self.layout(),
@ -114,6 +120,9 @@ macro_rules! binary_op_scalar {
.broadcast_as(self.shape())?,
};
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(),
self.layout(),
@ -447,7 +456,15 @@ impl Tensor {
shape: S,
device: &Device,
) -> Result<Self> {
Self::new_impl(array, shape.into(), device, false)
let shape = shape.into();
let n: usize = shape.elem_count();
let buffer_size: usize = array.len();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let storage = device.storage_from_slice(array)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, false))
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
@ -646,6 +663,9 @@ impl Tensor {
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().affine(self.layout(), mul, add)?;
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
Ok(from_storage(storage, self.shape(), op, false))
@ -653,6 +673,9 @@ impl Tensor {
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
pub fn elu(&self, alpha: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().elu(self.layout(), alpha)?;
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
Ok(from_storage(storage, self.shape(), op, false))
@ -660,6 +683,9 @@ impl Tensor {
/// Raise the tensor to some float exponent `e`.
pub fn powf(&self, e: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().powf(self.layout(), e)?;
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
Ok(from_storage(storage, self.shape(), op, false))
@ -1154,6 +1180,9 @@ impl Tensor {
let n = b_dims[dim - 1];
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
if c_shape.elem_count() == 0 || k == 0 {
return Tensor::zeros(c_shape, self.dtype(), self.device());
}
let batching: usize = a_dims[..dim - 2].iter().product();
let batching_b: usize = b_dims[..dim - 2].iter().product();
if k != k2 || batching != batching_b {

View File

@ -235,4 +235,66 @@ impl Tensor {
}
Ok(crate::tensor::from_storage(storage, shape, op, false))
}
/// Set the values on `self` using values from `src`. The copy starts at the specified
/// `offset` for the target dimension `dim` on `self`.
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
/// has to be greater than or equal to `offset` plus the `src` size.
///
/// Note that this modifies `self` in place and as such is not compatibel with
/// back-propagation.
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
let dim = dim.to_index(self.shape(), "slice-set")?;
if !self.is_contiguous() || !src.is_contiguous() {
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
}
if self.dtype() != src.dtype() {
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: src.dtype(),
op: "slice-set",
}
.bt())?
}
if self.device().location() != src.device().location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: self.device().location(),
rhs: src.device().location(),
op: "slice-set",
}
.bt())?
}
if self.rank() != src.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: self.rank(),
got: src.rank(),
shape: self.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
if dim_idx == dim && *v2 + offset > *v1 {
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
}
if dim_idx != dim && v1 != v2 {
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
}
}
let block_size: usize = src.dims().iter().skip(1 + dim).product();
let d1: usize = src.dims().iter().take(dim).product();
let d2 = block_size * src.dims()[dim];
let dst_o = self.layout().start_offset() + offset * block_size;
let src_o = src.layout().start_offset();
src.storage().copy2d(
&mut self.storage_mut(),
d1,
d2,
/* src_s */ d2,
/* dst_s */ block_size * self.dims()[dim],
src_o,
dst_o,
)?;
Ok(())
}
}

View File

@ -34,9 +34,14 @@ impl Var {
Ok(Self(inner))
}
// Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
pub fn from_tensor(t: &Tensor) -> Result<Self> {
let inner = t.make_var()?;
Ok(Self(inner))
if t.is_variable() {
Ok(Self(t.clone()))
} else {
let inner = t.make_var()?;
Ok(Self(inner))
}
}
pub fn rand_f64<S: Into<Shape>>(

View File

@ -3,7 +3,7 @@ use candle_core::{
quantized::{self, GgmlDType},
test_device,
test_utils::to_vec2_round,
Device, Module, Result, Tensor,
DType, Device, IndexOp, Module, Result, Tensor,
};
use quantized::{k_quants, GgmlType};
use rand::prelude::*;
@ -47,18 +47,14 @@ fn test_matmul(
}
fn quantized_matmul(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let lhs_s = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[
@ -67,7 +63,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
]
);
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
let mm = tensor_lhs.matmul(&tensor_rhs)?;
let mm = lhs.matmul(&tensor_rhs)?;
assert_eq!(
mm.to_vec2::<f32>()?,
&[
@ -79,7 +75,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
let res = matmul.forward(&lhs)?;
match device {
Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?,
@ -89,7 +85,15 @@ fn quantized_matmul(device: &Device) -> Result<()> {
[341970.0, 994574.0, 1656181.0, 2302182.0]
]
),
_ => assert_eq!(
Device::Cuda(_) => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[84866.0, 214045.0, 344676.0, 473707.0],
[213425.0, 604313.0, 1000431.0, 1387960.0],
[342030.0, 994630.0, 1656248.0, 2302250.0]
]
),
Device::Cpu => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[85120.0, 214562.0, 345455.0, 474748.0],
@ -98,22 +102,16 @@ fn quantized_matmul(device: &Device) -> Result<()> {
]
),
}
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
Ok(())
}
fn quantized_matmul_neg(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k))
let lhs_s = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n)
@ -121,7 +119,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
.collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[
@ -129,7 +127,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
-196472.0, 63012.0, 324585.0, 587902.0
]
);
let mm = tensor_lhs.matmul(&tensor_rhs)?;
let mm = lhs.matmul(&tensor_rhs)?;
assert_eq!(
to_vec2_round(&mm, 0)?,
&[
@ -141,7 +139,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
let res = matmul.forward(&lhs)?;
match device {
Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?,
@ -151,7 +149,15 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
[-196102.0, 63022.0, 324233.0, 587191.0]
]
),
_ => assert_eq!(
Device::Cuda(_) => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243740.0, -19762.0, -285476.0, -550498.0],
[23774.0, 21645.0, 19395.0, 18364.0],
[-196045.0, 63030.0, 324120.0, 587079.0]
]
),
Device::Cpu => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243524.0, -19596.0, -285051.0, -549815.0],
@ -160,22 +166,58 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
]
),
}
let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?;
let res2 = matmul.forward(&lhs2)?;
let res2 = res2.i(1)?;
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if device.is_cuda() {
assert!(diff < 0.1);
} else {
assert_eq!(diff, 0.);
}
Ok(())
}
test_device!(
quantized_matmul,
quantized_matmul_cpu,
quantized_matmul_cuda,
quantized_matmul_metal
);
test_device!(
quantized_matmul_neg,
quantized_matmul_neg_cpu,
quantized_matmul_neg_cuda,
quantized_matmul_neg_metal
);
fn qmm_batch(dev: &Device) -> Result<()> {
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.shape().dims(), [2, 6]);
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
let mm2 = rhs.forward(&lhs2)?;
assert_eq!(mm2.shape().dims(), [4, 6]);
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff2, 0.0);
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
let mm3 = rhs.forward(&lhs3)?;
assert_eq!(mm3.shape().dims(), [6, 6]);
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff3, 0.0);
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff3, 0.0);
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
let mm4 = rhs.forward(&lhs4)?;
assert_eq!(mm4.shape().dims(), [12, 6]);
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if dev.is_cuda() {
// We use a different kernel for sizes from 1 to 8 on cuda which explains
// the difference here.
assert!(0. < diff4 && diff4 < 1e-4)
} else {
assert_eq!(diff4, 0.0)
};
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff4, 0.0);
Ok(())
}
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
fn quantize_q4_0(device: &Device) -> Result<()> {
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
@ -183,6 +225,13 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!(
dst.to_vec1::<f32>()?,
&[
@ -209,6 +258,13 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!(
round_vector(&dst.to_vec1::<f32>()?),
&[
@ -235,6 +291,13 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!(
round_vector(&dst.to_vec1::<f32>()?),
&[
@ -261,6 +324,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!(
round_vector(&dst.to_vec1::<f32>()?),
&[
@ -345,6 +415,13 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
if error > max_error {
bail!(
@ -362,6 +439,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -381,6 +465,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
@ -395,6 +486,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -414,6 +512,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
@ -428,6 +533,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -447,6 +559,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
@ -461,6 +580,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -480,6 +606,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
@ -494,6 +627,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -513,6 +653,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
@ -527,6 +674,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -546,6 +700,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;

View File

@ -1,5 +1,31 @@
use candle_core::{DType, Result, Tensor};
struct TmpFile(std::path::PathBuf);
impl TmpFile {
fn create(base: &str) -> TmpFile {
let filename = std::env::temp_dir().join(format!(
"candle-{}-{}-{:?}",
base,
std::process::id(),
std::thread::current().id(),
));
TmpFile(filename)
}
}
impl std::convert::AsRef<std::path::Path> for TmpFile {
fn as_ref(&self) -> &std::path::Path {
self.0.as_path()
}
}
impl Drop for TmpFile {
fn drop(&mut self) {
std::fs::remove_file(&self.0).unwrap()
}
}
#[test]
fn npy() -> Result<()> {
let npy = Tensor::read_npy("tests/test.npy")?;
@ -22,3 +48,24 @@ fn npz() -> Result<()> {
);
Ok(())
}
#[test]
fn safetensors() -> Result<()> {
use candle_core::safetensors::Load;
let tmp_file = TmpFile::create("st");
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
t.save_safetensors("t", &tmp_file)?;
// Load from file.
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
let t2 = st.get("t").unwrap();
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0f32);
// Load from bytes.
let bytes = std::fs::read(tmp_file)?;
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0f32);
Ok(())
}

View File

@ -96,6 +96,40 @@ fn clamp(device: &Device) -> Result<()> {
Ok(())
}
fn asort(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
let indexes = tensor.arg_sort_last_dim(true)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
);
let indexes = tensor.arg_sort_last_dim(false)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
);
let (sorted, indexes) = tensor.sort_last_dim(true)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
);
assert_eq!(
sorted.to_vec2::<f32>()?,
[[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]]
);
let (sorted, indexes) = tensor.sort_last_dim(false)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
);
assert_eq!(
sorted.to_vec2::<f32>()?,
[[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]]
);
Ok(())
}
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
@ -631,6 +665,30 @@ fn broadcast(device: &Device) -> Result<()> {
Ok(())
}
fn slice_set(device: &Device) -> Result<()> {
let (b, h, max_t, d) = (2, 4, 7, 3);
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
cache.slice_set(&tensor, 2, 0)?;
let cache_t = cache.narrow(2, 0, 4)?;
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
cache.slice_set(&tensor, 2, 1)?;
let cache_t = cache.narrow(2, 1, 4)?;
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
cache.slice_set(&ones, 2, 6)?;
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let diff = (cache.narrow(2, 6, 1)? - 1.)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
Ok(())
}
fn cat(device: &Device) -> Result<()> {
// 1D
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
@ -1083,6 +1141,27 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
fn zero_dim(device: &Device) -> Result<()> {
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
assert_eq!(t.dims3()?, (4, 0, 1));
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
assert_eq!(t_cat.dims3()?, (4, 3, 1));
let t_cat = Tensor::cat(&[&t, &t], 1)?;
assert_eq!(t_cat.dims3()?, (4, 0, 1));
let t_unary = t.sqrt()?;
assert_eq!(t_unary.dims3()?, (4, 0, 1));
let t_plus = (&t + 1.)?;
assert_eq!(t_plus.dims3()?, (4, 0, 1));
let t_mm = t2.matmul(&t.t()?)?;
assert_eq!(t_mm.dims3()?, (4, 3, 0));
let t_mm = t.matmul(&t2.t()?)?;
assert_eq!(t_mm.dims3()?, (4, 0, 3));
let t_mm = t.t()?.matmul(&t)?;
assert_eq!(t_mm.dims3()?, (4, 1, 1));
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
@ -1091,6 +1170,7 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);
@ -1130,7 +1210,9 @@ test_device!(
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381

View File

@ -89,7 +89,7 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
pub fn load() -> Result<crate::vision::Dataset> {
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "mnist".to_string();
let dataset_id = "ylecun/mnist".to_string();
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,

View File

@ -25,6 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
@ -65,6 +67,7 @@ onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
[[example]]
name = "llama_multiprocess"
@ -101,3 +104,7 @@ required-features = ["candle-datasets"]
[[example]]
name = "encodec"
required-features = ["encodec"]
[[example]]
name = "depth_anything_v2"
required-features = ["depth_anything_v2"]

View File

@ -0,0 +1,13 @@
# candle-dinov2
[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer.
This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it.
## Running an example with color map and CUDA
```bash
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg
```

View File

@ -0,0 +1,50 @@
use enterpolation::linear::ConstEquidistantLinear;
use enterpolation::Generator;
use palette::LinSrgb;
use candle::Tensor;
pub struct SpectralRColormap {
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>,
}
impl SpectralRColormap {
pub(crate) fn new() -> Self {
// Define a colormap similar to 'Spectral_r' by specifying key colors.
// got the colors from ChatGPT-4o
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple
]);
Self { gradient }
}
fn get_color(&self, value: f32) -> LinSrgb {
self.gradient.gen(value)
}
pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> {
println!("Gray: {:?}", gray.dims());
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?;
let rgb_values: Vec<f32> = gray_values
.iter()
.map(|g| self.get_color(*g))
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue])
.collect();
let [.., height, width] = gray.dims() else {
candle::bail!("Not enough dims!")
};
let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?;
color.permute((2, 0, 1))
}
}

View File

@ -0,0 +1,187 @@
//! Depth Anything V2
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use std::ffi::OsString;
use std::path::PathBuf;
use clap::Parser;
use candle::DType::{F32, U8};
use candle::{DType, Device, Module, Result, Tensor};
use candle_examples::{load_image, load_image_and_resize, save_image};
use candle_nn::VarBuilder;
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config};
use candle_transformers::models::dinov2;
use crate::color_map::SpectralRColormap;
mod color_map;
// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225];
const DINO_IMG_SIZE: usize = 518;
#[derive(Parser)]
struct Args {
#[arg(long)]
dinov2_model: Option<PathBuf>,
#[arg(long)]
depth_anything_v2_model: Option<PathBuf>,
#[arg(long)]
image: PathBuf,
#[arg(long)]
output_dir: Option<PathBuf>,
#[arg(long)]
cpu: bool,
#[arg(long)]
color_map: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let dinov2_model_file = match args.dinov2_model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-dino-v2".into());
api.get("dinov2_vits14.safetensors")?
}
Some(dinov2_model) => dinov2_model,
};
println!("Using file {:?}", dinov2_model_file);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? };
let dinov2 = dinov2::vit_small(vb)?;
println!("DinoV2 model built");
let depth_anything_model_file = match args.depth_anything_v2_model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into());
api.get("depth_anything_v2_vits.safetensors")?
}
Some(depth_anything_model) => depth_anything_model,
};
println!("Using file {:?}", depth_anything_model_file);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)?
};
let config = DepthAnythingV2Config::vit_small();
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
println!("Loaded image {image:?}");
let depth = depth_anything.forward(&image)?;
println!("Got predictions {:?}", depth.shape());
let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?;
let output_path = full_output_path(&args.image, &args.output_dir);
println!("Saving image to {}", output_path.to_string_lossy());
save_image(&output_image, output_path)?;
Ok(())
}
fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf {
let input_file_name = image_path.file_name().unwrap();
let mut output_file_name = OsString::from("depth_");
output_file_name.push(input_file_name);
let mut output_path = match output_dir {
None => image_path.parent().unwrap().to_path_buf(),
Some(output_path) => output_path.clone(),
};
output_path.push(output_file_name);
output_path
}
fn load_and_prep_image(
image_path: &PathBuf,
device: &Device,
) -> anyhow::Result<(usize, usize, Tensor)> {
let (_original_image, original_height, original_width) = load_image(&image_path, None)?;
let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)?
.unsqueeze(0)?
.to_dtype(F32)?
.to_device(&device)?;
let max_pixel_val = Tensor::try_from(255.0f32)?
.to_device(&device)?
.broadcast_as(image.shape())?;
let image = (image / max_pixel_val)?;
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?;
Ok((original_height, original_width, image))
}
fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> {
let mean_tensor =
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
let std_tensor =
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
image.sub(&mean_tensor)?.div(&std_tensor)
}
fn post_process_image(
image: &Tensor,
original_height: usize,
original_width: usize,
color_map: bool,
) -> Result<Tensor> {
let out = image.interpolate2d(original_height, original_width)?;
let out = scale_image(&out)?;
let out = if color_map {
let spectral_r = SpectralRColormap::new();
spectral_r.gray2color(&out)?
} else {
let rgb_slice = [&out, &out, &out];
Tensor::cat(&rgb_slice, 0)?.squeeze(1)?
};
let max_pixel_val = Tensor::try_from(255.0f32)?
.to_device(out.device())?
.broadcast_as(out.shape())?;
let out = (out * max_pixel_val)?;
out.to_dtype(U8)
}
fn scale_image(depth: &Tensor) -> Result<Tensor> {
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?;
let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
let min_val_tensor = Tensor::try_from(*min_val)?
.to_device(depth.device())?
.broadcast_as(depth.shape())?;
let depth = (depth - min_val_tensor)?;
let range = max_val - min_val;
let range_tensor = Tensor::try_from(range)?
.to_device(depth.device())?
.broadcast_as(depth.shape())?;
depth / range_tensor
}

View File

@ -30,6 +30,14 @@ enum Which {
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
}
struct TextGeneration {
@ -185,6 +193,9 @@ struct Args {
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
fn main() -> Result<()> {
@ -224,6 +235,10 @@ fn main() -> Result<()> {
Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
Which::CodeBase2B => "google/codegemma-2b".to_string(),
Which::CodeBase7B => "google/codegemma-7b".to_string(),
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
@ -258,7 +273,7 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
let model = Model::new(args.use_flash_attn, &config, vb)?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -0,0 +1,19 @@
# gte-Qwen1.5-7B-instruct
gte-Qwen1.5-7B-instruct is a variant of the GTE embedding model family.
- [Model card](https://huggingface.co/Alibaba-NLP/gte-Qwen1.5-7B-instruct) on the HuggingFace Hub.
- [Technical report](https://arxiv.org/abs/2308.03281) *Towards General Text Embeddings with Multi-stage Contrastive Learning*
## Running the example
Automatically download the model from the HuggingFace hub:
```bash
$ cargo run --example gte-qwen --release
```
or, load the model from a local directory:
```bash
cargo run --example gte-qwen --release --features cuda -- --local-repo /path/to/gte_Qwen1.5-7B-instruct/
```

View File

@ -0,0 +1,178 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config, Model};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
Tokenizer,
};
// gte-Qwen1.5-7B-instruct use EOS token as padding token
const EOS_TOKEN: &str = "<|endoftext|>";
const EOS_TOKEN_ID: u32 = 151643;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long, default_value = "Alibaba-NLP/gte-Qwen1.5-7B-instruct")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
local_repo: Option<String>,
}
#[derive(Debug)]
struct ConfigFiles {
pub config: std::path::PathBuf,
pub tokenizer: std::path::PathBuf,
pub weights: Vec<std::path::PathBuf>,
}
// Loading the model from the HuggingFace Hub. Network access is required.
fn load_from_hub(model_id: &str, revision: &str) -> Result<ConfigFiles> {
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
));
Ok(ConfigFiles {
config: repo.get("config.json")?,
tokenizer: repo.get("tokenizer.json")?,
weights: candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
})
}
// Loading the model from a local directory.
fn load_from_local(local_path: &str) -> Result<ConfigFiles> {
let local_path = std::path::PathBuf::from(local_path);
let weight_path = local_path.join("model.safetensors.index.json");
let json: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(weight_path)?)?;
let weight_map = match json.get("weight_map") {
Some(serde_json::Value::Object(map)) => map,
Some(_) => panic!("`weight map` is not a map"),
None => panic!("`weight map` not found"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
safetensors_files.insert(
value
.as_str()
.expect("Weight files should be parsed as strings"),
);
}
let safetensors_paths = safetensors_files
.iter()
.map(|v| local_path.join(v))
.collect::<Vec<_>>();
Ok(ConfigFiles {
config: local_path.join("config.json"),
tokenizer: local_path.join("tokenizer.json"),
weights: safetensors_paths,
})
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
// Fetch the model. Do this offline if local path provided.
println!("Fetching model files...");
let start = std::time::Instant::now();
let config_files = match args.local_repo {
Some(local_path) => load_from_local(&local_path)?,
None => load_from_hub(&args.model_id, &args.revision)?,
};
println!("Model file retrieved in {:?}", start.elapsed());
// Inputs will be padded to the longest sequence in the batch.
let padding = PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_to_multiple_of: None,
pad_id: EOS_TOKEN_ID,
pad_type_id: 0,
pad_token: String::from(EOS_TOKEN),
};
// Tokenizer setup
let mut tokenizer = Tokenizer::from_file(config_files.tokenizer).map_err(E::msg)?;
tokenizer.with_padding(Some(padding));
// Model initialization
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let config: Config = serde_json::from_slice(&std::fs::read(config_files.config)?)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&config_files.weights, dtype, &device)? };
let mut model = Model::new(&config, vb)?;
println!("Model loaded in {:?}", start.elapsed());
// Encode the queries and the targets
let instruct = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ";
let documents = vec![
format!("{instruct}how much protein should a female eat{EOS_TOKEN}"),
format!("{instruct}summit define{EOS_TOKEN}"),
format!("As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.{EOS_TOKEN}"),
format!("Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.{EOS_TOKEN}"),
];
let encoded = tokenizer.encode_batch(documents, true).map_err(E::msg)?;
let tokens: Vec<&[u32]> = encoded.iter().map(|x| x.get_ids()).collect();
let tokens = Tensor::new(tokens, &device)?;
let mask: Vec<&[u32]> = encoded.iter().map(|x| x.get_attention_mask()).collect();
let mask = Tensor::new(mask, &device)?;
// Inference
let start_gen = std::time::Instant::now();
let logits = model.forward(&tokens, 0, Some(&mask))?;
// Extract the last hidden states as embeddings since inputs are padded left.
let (_, seq_len, _) = logits.dims3()?;
let embd = logits
.narrow(1, seq_len - 1, 1)?
.squeeze(1)?
.to_dtype(DType::F32)?;
// Calculate the relativity scores. Note the embeddings should be normalized.
let norm = embd.broadcast_div(&embd.sqr()?.sum_keepdim(1)?.sqrt()?)?;
let scores = norm.narrow(0, 0, 2)?.matmul(&norm.narrow(0, 2, 2)?.t()?)?;
// Print the results
println!("Embedding done in {:?}", start_gen.elapsed());
println!("Scores: {:?}", scores.to_vec2::<f32>()?);
Ok(())
}

View File

@ -17,7 +17,7 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
@ -31,6 +31,8 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
enum Which {
V1,
V2,
V3,
V3Instruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@ -45,19 +47,23 @@ struct Args {
cpu: bool,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 10000)]
#[arg(short = 'n', long, default_value_t = 10000)]
sample_len: usize,
/// Disable the key-value cache.
@ -83,18 +89,18 @@ struct Args {
revision: Option<String>,
/// The model size to use.
#[arg(long, default_value = "v2")]
#[arg(long, default_value = "v3")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}
@ -120,11 +126,13 @@ fn main() -> Result<()> {
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (llama, tokenizer_filename, mut cache) = {
let (llama, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@ -138,7 +146,7 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn);
let filenames = match args.which {
Which::V1 | Which::V2 | Which::Solar10_7B => {
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
@ -146,10 +154,12 @@ fn main() -> Result<()> {
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache)
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
let eos_token_id = config
.eos_token_id
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@ -160,8 +170,22 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now();
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
for index in 0..args.sample_len {
@ -170,6 +194,9 @@ fn main() -> Result<()> {
} else {
(tokens.len(), 0)
};
if index == 1 {
start_gen = std::time::Instant::now()
}
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
@ -205,7 +232,7 @@ fn main() -> Result<()> {
println!(
"\n\n{} tokens generated ({} token/s)\n",
token_generated,
token_generated as f64 / dt.as_secs_f64(),
(token_generated - 1) as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -10,7 +10,7 @@
extern crate intel_mkl_src;
use anyhow::{bail, Error as E, Result};
use clap::Parser;
use clap::{Parser, ValueEnum};
use candle::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
@ -24,57 +24,15 @@ mod model;
use model::{Config, Llama};
const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = r"
EDWARD:
I wonder how our princely father 'scaped,
Or whether he be 'scaped away or no
From Clifford's and Northumberland's pursuit:
Had he been ta'en, we should have heard the news;
Had he been slain, we should have heard the news;
Or had he 'scaped, methinks we should have heard
The happy tidings of his good escape.
How fares my brother? why is he so sad?
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
RICHARD:
I cannot joy, until I be resolved
Where our right valiant father is become.
I saw him in the battle range about;
And watch'd him how he singled Clifford forth.
Methought he bore him in the thickest troop
As doth a lion in a herd of neat;
Or as a bear, encompass'd round with dogs,
Who having pinch'd a few and made them cry,
The rest stand all aloof, and bark at him.
So fared our father with his enemies;
So fled his enemies my warlike father:
Methinks, 'tis prize enough to be his son.
See how the morning opes her golden gates,
And takes her farewell of the glorious sun!
How well resembles it the prime of youth,
Trimm'd like a younker prancing to his love!
EDWARD:
Dazzle mine eyes, or do I see three suns?
RICHARD:
Three glorious suns, each one a perfect sun;
Not separated with the racking clouds,
But sever'd in a pale clear-shining sky.
See, see! they join, embrace, and seem to kiss,
As if they vow'd some league inviolable:
Now are they but one lamp, one light, one sun.
In this the heaven figures some event.
EDWARD:
'Tis wondrous strange, the like yet never heard of.
I think it cites us, brother, to the field,
That we, the sons of brave Plantagenet,
Each one already blazing by our meeds,
Should notwithstanding join our lights together
And over-shine the earth as this the world.
Whate'er it bodes, henceforward will I bear
Upon my target three fair-shining suns.
";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
V2_7b,
V2_70b,
V3_8b,
V3_70b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
@ -86,8 +44,8 @@ struct Args {
rank: Option<usize>,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
@ -117,6 +75,12 @@ struct Args {
#[arg(long)]
dtype: Option<String>,
#[arg(long, default_value = "v3-8b")]
which: Which,
#[arg(long, default_value = "nccl_id.txt")]
comm_file: String,
}
fn main() -> Result<()> {
@ -129,14 +93,27 @@ fn main() -> Result<()> {
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
None => match args.which {
Which::V2_7b | Which::V2_70b => DType::F16,
Which::V3_8b | Which::V3_70b => DType::BF16,
},
};
let api = Api::new()?;
let comm_file = std::path::PathBuf::from(&args.comm_file);
if comm_file.exists() {
bail!("comm file {comm_file:?} already exists, please remove it first")
}
let model_id = args
.model_id
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
let api = Api::new()?;
let model_id = match args.model_id {
Some(model) => model,
None => match args.which {
Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(),
Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(),
},
};
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
@ -145,39 +122,40 @@ fn main() -> Result<()> {
let tokenizer_filename = api.get("tokenizer.json")?;
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
if args.rank.is_none() {
let children: Vec<_> = (0..args.num_shards)
.map(|rank| {
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
args.push_back("--rank".to_string());
args.push_back(format!("{rank}"));
let name = args.pop_front().unwrap();
std::process::Command::new(name).args(args).spawn().unwrap()
})
.collect();
for mut child in children {
child.wait().unwrap();
let rank = match args.rank {
None => {
println!("creating {} child processes", args.num_shards);
let children: Vec<_> = (0..args.num_shards)
.map(|rank| {
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
args.push_back("--rank".to_string());
args.push_back(format!("{rank}"));
let name = args.pop_front().unwrap();
std::process::Command::new(name).args(args).spawn().unwrap()
})
.collect();
for mut child in children {
child.wait()?;
}
return Ok(());
}
return Ok(());
}
Some(rank) => rank,
};
let i = args.rank.unwrap();
let num_shards = args.num_shards;
let rank = i;
// Primitive IPC
let id = if rank == 0 {
let id = Id::new().unwrap();
std::fs::File::create("nccl_id.txt.tmp")?
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())
.unwrap();
std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?;
let tmp_file = comm_file.with_extension(".comm.tgz");
std::fs::File::create(&tmp_file)?
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
std::fs::rename(&tmp_file, &comm_file)?;
id
} else {
let path = std::path::PathBuf::from("nccl_id.txt");
while !path.exists() {
while !comm_file.exists() {
std::thread::sleep(std::time::Duration::from_secs(1));
}
let data = std::fs::read("nccl_id.txt")?;
let data = std::fs::read(&comm_file)?;
let internal: [i8; 128] = data
.into_iter()
.map(|i| i as i8)
@ -187,14 +165,17 @@ fn main() -> Result<()> {
let id: Id = Id::uninit(internal);
id
};
let device = CudaDevice::new(i)?;
let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap());
let device = CudaDevice::new(rank)?;
let comm = match Comm::from_rank(device, rank, num_shards, id) {
Ok(comm) => Rc::new(comm),
Err(err) => anyhow::bail!("nccl error {:?}", err.0),
};
if rank == 0 {
std::fs::remove_file("nccl_id.txt")?;
std::fs::remove_file(comm_file)?;
}
println!("Rank {rank:?} spawned");
let device = Device::new_cuda(i)?;
let device = Device::new_cuda(rank)?;
let cache = model::Cache::new(dtype, &config, &device)?;
println!("building the model");
@ -210,14 +191,24 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let temperature = if args.temperature <= 0. {
None
} else {
Some(args.temperature)
};
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
// Only start timing at the second token as processing the first token waits for all the
// weights to be loaded in an async way.
if index == 1 {
start_gen = std::time::Instant::now()
};
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
@ -228,25 +219,23 @@ fn main() -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
if Some(next_token) == config.eos_token_id {
break;
}
if rank == 0 {
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
}
let dt = start_gen.elapsed();
println!();
if rank == 0 {
let dt = start_gen.elapsed();
println!(
"{} tokens generated ({} token/s)\n----\n{}\n----",
"\n\n{} tokens generated ({} token/s)\n",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer
.decode(new_tokens.as_slice(), true)
.map_err(E::msg)?
(args.sample_len - 1) as f64 / dt.as_secs_f64(),
);
}
Ok(())

View File

@ -1,15 +1,14 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
use candle_nn::{Embedding, Linear, Module, RmsNorm};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use serde::Deserialize;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN;
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
pub type Config = candle_transformers::models::llama::LlamaConfig;
struct TensorParallelColumnLinear {
linear: Linear,
@ -26,7 +25,7 @@ impl TensorParallelColumnLinear {
struct TensorParallelRowLinear {
linear: Linear,
comm: Rc<Comm>,
all_reduce: AllReduce,
}
struct AllReduce {
@ -36,8 +35,6 @@ struct AllReduce {
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Sync for AllReduce {}
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Send for AllReduce {}
impl CustomOp1 for AllReduce {
@ -46,7 +43,7 @@ impl CustomOp1 for AllReduce {
}
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
todo!("implement allreduce for cpu is not necessary for single node");
candle::bail!("AllReduce is never used on cpu")
}
#[cfg(feature = "cuda")]
@ -56,31 +53,49 @@ impl CustomOp1 for AllReduce {
l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::WrapErr;
use cudarc::driver::DeviceSlice;
use half::{bf16, f16};
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
let s = s.as_cuda_slice::<f16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
let dst = match s.dtype() {
DType::BF16 => {
let s = s.as_cuda_slice::<bf16>()?;
let s = match l.contiguous_offsets() {
Some((0, l)) if l == s.len() => s,
Some(_) | None => candle::bail!("input has to be contiguous"),
};
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
self.comm
.all_reduce(s, &mut dst, &ReduceOp::Sum)
.map_err(candle::Error::debug)?;
candle::CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F16 => {
let s = s.as_cuda_slice::<f16>()?;
let s = match l.contiguous_offsets() {
Some((0, l)) if l == s.len() => s,
Some(_) | None => candle::bail!("input has to be contiguous"),
};
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm
.all_reduce(s, &mut dst, &ReduceOp::Sum)
.map_err(candle::Error::debug)?;
candle::CudaStorage::wrap_cuda_slice(dst, dev)
}
dtype => candle::bail!("unsupported dtype {dtype:?}"),
};
Ok((dst, l.shape().clone()))
}
}
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
x.apply_op1(AllReduce { comm: comm.clone() })
}
impl TensorParallelRowLinear {
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
Self { linear, comm }
let all_reduce = AllReduce { comm };
Self { linear, all_reduce }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.linear.forward(x)?;
all_reduce_sum(&x, &self.comm)
self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce)
}
}
@ -121,23 +136,6 @@ impl TensorParallelRowLinear {
}
}
#[derive(Deserialize)]
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
}
fn default_rope() -> f32 {
10_000.0
}
#[derive(Clone)]
pub struct Cache {
#[allow(clippy::type_complexity)]
@ -161,7 +159,6 @@ impl Cache {
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
@ -197,16 +194,10 @@ struct CausalSelfAttention {
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?;
let (_b_sz, _, seq_len, _hidden_size) = x.shape().dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
candle_nn::rotary_emb::rope(x, &cos, &sin)
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
@ -232,13 +223,16 @@ impl CausalSelfAttention {
let q = q
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let mut v = v
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
@ -269,25 +263,14 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
.transpose(1, 2)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
.reshape((b_sz, seq_len, hidden_size))?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.num_attention_heads / self.num_key_value_heads;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
Ok(x)
}
candle_transformers::utils::repeat_kv(x, n_rep)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
@ -301,7 +284,7 @@ impl CausalSelfAttention {
qkv_proj,
o_proj,
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(),
head_dim: cfg.hidden_size / cfg.num_attention_heads,
cache: cache.clone(),
})
@ -315,18 +298,6 @@ struct Mlp {
}
impl Mlp {
fn new(
c_fc1: TensorParallelColumnLinear,
c_fc2: TensorParallelColumnLinear,
c_proj: TensorParallelRowLinear,
) -> Self {
Self {
c_fc1,
c_fc2,
c_proj,
}
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
@ -336,7 +307,11 @@ impl Mlp {
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
Ok(Self::new(c_fc1, c_fc2, c_proj))
Ok(Self {
c_fc1,
c_fc2,
c_proj,
})
}
}
@ -427,10 +402,8 @@ impl Llama {
cfg,
comm.clone(),
)
.unwrap()
})
.collect();
.collect::<Result<Vec<_>>>()?;
Ok(Self::new(wte, blocks, norm, lm_head))
}
}

View File

@ -0,0 +1,4 @@
pub const DEFAULT_IMAGE_TOKEN: &str = "<image>";
pub const DEFAULT_IM_START_TOKEN: &str = "<im_start>";
pub const DEFAULT_IM_END_TOKEN: &str = "<im_end>";
pub const IMAGE_PLACEHOLDER: &str = "<image-placeholder>";

View File

@ -0,0 +1,114 @@
pub enum SeparatorStyle {
Two,
Mpt,
}
pub struct Conversation {
pub system: String,
pub roles: Vec<String>,
pub messages: Vec<(String, Option<String>)>,
pub offset: i32,
pub sep_style: SeparatorStyle,
pub sep: String,
pub sep2: Option<String>,
pub version: String,
}
impl Conversation {
pub fn new(
system: &str,
roles: &[String],
offset: i32,
sep_style: SeparatorStyle,
sep: &str,
sep2: Option<&str>,
version: &str,
) -> Self {
Conversation {
system: system.to_string(),
roles: roles.to_vec(),
messages: Vec::new(),
offset,
sep_style,
sep: sep.to_string(),
sep2: sep2.map(|s| s.to_string()),
version: version.to_string(),
}
}
pub fn conv_chatml_direct() -> Self {
Conversation::new(
"<|im_start|>system\nAnswer the questions.",
&[
"<|im_start|>user\n".to_string(),
"<|im_start|>assistant\n".to_string(),
],
0,
SeparatorStyle::Mpt,
"<|im_end|>",
None,
"mpt",
)
}
pub fn conv_llava_v1() -> Self {
Conversation::new(
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
&[
"USER".to_string(),
"ASSISTANT".to_string(),
],
0,
SeparatorStyle::Two,
" ",
Some("</s>"),
"v1"
)
}
pub fn append_message(&mut self, role: String, message: Option<&str>) {
self.messages.push((role, message.map(|s| s.to_string())))
}
pub fn append_user_message(&mut self, message: Option<&str>) {
self.append_message(self.roles[0].clone(), message);
}
pub fn append_assistant_message(&mut self, message: Option<&str>) {
self.append_message(self.roles[1].clone(), message);
}
pub fn get_prompt(&self) -> String {
match self.sep_style {
SeparatorStyle::Mpt => {
let mut ret = String::new();
ret.push_str(&self.system);
ret.push_str(&self.sep);
for (role, message) in &self.messages {
ret.push_str(role);
if let Some(message) = message {
ret.push_str(message);
};
ret.push_str(&self.sep);
}
ret
}
SeparatorStyle::Two => {
let seps = [self.sep.clone(), self.sep2.clone().unwrap()];
let mut ret = String::new();
ret.push_str(&self.system);
ret.push_str(&seps[0]);
for (i, (role, message)) in self.messages.iter().enumerate() {
ret.push_str(role);
if let Some(message) = message {
ret.push_str(": "); // strictly follow the python implementation, otherwise it will cause some minor difference between tokens ^_^
ret.push_str(message);
ret.push_str(&seps[i % 2]);
} else {
ret.push(':')
}
}
ret
}
}
}
}

View File

@ -0,0 +1,317 @@
use std::cmp::min;
use candle::{bail, DType, Device, Result, Tensor};
use candle_transformers::models::llava::{
config::{HFPreProcessorConfig, LLaVAConfig},
utils::select_best_resolution,
};
use hf_hub::api::sync::Api;
use image::{imageops::overlay, DynamicImage, GenericImageView, Rgb, RgbImage};
use serde::{Deserialize, Serialize};
//This struct is mainly for LLaVA aplications, hence it's not completely compatible with python transformer CLIPImageProcessor few several preprocess that LLaVA used, including "openai/clip-vit-large-patch14-336" and "openai/clip-vit-large-patch14".
#[derive(Serialize, Deserialize, Debug)]
pub struct ImageProcessor {
#[serde(default = "default_size")]
pub size: u32, // this is not the same as python transformer
#[serde(default = "default_do_resize")]
pub do_resize: bool,
//resample: u32 // 3 for PIL bicubic, equivalent to rust CatmullRom. Hence below we use CatmullRom
#[serde(default = "default_do_center_crop")]
pub do_center_crop: bool,
#[serde(default = "default_crop_size")]
pub crop_size: u32, // this is not the same as python transformer
#[serde(default = "default_do_rescale")]
pub do_rescale: bool,
#[serde(default = "default_rescale_factor")]
pub rescale_factor: f32,
#[serde(default = "default_do_normalize")]
pub do_normalize: bool,
#[serde(default = "default_image_mean")]
pub image_mean: Vec<f32>,
#[serde(default = "default_image_std")]
pub image_std: Vec<f32>,
}
fn default_size() -> u32 {
224
}
fn default_do_resize() -> bool {
true
}
fn default_do_center_crop() -> bool {
true
}
fn default_crop_size() -> u32 {
224
}
fn default_do_rescale() -> bool {
true
}
fn default_rescale_factor() -> f32 {
1.0 / 255.0
}
fn default_do_normalize() -> bool {
true
}
fn default_image_mean() -> Vec<f32> {
vec![0.48145466, 0.4578275, 0.40821073]
}
fn default_image_std() -> Vec<f32> {
vec![0.26862954, 0.2613026, 0.2757771]
}
impl ImageProcessor {
pub fn from_pretrained(clip_id: &str) -> Result<Self> {
let api = Api::new().map_err(|e| candle::Error::Msg(e.to_string()))?;
let api = api.model(clip_id.to_string());
let config_filename = api
.get("preprocessor_config.json")
.map_err(|e| candle::Error::Msg(e.to_string()))?;
let image_processor =
serde_json::from_slice(&std::fs::read(config_filename).map_err(candle::Error::Io)?)
.map_err(|e| candle::Error::Msg(e.to_string()))?;
Ok(image_processor)
}
pub fn from_hf_preprocessor_config(hf_preprocessor_config: &HFPreProcessorConfig) -> Self {
Self {
size: hf_preprocessor_config.size["shortest_edge"] as u32,
do_resize: hf_preprocessor_config.do_resize,
do_center_crop: hf_preprocessor_config.do_center_crop,
crop_size: hf_preprocessor_config.crop_size["height"] as u32,
do_rescale: hf_preprocessor_config.do_rescale,
rescale_factor: hf_preprocessor_config.rescale_factor,
do_normalize: hf_preprocessor_config.do_normalize,
image_mean: hf_preprocessor_config.image_mean.clone(),
image_std: hf_preprocessor_config.image_std.clone(),
}
}
///shortest edge to self.resize, other edge is resized to maintain aspect ratio
pub fn resize(&self, image: &DynamicImage) -> DynamicImage {
let (width, height) = image.dimensions();
let size = self.size;
if width == size && height == size {
image.clone()
} else {
let (new_width, new_height) = if width < height {
(
size,
(((size * height) as f32) / width as f32).ceil() as u32,
)
} else {
(
(((size * width) as f32) / height as f32).ceil() as u32,
size,
)
};
image.resize(
new_width,
new_height,
image::imageops::FilterType::CatmullRom,
)
}
}
pub fn center_crop(&self, image: &DynamicImage) -> DynamicImage {
let (width, height) = image.dimensions();
let crop_size = self.crop_size;
let (left, top) = calculate_middle((width, height), (crop_size, crop_size));
image.crop_imm(left, top, crop_size, crop_size)
}
pub fn to_tensor(&self, image: &DynamicImage) -> Result<Tensor> {
let img = image.to_rgb8().into_raw();
let (width, height) = image.dimensions();
Tensor::from_vec(img, (height as usize, width as usize, 3), &Device::Cpu)?
.to_dtype(DType::F32) // only for internal compute
}
pub fn rescale(&self, tensor: &Tensor) -> Result<Tensor> {
let rescale_factor = self.rescale_factor as f64;
tensor.affine(rescale_factor, 0.0)
}
pub fn normalize(&self, tensor: &Tensor) -> Result<Tensor> {
let image_mean = self.image_mean.clone();
let image_std = self.image_std.clone();
let mean = Tensor::from_vec(image_mean, (3,), &Device::Cpu)?;
let std = Tensor::from_vec(image_std, (3,), &Device::Cpu)?;
tensor.broadcast_sub(&mean)?.broadcast_div(&std)
}
pub fn to_channel_dimension_format(&self, tensor: &Tensor) -> Result<Tensor> {
tensor.permute((2, 0, 1))
}
pub fn preprocess(&self, image: &DynamicImage) -> Result<Tensor> {
let image = if self.do_resize {
self.resize(image)
} else {
image.clone()
};
let image = if self.do_center_crop {
self.center_crop(&image)
} else {
image
};
let tensor = self.to_tensor(&image)?;
let tensor = if self.do_rescale {
self.rescale(&tensor)?
} else {
tensor
};
let tensor = if self.do_normalize {
self.normalize(&tensor)?
} else {
tensor
};
self.to_channel_dimension_format(&tensor)
}
}
pub fn calculate_middle(image_size: (u32, u32), center_size: (u32, u32)) -> (u32, u32) {
let (width, height) = image_size;
let (center_width, center_height) = center_size;
let left = if width <= center_width {
0
} else {
((width as f32 - center_width as f32) / 2.0).ceil() as u32
};
let top = if height <= center_height {
0
} else {
((height as f32 - center_height as f32) / 2.0).ceil() as u32
};
(left, top)
}
pub fn process_image(
image: &DynamicImage,
processor: &ImageProcessor,
llava_config: &LLaVAConfig,
) -> candle::Result<Tensor> {
if llava_config.image_aspect_ratio == *"square" {
processor.preprocess(image)?.unsqueeze(0)
} else if llava_config.image_aspect_ratio == *"anyres" {
process_anyres_image(image, processor, &llava_config.image_grid_pinpoints)
} else if llava_config.image_aspect_ratio == *"pad" {
process_pad_image(image, processor)
} else {
bail!("Invalid image aspect ratio")
}
}
fn process_pad_image(image: &DynamicImage, processor: &ImageProcessor) -> Result<Tensor> {
let mean_color = processor
.image_mean
.iter()
.map(|x| ((*x) * 255.0) as u8)
.collect::<Vec<u8>>();
let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
let image_padded = expand2square(image, mean_color);
processor.preprocess(&image_padded)
}
fn process_anyres_image(
image: &DynamicImage,
processor: &ImageProcessor,
grid_pinpoints: &[(u32, u32)],
) -> Result<Tensor> {
let original_size = image.dimensions();
let best_resolution = select_best_resolution(original_size, grid_pinpoints);
let image_padded = resize_and_pad_image(image, best_resolution);
let image_original_resize = image.resize_exact(
processor.size,
processor.size,
image::imageops::FilterType::CatmullRom,
);
let mut patches = vec![image_original_resize];
for patch in divide_to_patches(&image_padded, processor.crop_size) {
patches.push(patch);
}
let tensors = patches
.iter()
.map(|patch| processor.preprocess(patch))
.collect::<Result<Vec<Tensor>>>()?;
Tensor::stack(&tensors, 0)
}
fn expand2square(image: &DynamicImage, background_color: Rgb<u8>) -> DynamicImage {
let (width, height) = image.dimensions();
match width.cmp(&height) {
std::cmp::Ordering::Less => {
let mut new_image =
DynamicImage::from(RgbImage::from_pixel(height, height, background_color));
overlay(&mut new_image, image, ((height - width) / 2) as i64, 0);
new_image
}
std::cmp::Ordering::Equal => image.clone(),
std::cmp::Ordering::Greater => {
let mut new_image =
DynamicImage::from(RgbImage::from_pixel(width, width, background_color));
overlay(&mut new_image, image, 0, ((width - height) / 2) as i64);
new_image
}
}
}
fn resize_and_pad_image(image: &DynamicImage, target_resolution: (u32, u32)) -> DynamicImage {
let (original_width, original_height) = image.dimensions();
let original_width_f = original_width as f32;
let original_height_f = original_height as f32;
let (target_width, target_height) = target_resolution;
let target_width_f = target_width as f32;
let target_height_f = target_height as f32;
let scale_w = target_width_f / original_width_f;
let scale_h = target_height_f / original_height_f;
let (new_width, new_height) = if scale_w < scale_h {
(
target_width,
min((original_height_f * scale_w).ceil() as u32, target_height),
)
} else {
(
min((original_width_f * scale_h).ceil() as u32, target_width),
target_height,
)
};
let resized_image = image.resize_exact(
new_width,
new_height,
image::imageops::FilterType::CatmullRom,
);
let mut new_image = DynamicImage::new_rgb8(target_width, target_height);
let (paste_x, paste_y) =
calculate_middle((target_width, target_height), (new_width, new_height));
overlay(
&mut new_image,
&resized_image,
paste_x.into(),
paste_y.into(),
);
new_image
}
fn divide_to_patches(image: &DynamicImage, patch_size: u32) -> Vec<DynamicImage> {
let (width, height) = image.dimensions();
let mut patches = Vec::new();
for y in (0..height).step_by(patch_size as usize) {
for x in (0..width).step_by(patch_size as usize) {
let patch = image.crop_imm(x, y, patch_size, patch_size);
patches.push(patch);
}
}
patches
}

View File

@ -0,0 +1,316 @@
pub mod constants;
pub mod conversation;
pub mod image_processor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::Cache;
use anyhow::{bail, Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::llava::config::{
HFGenerationConfig, HFLLaVAConfig, HFPreProcessorConfig,
};
use candle_transformers::models::llava::{config::LLaVAConfig, LLaVA};
use clap::Parser;
use constants::*;
use conversation::Conversation;
use hf_hub::api::sync::Api;
use image_processor::{process_image, ImageProcessor};
use std::io::Write;
use tokenizers::Tokenizer;
#[derive(Parser, Debug)]
#[command(author, version, about,long_about=None)]
struct Args {
#[arg(long, default_value = "llava-hf/llava-v1.6-vicuna-7b-hf")]
model_path: String,
#[arg(long, default_value = "tokenizer/tokenizer.json")]
tokenizer_path: String,
#[arg(long)]
model_base: Option<String>,
#[arg(long)]
image_file: String, // Required
#[arg(long)]
conv_mode: Option<String>,
#[arg(long, default_value_t = 0.2)]
temperature: f32,
#[arg(long, default_value_t = 512)]
max_new_tokens: usize,
#[arg(long, action)]
hf: bool,
#[arg(long, action)]
cpu: bool,
#[arg(long, action)]
no_kv_cache: bool,
#[arg(long)]
prompt: String,
/// The seed to use when generating random samples. Copy from candle llama. Not exist in python llava.
#[arg(long, default_value_t = 299792458)]
seed: u64,
}
//from https://github.com/huggingface/candle/blob/main/candle-examples/examples/clip/main.rs
fn load_image<T: AsRef<std::path::Path>>(
path: T,
processor: &ImageProcessor,
llava_config: &LLaVAConfig,
dtype: DType,
) -> Result<((u32, u32), Tensor)> {
let img = image::io::Reader::open(path)?.decode()?;
let img_tensor = process_image(&img, processor, llava_config)?;
Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?))
}
fn get_model_name_from_path(model_path: &str) -> String {
let model_paths: Vec<String> = model_path
.trim_matches('/')
.split('/')
.map(|s| s.to_string())
.collect();
if model_paths.last().unwrap().starts_with("checkpoint-") {
format!(
"{}_{}",
model_paths[model_paths.len() - 2],
model_paths.last().unwrap()
)
} else {
model_paths.last().unwrap().to_string()
}
}
fn duplicate_vec<T>(vec: &[T], n: usize) -> Vec<T>
where
T: Clone,
{
let mut res = Vec::new();
for _ in 0..n {
res.extend(vec.to_owned());
}
res
}
fn insert_separator<T>(x: Vec<Vec<T>>, sep: Vec<T>) -> Vec<Vec<T>>
where
T: Clone,
{
let sep = vec![sep];
let sep = duplicate_vec(&sep, x.len());
let mut res = x
.iter()
.zip(sep.iter())
.flat_map(|(x, y)| vec![x.clone(), y.clone()])
.collect::<Vec<Vec<T>>>();
res.pop();
res
}
fn tokenizer_image_token(
prompt: &str,
tokenizer: &Tokenizer,
image_token_index: i64,
llava_config: &LLaVAConfig,
) -> Result<Tensor> {
let prompt_chunks = prompt
.split("<image>")
.map(|s| {
tokenizer
.encode(s, true)
.unwrap()
.get_ids()
.to_vec()
.iter()
.map(|x| *x as i64)
.collect()
})
.collect::<Vec<Vec<i64>>>();
let mut input_ids = Vec::new();
let mut offset = 0;
if !prompt_chunks.is_empty()
&& !prompt_chunks[0].is_empty()
&& prompt_chunks[0][0] == llava_config.bos_token_id as i64
{
offset = 1;
input_ids.push(prompt_chunks[0][0]);
}
for x in insert_separator(
prompt_chunks,
duplicate_vec(&[image_token_index], offset + 1),
)
.iter()
{
input_ids.extend(x[1..].to_vec())
}
let input_len = input_ids.len();
Tensor::from_vec(input_ids, (1, input_len), &Device::Cpu).map_err(E::msg)
}
fn main() -> Result<()> {
let mut args = Args::parse();
let device = candle_examples::device(args.cpu)?;
println!("Start loading model");
let api = Api::new()?;
let api = api.model(args.model_path.clone());
let (llava_config, tokenizer, clip_vision_config, image_processor) = if args.hf {
let config_filename = api.get("config.json")?;
let hf_llava_config: HFLLaVAConfig =
serde_json::from_slice(&std::fs::read(config_filename)?)?;
let generation_config_filename = api.get("generation_config.json")?;
let generation_config: HFGenerationConfig =
serde_json::from_slice(&std::fs::read(generation_config_filename)?)?;
let preprocessor_config_filename = api.get("preprocessor_config.json")?;
let preprocessor_config: HFPreProcessorConfig =
serde_json::from_slice(&std::fs::read(preprocessor_config_filename)?)?;
let llava_config =
hf_llava_config.to_llava_config(&generation_config, &preprocessor_config);
let tokenizer_filename = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let clip_vision_config = hf_llava_config.to_clip_vision_config();
(
llava_config,
tokenizer,
Some(clip_vision_config),
ImageProcessor::from_hf_preprocessor_config(&preprocessor_config),
)
} else {
let config_filename = api.get("config.json")?;
let llava_config: LLaVAConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let tokenizer = Tokenizer::from_file(&args.tokenizer_path)
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.tokenizer_path, e)))?;
(
llava_config.clone(),
tokenizer,
None,
ImageProcessor::from_pretrained(&llava_config.mm_vision_tower.unwrap())?,
)
};
let llama_config = llava_config.to_llama_config();
let dtype: DType = match llava_config.torch_dtype.as_str() {
"float16" => DType::F16,
"bfloat16" => DType::BF16,
_ => bail!("unsupported dtype"),
};
let eos_token_id = llava_config.eos_token_id;
println!("setting kv cache");
let mut cache = Cache::new(!args.no_kv_cache, dtype, &llama_config, &device)?;
println!("loading model weights");
let weight_filenames =
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_filenames, dtype, &device)? };
let llava: LLaVA = LLaVA::load(vb, &llava_config, clip_vision_config)?;
println!("generating conv template");
let image_token_se = format!(
"{}{}{}",
DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_END_TOKEN
);
let qs = if args.prompt.contains(IMAGE_PLACEHOLDER) {
if llava_config.mm_use_im_start_end {
args.prompt.replace(IMAGE_PLACEHOLDER, &image_token_se)
} else {
args.prompt.replace(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN)
}
} else if llava_config.mm_use_im_start_end {
format!("{}\n{}", image_token_se, args.prompt)
} else {
format!("{}\n{}", DEFAULT_IMAGE_TOKEN, args.prompt)
};
let model_name = get_model_name_from_path(&args.model_path).to_lowercase();
let conv_mode = if model_name.contains("llama-2") {
"llava_llama_2"
} else if model_name.contains("mistral") {
"mistral_instruct"
} else if model_name.contains("v1.6-34b") {
"chatml_direct"
} else if model_name.contains("v1") {
"llava_v1"
} else if model_name.contains("mpt") {
"mpt"
} else {
"llava_v0"
};
if args.conv_mode.is_some() && args.conv_mode.as_deref() != Some(conv_mode) {
println!(
"Warning: the model is trained with {}, but you are using {}",
conv_mode,
args.conv_mode.as_deref().unwrap()
);
} else {
args.conv_mode = Some(conv_mode.to_string());
}
let mut conv = match args.conv_mode {
Some(conv_mode) => match conv_mode.as_str() {
"chatml_direct" => Conversation::conv_chatml_direct(),
"llava_v1" => Conversation::conv_llava_v1(),
_ => todo!("not implement yet"),
},
None => bail!("conv_mode is required"),
};
conv.append_user_message(Some(&qs));
conv.append_assistant_message(None);
let prompt = conv.get_prompt();
println!("loading image");
let (image_size, image_tensor) =
load_image(&args.image_file, &image_processor, &llava_config, dtype)
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.image_file, e)))?;
let image_tensor = image_tensor.to_device(&device)?;
let mut logits_processor = {
let temperature = f64::from(args.temperature);
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
Sampling::All { temperature }
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
// get input tokens
let tokens = tokenizer_image_token(
&prompt,
&tokenizer,
llava_config.image_token_index as i64,
&llava_config,
)?;
let mut input_embeds =
llava.prepare_inputs_labels_for_multimodal(&tokens, &[image_tensor], &[image_size])?;
//inference loop, based on https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let mut index_pos = 0;
for index in 0..args.max_new_tokens {
let (_, input_embeds_len, _) = input_embeds.dims3()?;
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(input_embeds_len, 0)
};
let input = input_embeds.i((.., input_embeds_len.saturating_sub(context_size).., ..))?;
let logits = llava.forward(&input, context_index, &mut cache)?; //[1,32000]
let logits = logits.squeeze(0)?;
let (_, input_len, _) = input.dims3()?;
index_pos += input_len;
let next_token = logits_processor.sample(&logits)?;
let next_token_tensor = Tensor::from_vec(vec![next_token], 1, &device)?;
let next_embeds = llava.llama.embed(&next_token_tensor)?.unsqueeze(0)?;
input_embeds = Tensor::cat(&[input_embeds, next_embeds], 1)?;
if next_token == eos_token_id as u32 {
break;
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
Ok(())
}

View File

@ -0,0 +1,40 @@
# candle-llava
LLaVA (Large Language-and-Vision Assistant) is an end-to-end trained large
multimodal model. This example is from [candle-llava](https://github.com/chenwanqq/candle-llava)
The code is based on [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA), Hence the llava-hf version of config may perform differently.
## model zoo
* [liuhaotian/LLaVA](https://huggingface.co/liuhaotian)
* [llava-hf](https://huggingface.co/llava-hf)
Right now this has been tested on `liuhaotian/llava-v1.6-vicuna-7b` and
`llava-hf/llava-v1.6-vicuna-7b-hf`. Memory usage might have room for optimization.
## Tokenizer Setup
The llava-hf models contain a `tokenizer.json` file so can be used directly with
the `-hf` command line flag.
For the original llava models, you can use the following code to generate the `tokenizer.json` file.
```bash
conda create -n llava python=3.10
pip install transformers protobuf
conda activate llava
python -c "from transformers import AutoTokenizer;tokenizer=AutoTokenizer.from_pretrained('liuhaotian/llava-v1.6-vicuna-7b');tokenizer.save_pretrained('tokenizer')"
```
Then the `tokenizer.json` file should be in `tokenizer/tokenizer.json` (which is the default path).
## eval
```bash
cargo run --example llava --features cuda -- --image-file "llava_logo.png" --prompt "is this a cat?" --hf # default args, use llava-hf/llava-v1.6-vicuna-7b-hf. image-file is required^_^
cargo run --example llava --features cuda -- --model-path liuhaotian/llava-v1.6-vicuna-7b --image-file "llava_logo.png" --prompt "is this a cat?" # use liuhaotian/llava-v1.6-vicuna-7b, tokenizer setup should be done
```
## Major Limitations
1. Currently only support llama-2/vicuna llm. Haven't supoort Mistral yet.
2. There are some ops like split, nonzero and where are not supported by candle.
3. Lack of quantization and LoRA support.

View File

@ -54,6 +54,7 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let dtype = self.model.dtype();
let mut tokens = self
.tokenizer
.tokenizer()
@ -66,7 +67,7 @@ impl TextGeneration {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let mut state = State::new(1, &self.config, &self.device)?;
let mut state = State::new(1, &self.config, dtype, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[t], &self.device)?;
@ -84,7 +85,7 @@ impl TextGeneration {
Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"),
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
@ -210,6 +211,9 @@ struct Args {
#[arg(long)]
config_file: Option<String>,
#[arg(long, default_value = "f32")]
dtype: String,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
@ -220,6 +224,7 @@ struct Args {
}
fn main() -> Result<()> {
use std::str::FromStr;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@ -279,7 +284,8 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let dtype = DType::from_str(&args.dtype)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb.pp("backbone"))?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -123,7 +123,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
if next_token == eos_token || tokens.ends_with(&[27, 10619, 29] /* <END> */) {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;

View File

@ -0,0 +1,36 @@
# candle-olmo: Open Language Models designed to enable the science of language models
OLMo is a series of Open Language Models designed to enable the science of language models.
- **Project Page:** https://allenai.org/olmo
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
<!-- - **Press release:** TODO -->
## Running the example
```bash
$ cargo run --example olmo --release -- --prompt "It is only with the heart that one can see rightly"
avx: true, neon: false, simd128: false, f16c: true
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64
retrieved the files in 354.977µs
loaded the model in 19.87779666s
It is only with the heart that one can see rightly; what is essential is invisible to the eye.
```
Various model sizes are available via the `--model` argument.
```bash
$ cargo run --example olmo --release -- --model 1.7-7b --prompt 'It is only with the heart that one can see rightly'
avx: true, neon: false, simd128: false, f16c: true
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64
retrieved the files in 1.226087ms
loaded the model in 171.274578609s
It is only with the heart that one can see rightly; what is essential is invisible to the eye.”
~ Antoine de Saint-Exupery, The Little Prince
I am a big fan of this quote. It reminds me that I need to be open and aware of my surroundings in order to truly appreciate them.
```

View File

@ -0,0 +1,284 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::olmo::{Config, Model as OLMo};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
OLMo(OLMo),
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, false)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model {
Model::OLMo(m) => m.forward(&input, start_pos)?,
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
enum Which {
#[value(name = "1b")]
W1b,
#[value(name = "7b")]
W7b,
#[value(name = "7b-twin-2t")]
W7bTwin2T,
#[value(name = "1.7-7b")]
V1_7W7b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 1000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long, default_value = "1b")]
model: Which,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => match args.model {
Which::W1b => "allenai/OLMo-1B-hf".to_string(),
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
Which::W1b => {
vec![repo.get("model.safetensors")?]
}
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = {
let config_filename = repo.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
config
};
let device = candle_examples::device(args.cpu)?;
let model = {
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = OLMo::new(&config, vb)?;
Model::OLMo(model)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,8 +1,9 @@
# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
only 1.3 and 2.7 billion parameters but with state of the art performance compared to
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5),
[Phi-2](https://huggingface.co/microsoft/phi-2), and
[Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) are language models using
only 1.3, 2.7, and 3.8 billion parameters but with state of the art performance compared to
models with up to 10 billion parameters.
The candle implementation provides both the standard version as well as a

View File

@ -7,11 +7,13 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3};
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle::{DType, Device, Tensor};
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
@ -20,13 +22,14 @@ use tokenizers::Tokenizer;
enum Model {
MixFormer(MixFormer),
Phi(Phi),
Phi3(Phi3),
Quantized(QMixFormer),
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
@ -49,7 +52,7 @@ impl TextGeneration {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
@ -61,7 +64,11 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
let tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the phi model.")
}
@ -73,13 +80,14 @@ impl TextGeneration {
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the endoftext token"),
};
print!("{prompt}");
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
let mut pos = 0;
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
@ -88,6 +96,7 @@ impl TextGeneration {
Model::MixFormer(m) => m.forward(&input)?,
Model::Phi(m) => m.forward(&input)?,
Model::Quantized(m) => m.forward(&input)?,
Model::Phi3(m) => m.forward(&input, pos)?.i((.., 0, ..))?,
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
@ -107,9 +116,11 @@ impl TextGeneration {
if next_token == eos_token {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
pos += context_size;
}
let dt = start_gen.elapsed();
println!(
@ -128,6 +139,10 @@ enum WhichModel {
V1_5,
#[value(name = "2")]
V2,
#[value(name = "3")]
V3,
#[value(name = "3-medium")]
V3Medium,
#[value(name = "2-old")]
V2Old,
PuffinPhiV2,
@ -196,6 +211,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The dtype to be used for running the model, e.g. f32, bf16, or f16.
#[arg(long)]
dtype: Option<String>,
}
fn main() -> Result<()> {
@ -236,6 +255,8 @@ fn main() -> Result<()> {
WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
@ -253,9 +274,11 @@ fn main() -> Result<()> {
WhichModel::V1 => "refs/pr/8".to_string(),
WhichModel::V1_5 => "refs/pr/73".to_string(),
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"main".to_string()
}
WhichModel::V2
| WhichModel::V3
| WhichModel::V3Medium
| WhichModel::PuffinPhiV2
| WhichModel::PhiHermes => "main".to_string(),
}
}
}
@ -264,9 +287,12 @@ fn main() -> Result<()> {
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
repo.get("tokenizer.json")?
}
WhichModel::V1
| WhichModel::V1_5
| WhichModel::V2
| WhichModel::V2Old
| WhichModel::V3
| WhichModel::V3Medium => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
@ -282,14 +308,19 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!(
"use the quantized or quantized-phi examples for quantized phi-v3"
),
}
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
)?,
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => {
candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
)?
}
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
}
@ -306,6 +337,9 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
WhichModel::V3 | WhichModel::V3Medium => {
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
}
};
let device = candle_examples::device(args.cpu)?;
let model = if args.quantized {
@ -320,7 +354,19 @@ fn main() -> Result<()> {
};
Model::Quantized(model)
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let dtype = match args.dtype {
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
None => {
if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium)
&& device.is_cuda()
{
DType::BF16
} else {
DType::F32
}
}
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
let config_filename = repo.get("config.json")?;
@ -329,6 +375,13 @@ fn main() -> Result<()> {
let phi = Phi::new(&config, vb)?;
Model::Phi(phi)
}
WhichModel::V3 | WhichModel::V3Medium => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: Phi3Config = serde_json::from_str(&config)?;
let phi3 = Phi3::new(&config, vb)?;
Model::Phi3(phi3)
}
WhichModel::V2Old => {
let config = config();
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
@ -421,6 +474,10 @@ fn mmlu<P: AsRef<std::path::Path>>(
m.clear_kv_cache();
m.forward(&input)?
}
Model::Phi3(m) => {
m.clear_kv_cache();
m.forward(&input, 0)?
}
Model::Quantized(m) => {
m.clear_kv_cache();
m.forward(&input)?

View File

@ -0,0 +1,325 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::gguf_file;
use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama::ModelWeights as Phi3b;
use candle_transformers::models::quantized_phi::ModelWeights as Phi2;
use candle_transformers::models::quantized_phi3::ModelWeights as Phi3;
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "phi-2")]
Phi2,
#[value(name = "phi-3")]
Phi3,
/// Alternative implementation of phi-3, based on llama.
#[value(name = "phi-3b")]
Phi3b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
#[arg(long)]
model: Option<String>,
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
/// is preserved.
#[arg(long)]
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize,
/// The tokenizer config in json format.
#[arg(long)]
tokenizer: Option<String>,
/// The temperature used to generate samples, use 0 for greedy sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "phi-3b")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
impl Args {
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = match self.which {
Which::Phi2 => "microsoft/phi-2",
Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
};
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
}
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename, revision) = match self.which {
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf", "main"),
Which::Phi3 => (
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
"main",
),
Which::Phi3b => (
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
"5eef2ce24766d31909c0b269fe90c817a8f263fb",
),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
revision.to_string(),
))
.get(filename)?
}
};
Ok(model_path)
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
enum Model {
Phi2(Phi2),
Phi3(Phi3),
Phi3b(Phi3b),
}
impl Model {
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
match self {
Self::Phi2(m) => m.forward(xs, pos),
Self::Phi3(m) => m.forward(xs, pos),
Self::Phi3b(m) => m.forward(xs, pos),
}
}
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let mut model = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
match args.which {
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
args.use_flash_attn,
model,
&mut file,
&device,
)?),
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
}
};
println!("model built");
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
print!("{}", &prompt_str);
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
let tokens = tokens.get_ids();
let to_sample = args.sample_len.saturating_sub(1);
let mut all_tokens = vec![];
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = *tos
.tokenizer()
.get_vocab(true)
.get("<|endoftext|>")
.unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
tokens.len(),
tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -67,6 +67,10 @@ enum Which {
Mixtral,
#[value(name = "mixtral-instruct")]
MixtralInstruct,
#[value(name = "llama3-8b")]
L8b,
#[value(name = "phi3")]
Phi3,
}
impl Which {
@ -82,7 +86,9 @@ impl Which {
| Self::L13bCode
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b => false,
| Self::Leo13b
| Self::L8b
| Self::Phi3 => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
@ -116,7 +122,9 @@ impl Which {
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::OpenChat35
| Self::Starling7bAlpha => false,
| Self::Starling7bAlpha
| Self::L8b
| Self::Phi3 => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
@ -140,33 +148,37 @@ impl Which {
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => false,
| Self::Zephyr7bBeta
| Self::L8b
| Self::Phi3 => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
}
}
fn tokenizer_repo(&self) -> &'static str {
match self {
Which::L7b
| Which::L13b
| Which::L70b
| Which::L7bChat
| Which::L13bChat
| Which::L70bChat
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
Which::Leo7b => "LeoLM/leo-hessianai-7b",
Which::Leo13b => "LeoLM/leo-hessianai-13b",
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Which::OpenChat35 => "openchat/openchat_3.5",
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => "hf-internal-testing/llama-tokenizer",
Self::Leo7b => "LeoLM/leo-hessianai-7b",
Self::Leo13b => "LeoLM/leo-hessianai-13b",
Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Self::OpenChat35 => "openchat/openchat_3.5",
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L8b => "meta-llama/Meta-Llama-3-8B",
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
}
}
}
@ -322,10 +334,28 @@ impl Args {
"TheBloke/Starling-LM-7B-alpha-GGUF",
"starling-lm-7b-alpha.Q4_K_M.gguf",
),
// TODO: swap to TheBloke model when available
Which::L8b => (
"QuantFactory/Meta-Llama-3-8B-GGUF",
"Meta-Llama-3-8B.Q4_K_S.gguf",
),
Which::Phi3 => (
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
),
};
let revision = if self.which == Which::Phi3 {
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
} else {
"main"
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
api.get(filename)?
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
revision.to_string(),
))
.get(filename)?
}
};
Ok(model_path)
@ -353,6 +383,9 @@ fn main() -> anyhow::Result<()> {
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
candle::cuda::set_gemm_reduced_precision_f16(true);
candle::cuda::set_gemm_reduced_precision_bf16(true);
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@ -420,7 +453,9 @@ fn main() -> anyhow::Result<()> {
| Which::L13bCode
| Which::L34bCode
| Which::Leo7b
| Which::Leo13b => 1,
| Which::Leo13b
| Which::L8b
| Which::Phi3 => 1,
Which::Mixtral
| Which::MixtralInstruct
| Which::Mistral7b
@ -537,11 +572,14 @@ fn main() -> anyhow::Result<()> {
std::io::stdout().flush()?;
}
let eos_token = if args.which.is_open_chat() {
"<|end_of_turn|>"
} else {
"</s>"
let eos_token = match args.which {
Which::L8b => "<|end_of_text|>",
_ => match args.which.is_open_chat() {
true => "<|end_of_turn|>",
false => "</s>",
},
};
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;

View File

@ -7,7 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle::{DType, Device, Tensor};
@ -144,6 +144,14 @@ enum WhichModel {
W72b,
#[value(name = "moe-a2.7b")]
MoeA27b,
#[value(name = "2-0.5b")]
W2_0_5b,
#[value(name = "2-1.5b")]
W2_1_5b,
#[value(name = "2-7b")]
W2_7b,
#[value(name = "2-72b")]
W2_72b,
}
#[derive(Parser, Debug)]
@ -234,16 +242,20 @@ fn main() -> Result<()> {
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
let size = match args.model {
WhichModel::W0_5b => "0.5B",
WhichModel::W1_8b => "1.8B",
WhichModel::W4b => "4B",
WhichModel::W7b => "7B",
WhichModel::W14b => "14B",
WhichModel::W72b => "72B",
WhichModel::MoeA27b => "MoE-A2.7B",
let (version, size) = match args.model {
WhichModel::W2_0_5b => ("2", "0.5B"),
WhichModel::W2_1_5b => ("2", "1.5B"),
WhichModel::W2_7b => ("2", "7B"),
WhichModel::W2_72b => ("2", "72B"),
WhichModel::W0_5b => ("1.5", "0.5B"),
WhichModel::W1_8b => ("1.5", "1.8B"),
WhichModel::W4b => ("1.5", "4B"),
WhichModel::W7b => ("1.5", "7B"),
WhichModel::W14b => ("1.5", "14B"),
WhichModel::W72b => ("1.5", "72B"),
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
};
format!("Qwen/Qwen1.5-{size}")
format!("Qwen/Qwen{version}-{size}")
}
};
let repo = api.repo(Repo::with_revision(
@ -261,11 +273,15 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
vec![repo.get("model.safetensors")?]
}
WhichModel::W4b
| WhichModel::W7b
| WhichModel::W2_7b
| WhichModel::W14b
| WhichModel::W72b
| WhichModel::W2_72b
| WhichModel::MoeA27b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}

View File

@ -0,0 +1,9 @@
# candle-recurrent-gemma
This model card corresponds to the 2B base version of the RecurrentGemma model
[huggingface model card](https://huggingface.co/google/recurrentgemma-2b).
```bash
cargo run --features cuda -r --example recurrent-gemma -- \
--prompt "Write me a poem about Machine Learning."
```

View File

@ -0,0 +1,321 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::quantized_recurrent_gemma::Model as QModel;
use candle_transformers::models::recurrent_gemma::{Config, Model as BModel};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
B(BModel),
Q(QModel),
}
impl Model {
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
match self {
Self::B(m) => m.forward(xs, pos),
Self::Q(m) => m.forward(xs, pos),
}
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "2b-it")]
Instruct2B,
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
top_k: usize,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let sampling = match temp {
None => candle_transformers::generation::Sampling::ArgMax,
Some(temperature) => match top_p {
None => candle_transformers::generation::Sampling::TopK {
temperature,
k: top_k,
},
Some(top_p) => candle_transformers::generation::Sampling::TopKThenTopP {
temperature,
k: top_k,
p: top_p,
},
},
};
let logits_processor = LogitsProcessor::from_sampling(seed, sampling);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
#[arg(long, default_value_t = 250)]
top_k: usize,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 8000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
#[arg(long)]
quantized: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::Base2B => "google/recurrentgemma-2b".to_string(),
Which::Instruct2B => "google/recurrentgemma-2b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.quantized {
let filename = match args.which {
Which::Base2B => "recurrent-gemma-2b-q4k.gguf",
Which::Instruct2B => "recurrent-gemma-7b-q4k.gguf",
};
let filename = api.model("lmz/candle-gemma".to_string()).get(filename)?;
vec![filename]
} else {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&filenames[0],
&device,
)?;
Model::Q(QModel::new(&config, vb.pp("model"))?)
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
Model::B(BModel::new(&config, vb.pp("model"))?)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.top_k,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -39,7 +39,7 @@ struct Args {
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
/// mask, positive makes the mask more selective.
#[arg(long, default_value_t = 0.)]
#[arg(long, allow_hyphen_values = true, default_value_t = 0.)]
threshold: f32,
/// Enable tracing (generates a trace-timestamp.json file).

View File

@ -46,7 +46,8 @@ The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
- `--cpu`: use the cpu rather than the gpu (much slower).
- `--height`, `--width`: set the height and width for the generated image.
- `--n-steps`: the number of steps to be used in the diffusion process.
- `--num-samples`: the number of samples to generate.
- `--num-samples`: the number of samples to generate iteratively.
- `--bsize`: the numbers of samples to generate simultaneously.
- `--final-image`: the filename for the generated image(s).
### Using flash-attention

View File

@ -9,6 +9,7 @@ use candle_transformers::models::stable_diffusion;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use stable_diffusion::vae::AutoEncoderKL;
use tokenizers::Tokenizer;
#[derive(Parser)]
@ -64,9 +65,13 @@ struct Args {
#[arg(long)]
n_steps: Option<usize>,
/// The number of samples to generate.
/// The number of samples to generate iteratively.
#[arg(long, default_value_t = 1)]
num_samples: i64,
num_samples: usize,
/// The numbers of samples to generate simultaneously.
#[arg[long, default_value_t = 1]]
bsize: usize,
/// The name of the final image to generate.
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
@ -236,8 +241,8 @@ impl ModelFile {
fn output_filename(
basename: &str,
sample_idx: i64,
num_samples: i64,
sample_idx: usize,
num_samples: usize,
timestep_idx: Option<usize>,
) -> String {
let filename = if num_samples > 1 {
@ -261,6 +266,33 @@ fn output_filename(
}
}
#[allow(clippy::too_many_arguments)]
fn save_image(
vae: &AutoEncoderKL,
latents: &Tensor,
vae_scale: f64,
bsize: usize,
idx: usize,
final_image: &str,
num_samples: usize,
timestep_ids: Option<usize>,
) -> Result<()> {
let images = vae.decode(&(latents / vae_scale)?)?;
let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?;
for batch in 0..bsize {
let image = images.i(batch)?;
let image_filename = output_filename(
final_image,
(bsize * idx) + batch + 1,
batch + num_samples,
timestep_ids,
);
candle_examples::save_image(&image, image_filename)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn text_embeddings(
prompt: &str,
@ -382,6 +414,7 @@ fn run(args: Args) -> Result<()> {
final_image,
sliced_attention_size,
num_samples,
bsize,
sd_version,
clip_weights,
vae_weights,
@ -475,6 +508,7 @@ fn run(args: Args) -> Result<()> {
.collect::<Result<Vec<_>>>()?;
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
let text_embeddings = text_embeddings.repeat((bsize, 1, 1))?;
println!("{text_embeddings:?}");
println!("Building the autoencoder.");
@ -496,7 +530,6 @@ fn run(args: Args) -> Result<()> {
} else {
0
};
let bsize = 1;
let vae_scale = match sd_version {
StableDiffusionVersion::V1_5
@ -560,12 +593,16 @@ fn run(args: Args) -> Result<()> {
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
if args.intermediary_images {
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename =
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
candle_examples::save_image(&image, image_filename)?
save_image(
&vae,
&latents,
vae_scale,
bsize,
idx,
&final_image,
num_samples,
Some(timestep_index + 1),
)?;
}
}
@ -574,11 +611,16 @@ fn run(args: Args) -> Result<()> {
idx + 1,
num_samples
);
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
candle_examples::save_image(&image, image_filename)?
save_image(
&vae,
&latents,
vae_scale,
bsize,
idx,
&final_image,
num_samples,
None,
)?;
}
Ok(())
}

View File

@ -115,7 +115,7 @@ pub fn main() -> anyhow::Result<()> {
let processor = image_processor::ViTImageProcessor::new(&processor_config);
let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?;
let image = processor.preprocess(image)?.to_device(&device)?;
let encoder_xs = model.encoder().forward(&image)?;

View File

@ -13,7 +13,7 @@ struct Block {
impl Block {
fn get(&self, key: &str) -> Result<&str> {
match self.parameters.get(&key.to_string()) {
match self.parameters.get(key) {
None => candle::bail!("cannot find {} in {}", key, self.block_type),
Some(value) => Ok(value),
}
@ -28,7 +28,7 @@ pub struct Darknet {
impl Darknet {
fn get(&self, key: &str) -> Result<&str> {
match self.parameters.get(&key.to_string()) {
match self.parameters.get(key) {
None => candle::bail!("cannot find {} in net parameters", key),
Some(value) => Ok(value),
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 175 KiB

View File

@ -448,9 +448,9 @@ pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows10
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
///
/// The integrated loudness measurement is not just the average power over the
/// entire signal. BS.1770-4 defines defines two stages of gating that exclude
/// entire signal. BS.1770-4 defines two stages of gating that exclude
/// parts of the signal, to ensure that silent parts do not contribute to the
/// loudness measurment. This function performs that gating, and returns the
/// loudness measurement. This function performs that gating, and returns the
/// average power over the windows that were not excluded.
///
/// The result of this function is the integrated loudness measurement.

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.5.0"
version = "0.6.0"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.0" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.0" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);

View File

@ -139,7 +139,9 @@ impl FlashAttn {
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let softmax_lse = dev
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
.w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.5.0"
version = "0.6.0"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -97,6 +97,50 @@ __device__ void im2col1d(
}
}
template <typename T>
__device__ void col2im1d(
const size_t dst_el,
const size_t l_out,
const size_t l_in,
const size_t c_out,
const size_t k_size,
const size_t stride,
const T *src,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// src: (b_size, l_in, c_out, l_k)
// dst: (b_size, c_out, l_out)
if (dst_i >= dst_el) {
return;
}
const size_t dst_s0 = c_out * l_out;
const size_t dst_s1 = l_out;
const size_t src_s0 = c_out * k_size * l_in;
const size_t src_s1 = c_out * k_size;
const size_t src_s2 = k_size;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t c_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= c_idx * dst_s1;
const int l_out_idx = tmp_dst_i;
dst[dst_i] = static_cast<T>(0);
int l_in_idx = l_out_idx / stride;
int k0 = l_out_idx - l_in_idx * stride;
// l_out_idx = l_in_idx * stride + k0
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
if (l_in_idx < l_in) {
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
dst[dst_i] += src[src_i];
}
}
}
template <typename T>
__device__ void im2col(
const size_t dst_numel,
@ -542,6 +586,20 @@ extern "C" __global__ void FN_NAME( \
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
} \
#define COL2IM1D_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t dst_el, \
const size_t l_out, \
const size_t l_in, \
const size_t c_out, \
const size_t k_size, \
const size_t stride, \
const TYPENAME *src, \
TYPENAME *dst \
) { \
col2im1d<TYPENAME>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \
} \
#define IM2COL_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t dst_numel, \
@ -643,6 +701,7 @@ MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
IM2COL_OP(__nv_bfloat16, im2col_bf16)
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16)
#endif
#if __CUDA_ARCH__ >= 530
@ -655,6 +714,7 @@ MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
IM2COL_OP(__half, im2col_f16)
IM2COL1D_OP(__half, im2col1d_f16)
COL2IM1D_OP(__half, col2im1d_f16)
#endif
CONV1D_OP(float, float, conv1d_f32)
@ -701,3 +761,8 @@ IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(double, im2col1d_f64)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
COL2IM1D_OP(float, col2im1d_f32)
COL2IM1D_OP(double, col2im1d_f64)
COL2IM1D_OP(uint8_t, col2im1d_u8)
COL2IM1D_OP(uint32_t, col2im1d_u32)

View File

@ -6,5 +6,6 @@ pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));

File diff suppressed because it is too large Load Diff

View File

@ -50,6 +50,15 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
dst[dst_id] = shr[0];
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
@ -58,6 +67,70 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
return x;
}
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
template <typename T>
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float2 mean_var = make_float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
mean_var.x += xi;
mean_var.y += xi * xi;
}
// sum up partial sums
mean_var = warp_reduce_sum(mean_var);
if (block_size > WARP_SIZE) {
__shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
__syncthreads();
mean_var = s_sum[lane_id];
mean_var = warp_reduce_sum(mean_var);
}
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
const float inv_std = rsqrtf(var + eps);
if (alpha == nullptr && beta == nullptr) {
for (int col = tid; col < ncols; col += block_size) {
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs);
}
}
else if (alpha == nullptr && beta != nullptr) {
for (int col = tid; col < ncols; col += block_size) {
float b = static_cast<float>(beta[col]);
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs + b);
}
}
else if (alpha != nullptr && beta == nullptr) {
for (int col = tid; col < ncols; col += block_size) {
float a = static_cast<float>(alpha[col]);
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs * a);
}
}
else {
for (int col = tid; col < ncols; col += block_size) {
float a = static_cast<float>(alpha[col]);
float b = static_cast<float>(beta[col]);
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs * a + b);
}
}
}
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
template <typename T>
@ -461,6 +534,13 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
} \
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const TYPENAME *beta, const int n_cols, const float eps) { \
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
} \
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
extern "C" __global__ void FN_NAME_I( \
const TYPENAME *src, \
@ -496,6 +576,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#if __CUDA_ARCH__ >= 800
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
@ -504,6 +585,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
LAYERNORM_OP(__half, layernorm_f16)
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
@ -516,6 +598,8 @@ SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64)
LAYERNORM_OP(float, layernorm_f32)
LAYERNORM_OP(double, layernorm_f64)
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)

View File

@ -0,0 +1,88 @@
// Adapted from https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda/argsort.cu
#define SORT_ORDER_ASC 1
#define SORT_ORDER_DESC 0
#include "cuda_utils.cuh"
#include<stdint.h>
template<typename T>
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
T tmp = a;
a = b;
b = tmp;
}
template<int order, typename T>
static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
if (col >= ncols_pad) {
return;
}
const T * x_row = x + row * ncols;
extern __shared__ int dst_row[];
// initialize indices
dst_row[col] = col;
__syncthreads();
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
}
}
__syncthreads();
}
}
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
#define ASORT_OP(TYPENAME, RUST_NAME) \
extern "C" __global__ void asort_asc_##RUST_NAME( \
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
) { \
k_argsort<SORT_ORDER_ASC>(x, dst, ncols, ncols_pad); \
} \
extern "C" __global__ void asort_desc_##RUST_NAME( \
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
) { \
k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \
} \
#if __CUDA_ARCH__ >= 800
ASORT_OP(__nv_bfloat16, bf16)
#endif
#if __CUDA_ARCH__ >= 530
ASORT_OP(__half, f16)
#endif
ASORT_OP(float, f32)
ASORT_OP(double, f64)
ASORT_OP(uint8_t, u8)
ASORT_OP(uint32_t, u32)
ASORT_OP(int64_t, i64)

View File

@ -60,6 +60,11 @@ __device__ __forceinline__ T silu_fwd(T x) {
return x / (static_cast<T>(1) + expg(-x));
}
template<typename T>
__device__ __forceinline__ T sigmoid_fwd(T x) {
return recipg(static_cast<T>(1) + expg(-x));
}
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
@ -116,6 +121,7 @@ UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
#endif
#if __CUDA_ARCH__ >= 530
@ -142,6 +148,7 @@ UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
UNARY_OP(__half, usilu_f16, silu_fwd(x))
UNARY_OP1(__half, upowf_f16, powg(x, param))
UNARY_OP(__half, usign_f16, sign_(x))
UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x))
#endif
UNARY_OP(uint8_t, ucopy_u8, x)
@ -193,3 +200,5 @@ UNARY_OP1(float, upowf_f32, powg(x, param))
UNARY_OP1(double, upowf_f64, powg(x, param))
UNARY_OP(float, usign_f32, sign_(x))
UNARY_OP(double, usign_f64, sign_(x))
UNARY_OP(float, usigmoid_f32, sigmoid_fwd(x))
UNARY_OP(double, usigmoid_f64, sigmoid_fwd(x))

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.5.0"
version = "0.6.0"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -68,6 +68,50 @@ METAL_FUNC void im2col(
}
}
template <typename T>
METAL_FUNC void col2im1d(
constant size_t &dst_el,
constant size_t &l_out,
constant size_t &l_in,
constant size_t &c_out,
constant size_t &k_size,
constant size_t &stride,
device const T *src,
device T *dst,
uint dst_i [[ thread_position_in_grid ]]
) {
// src: (b_size, l_in, c_out, l_k)
// dst: (b_size, c_out, l_out)
if (dst_i >= dst_el) {
return;
}
const size_t dst_s0 = c_out * l_out;
const size_t dst_s1 = l_out;
const size_t src_s0 = c_out * k_size * l_in;
const size_t src_s1 = c_out * k_size;
const size_t src_s2 = k_size;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t c_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= c_idx * dst_s1;
const int l_out_idx = tmp_dst_i;
dst[dst_i] = static_cast<T>(0);
int l_in_idx = l_out_idx / stride;
int k0 = l_out_idx - l_in_idx * stride;
// l_out_idx = l_in_idx * stride + k0
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
if (l_in_idx < l_in) {
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
dst[dst_i] += src[src_i];
}
}
}
template <typename T>
METAL_FUNC void im2col1d(
constant size_t &dst_numel,
@ -190,6 +234,21 @@ kernel void FN_NAME( \
) { \
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \
#define COL2IM1D_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_el, \
constant size_t &l_out, \
constant size_t &l_in, \
constant size_t &c_out, \
constant size_t &k_size, \
constant size_t &stride, \
device const T *src, \
device T *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
col2im1d<T>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst, tid); \
} \
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
@ -493,6 +552,10 @@ IM2COL_OP(uint32_t, im2col_u32)
IM2COL_OP(bfloat, im2col_bf16)
#endif
COL2IM1D_OP(float, col2im1d_f32)
COL2IM1D_OP(uint8_t, col2im1d_u8)
COL2IM1D_OP(uint32_t, col2im1d_u32)
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
@ -533,4 +596,4 @@ CONVT2D_OP(float, float, conv_transpose2d_f32)
CONVT2D_OP(half, float, conv_transpose2d_f16)
#if defined(__HAVE_BFLOAT__)
CONVT1D_OP(bfloat, float, conv_transpose2d_bf16)
#endif
#endif

View File

@ -207,6 +207,9 @@ INDEX_OP(is_u8_bf16, uint8_t, bfloat)
GATHER_OP(gather_u32_f32, uint, float)
GATHER_OP(gather_u32_f16, uint, half)
#if defined(__HAVE_BFLOAT__)
GATHER_OP(gather_u32_bf16, uint, bfloat)
#endif
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,4 @@
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
#include <metal_stdlib>
using namespace metal;

View File

@ -353,6 +353,65 @@ METAL_FUNC void rmsnorm(
}
}
template<typename T>
METAL_FUNC void layernorm(
constant size_t & src_numel,
constant size_t & el_to_sum_per_block,
device const T * src,
device T * dst,
device const T * alpha,
device const T * beta,
constant float & eps,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup float * shared_memory
) {
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
float tmp1 = 0;
float tmp2 = 0;
while (idx < stop_idx) {
tmp1 += float(src[idx]);
tmp2 += float(src[idx]) * float(src[idx]);
idx += block_dim;
}
shared_memory[tid] = tmp1;
shared_memory[tid + block_dim] = tmp2;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
/* wait for shared_memory[0] to be filled */
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = shared_memory[0] / float(el_to_sum_per_block);
float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean;
float inv_norm = 1.0f / sqrt(var + eps);
idx = start_idx + tid;
while (idx < stop_idx) {
float val = (float(src[idx]) - mean) * inv_norm;
if (alpha != nullptr) {
val *= float(alpha[idx - start_idx]);
}
if (beta != nullptr) {
val += float(beta[idx - start_idx]);
}
dst[idx] = T(val);
idx += block_dim;
}
}
#define RMSNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
@ -371,6 +430,25 @@ kernel void NAME( \
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
} \
#define LAYERNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
device const T *alpha, \
device const T *beta, \
constant float &eps, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = 0; \
layernorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \
} \
template<typename T>
METAL_FUNC void ropei(
constant size_t &bh,
@ -511,6 +589,8 @@ SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half)
LAYERNORM(layernorm_f32, float)
LAYERNORM(layernorm_f16, half)
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
@ -535,5 +615,6 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat)
LAYERNORM(layernorm_bf16, bfloat)
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
#endif

View File

@ -0,0 +1,97 @@
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
#include <metal_stdlib>
using namespace metal;
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
#define SORT_ASC 1
#define SORT_DESC 0
template<int order, typename T>
METAL_FUNC void argsort(
device const T * x,
device uint32_t * dst,
constant int64_t & ncols,
constant int64_t & ncols_pad,
threadgroup uint32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
int col = tpitg[0];
int row = tgpig[1];
if (col >= ncols_pad) return;
device const T * x_row = x + row * ncols;
threadgroup uint32_t * dst_row = shared_values;
// initialize indices
dst_row[col] = col;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == SORT_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
SWAP(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == SORT_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
SWAP(dst_row[col], dst_row[ixj]);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
#define ARGSORT(T, RUST_T) \
kernel void asort_asc_##RUST_T( \
device const T * x, \
device uint32_t * dst, \
constant int64_t & ncols, \
constant int64_t & ncols_pad, \
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
uint3 tgpig[[threadgroup_position_in_grid]], \
uint3 tpitg[[thread_position_in_threadgroup]] \
) { \
argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
} \
kernel void asort_desc_##RUST_T( \
device const T * x, \
device uint32_t * dst, \
constant int64_t & ncols, \
constant int64_t & ncols_pad, \
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
uint3 tgpig[[threadgroup_position_in_grid]], \
uint3 tpitg[[thread_position_in_threadgroup]] \
) { \
argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
} \
ARGSORT(float, f32)
ARGSORT(half, f16)
ARGSORT(uint8_t, u8)
ARGSORT(uint32_t, u32)
#if __METAL_VERSION__ >= 220
ARGSORT(int64_t, i64)
#endif
#if defined(__HAVE_BFLOAT__)
ARGSORT(bfloat, bf16)
#endif

View File

@ -1,5 +1,4 @@
#include <metal_stdlib>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
@ -57,27 +56,31 @@ kernel void FN_NAME(
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
} \
// WHERE_OP(float, int64_t, where_i64_f32)
// WHERE_OP(double, int64_t, where_i64_f64)
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
// WHERE_OP(int64_t, int64_t, where_i64_i64)
//
// WHERE_OP(float, uint32_t, where_u32_f32)
// WHERE_OP(double, uint32_t, where_u32_f64)
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(half, uint32_t, where_u32_f16)
WHERE_OP(float, uint32_t, where_u32_f32)
WHERE_OP(uint8_t, uint32_t, where_u32_u8)
WHERE_OP(uint32_t, uint32_t, where_u32_u32)
WHERE_OP(float, uint8_t, where_u8_f32)
WHERE_OP(half, uint8_t, where_u8_f16)
WHERE_OP(float, uint8_t, where_u8_f32)
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
#if __METAL_VERSION__ >= 220
WHERE_OP(int64_t, uint8_t, where_u8_i64)
WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(half, int64_t, where_i64_f16)
WHERE_OP(float, int64_t, where_i64_f32)
WHERE_OP(uint8_t, int64_t, where_i64_u8)
WHERE_OP(uint32_t, int64_t, where_i64_u32)
WHERE_OP(int64_t, int64_t, where_i64_i64)
#if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, int64_t, where_i64_bf16)
#endif
#endif
#if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
#endif
WHERE_OP(bfloat, uint32_t, where_u32_bf16)
#endif

View File

@ -12,7 +12,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const c_void;
let size = (data.len() * std::mem::size_of::<T>()) as u64;
let size = std::mem::size_of_val(data) as u64;
device.new_buffer_with_data(ptr, size, options)
}
@ -41,6 +41,10 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let input = BufferOffset {
buffer: &input,
offset_in_bytes: 0,
};
let output = new_buffer(&device, v);
call_unary_contiguous(
&device,
@ -48,7 +52,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
&kernels,
name,
v.len(),
&input,
input,
&output,
)
.unwrap();
@ -72,8 +76,8 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
&kernels,
name,
x.len(),
&left,
&right,
BufferOffset::zero_offset(&left),
BufferOffset::zero_offset(&right),
&output,
)
.unwrap();
@ -93,7 +97,15 @@ fn run_strided<T: Clone>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
let input = BufferOffset {
buffer: &input,
offset_in_bytes: offset,
};
let output_b = new_buffer(&device, v);
let output = BufferOffset {
buffer: &output_b,
offset_in_bytes: 0,
};
let kernels = Kernels::new();
call_unary_strided(
&device,
@ -101,16 +113,14 @@ fn run_strided<T: Clone>(
&kernels,
kernel,
shape,
&input,
input,
strides,
offset,
&output,
0,
output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, v.len())
read_to_vec(&output_b, v.len())
}
#[test]
@ -308,8 +318,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
&kernels,
name,
v.len(),
&input,
0,
BufferOffset::zero_offset(&input),
&output,
)
.unwrap();
@ -521,7 +530,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
&kernels,
"affine_f32",
size,
&input,
BufferOffset::zero_offset(&input),
&output,
mul as f32,
add as f32,
@ -554,9 +563,8 @@ fn run_affine_strided<T: Clone>(
&kernels,
"affine_f32_strided",
shape,
&input,
BufferOffset::zero_offset(&input),
strides,
0,
&output,
mul as f32,
add as f32,
@ -633,7 +641,7 @@ fn index_select_strided() {
fn index_select_f16() {
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
.into_iter()
.map(|x| f16::from_f32(x))
.map(f16::from_f32)
.collect();
let shape = [5, 2];
let stride = [2, 1];
@ -700,8 +708,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let embeddings_buffer = new_buffer(&device, &embeddings);
let ids_buffer = new_buffer(&device, &ids);
let embeddings_buffer = new_buffer(&device, embeddings);
let ids_buffer = new_buffer(&device, ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
@ -711,7 +719,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
command_buffer,
&kernels,
name,
shape,
@ -720,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
true,
shape,
stride,
&embeddings_buffer,
0,
&ids_buffer,
0,
BufferOffset::zero_offset(&embeddings_buffer),
BufferOffset::zero_offset(&ids_buffer),
&dst_buffer,
)
.unwrap();
@ -746,8 +752,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let embeddings_buffer = new_buffer(&device, &embeddings);
let ids_buffer = new_buffer(&device, &ids);
let embeddings_buffer = new_buffer(&device, embeddings);
let ids_buffer = new_buffer(&device, ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
@ -757,7 +763,7 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
command_buffer,
&kernels,
name,
shape,
@ -766,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
false,
shape,
stride,
&embeddings_buffer,
0,
&ids_buffer,
0,
BufferOffset::zero_offset(&embeddings_buffer),
BufferOffset::zero_offset(&ids_buffer),
&dst_buffer,
)
.unwrap();
@ -811,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
&dims,
&strides,
out_length,
&input,
0,
BufferOffset::zero_offset(&input),
&output,
)
.unwrap();
@ -931,6 +934,7 @@ fn softmax() {
);
}
#[allow(clippy::too_many_arguments)]
fn run_where_cond<I: Clone, T: Clone>(
shape: &[usize],
cond: &[I],
@ -965,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>(
);
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
let cond = BufferOffset {
buffer: &cond,
offset_in_bytes: cond_offset,
};
let left = BufferOffset {
buffer: &left,
offset_in_bytes: left_offset,
};
let right = BufferOffset {
buffer: &right,
offset_in_bytes: cond_offset,
};
call_where_cond_strided(
&device,
command_buffer,
&kernels,
name,
shape,
&cond,
(&cond_stride, cond_offset),
&left,
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
cond,
&cond_stride,
left,
&left_stride,
right,
&cond_stride,
&output,
)
.unwrap();
@ -1007,6 +1023,27 @@ fn where_cond() {
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
#[test]
fn where_cond_u32_f32() {
let shape = vec![6];
let cond = vec![0u32, 1, 0, 0, 1, 1];
let cond_l = (vec![1], 0);
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let left_l = (vec![1], 0);
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
let right_l = (vec![1], 0);
let results = run_where_cond(
&shape,
&cond,
cond_l,
&left_true,
left_l,
&right_false,
right_l,
"where_u32_f32",
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
fn run_gemm<T: Clone>(
(b, m, n, k): (usize, usize, usize, usize),
@ -1148,7 +1185,7 @@ fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b:
#[test]
fn random() {
fn calc_mean(data: &[f32]) -> f32 {
let sum = data.iter().sum::<f32>() as f32;
let sum = data.iter().sum::<f32>();
let count = data.len();
assert!(count > 0);
sum / count as f32
@ -1162,7 +1199,7 @@ fn random() {
let variance = data
.iter()
.map(|value| {
let diff = mean - (*value as f32);
let diff = mean - *value;
diff * diff
})
.sum::<f32>()
@ -1241,10 +1278,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
shape,
shape,
dim,
&input_buffer,
0,
&ids_buffer,
0,
BufferOffset::zero_offset(&input_buffer),
BufferOffset::zero_offset(&ids_buffer),
&output,
)
.unwrap();
@ -1346,10 +1381,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
shape,
shape,
dim,
&input_buffer,
0,
&indices_buffer,
0,
BufferOffset::zero_offset(&input_buffer),
BufferOffset::zero_offset(&indices_buffer),
&output,
)
.unwrap();
@ -1787,6 +1820,7 @@ fn avg_pool2d_u32() {
assert_eq!(results, expected);
}
#[allow(clippy::too_many_arguments)]
fn run_conv_transpose1d<T: Clone>(
input: &[T],
input_shape: &[usize],

View File

@ -67,6 +67,11 @@ template <typename T> METAL_FUNC T relu(T in){
template <typename T> METAL_FUNC T silu(T in){
return in / (static_cast<T>(1) + exp(-in));
}
template <typename T> METAL_FUNC T sigmoid(T in) {
return recip(static_cast<T>(1) + exp(-in));
}
#define TILE_SIZE 2
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@ -79,8 +84,8 @@ kernel void FN_NAME( \
return; \
} \
output[tid] = TYPENAME(FN(float(input[tid]))); \
}\
kernel void FN_NAME_STRIDED( \
} \
kernel void FN_NAME##_##strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
@ -93,6 +98,17 @@ kernel void FN_NAME_STRIDED( \
return; \
} \
output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \
} \
kernel void FN_NAME##_##tiled( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
for (uint i = 0; i < TILE_SIZE; i++) { \
const uint idx = tid * TILE_SIZE + i; \
output[idx] = TYPENAME(FN(float(input[idx]))); \
} \
}
#define UNARY_OP(NAME) \
@ -104,21 +120,17 @@ UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define COPY2D(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &d1, \
constant size_t &d2, \
constant size_t &src_s, \
constant size_t &dst_s, \
constant int64_t &d1, \
constant int64_t &d2, \
constant int64_t &src_s, \
constant int64_t &dst_s, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
uint2 idx [[thread_position_in_grid]] \
) { \
if (tid >= d1 * d2) { \
return; \
} \
size_t idx1 = tid / d2; \
size_t idx2 = tid - idx1 * d2; \
size_t src_idx = idx1 * src_s + idx2; \
size_t dst_idx = idx1 * dst_s + idx2; \
if (idx.x >= d1 || idx.y >= d2) return; \
int64_t src_idx = idx.x * src_s + idx.y; \
int64_t dst_idx = idx.x * dst_s + idx.y; \
output[dst_idx] = input[src_idx]; \
}
@ -146,6 +158,7 @@ UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
UNARY_OP(sign)
UNARY_OP(sigmoid)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@ -176,8 +189,9 @@ BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
BFLOAT_UNARY_OP(sign)
BFLOAT_UNARY_OP(sigmoid)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
COPY2D(copy2d_bf64, bfloat)
COPY2D(copy2d_bf16, bfloat)
#endif

View File

@ -0,0 +1,162 @@
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize};
use std::ffi::c_void;
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
/// actual total buffer length).
/// Then kernels can just do their op on their single point in the buffer.
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = (size + width - 1) / width;
let thread_group_count = MTLSize {
width: count,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
(thread_group_count, thread_group_size)
}
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
let mut sum = 0u64;
loop {
let presum = sum;
// Check all the pows
if dim0 >= (1 << (pows0 + 1)) {
pows0 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim1 >= (1 << (pows1 + 1)) {
pows1 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim2 >= (1 << (pows2 + 1)) {
pows2 += 1;
sum += 1;
}
if sum == presum || sum == 10 {
break;
}
}
MTLSize {
width: 1 << pows0,
height: 1 << pows1,
depth: 1 << pows2,
}
}
pub(crate) fn set_param<P: EncoderParam>(
encoder: &ComputeCommandEncoderRef,
position: u64,
data: P,
) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
pub(crate) trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
($type:ty) => {
impl EncoderParam for $type {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<$type>() as u64,
&data as *const $type as *const c_void,
);
}
}
};
}
primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u32);
primitive!(u64);
primitive!(f32);
pub struct BufferOffset<'a> {
pub buffer: &'a Buffer,
pub offset_in_bytes: usize,
}
impl<'a> BufferOffset<'a> {
pub fn zero_offset(buffer: &'a Buffer) -> Self {
Self {
buffer,
offset_in_bytes: 0,
}
}
}
impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of_val(data) as u64,
data.as_ptr() as *const c_void,
);
}
}
impl EncoderParam for &Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
impl<'a> EncoderParam for &BufferOffset<'a> {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
}
}
impl EncoderParam for &mut Buffer {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&mut Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1 as u64);
}
}
#[macro_export]
macro_rules! set_params {
($encoder:ident, ($($param:expr),+)) => (
let mut _index = 0;
$(
$crate::utils::set_param($encoder, _index, $param);
_index += 1;
)*
);
}

View File

@ -5,7 +5,7 @@ use criterion::{black_box, criterion_group, Criterion};
use std::time::Instant;
fn run(input: &Tensor, weight: &Tensor, bias: &Tensor) {
let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(&input);
let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(input);
}
const B: usize = 1;

147
candle-nn/src/kv_cache.rs Normal file
View File

@ -0,0 +1,147 @@
use candle::{Result, Tensor};
#[derive(Debug, Clone)]
pub struct Cache {
// all_data is an option on a Tensor, this makes it possible to only create the actual tensor
// on the first call where the batch size is easily known.
// Also this makes it safe to clone a KvCache that has been reseted (as in it will not share
// its internal state with the cloned instance).
all_data: Option<Tensor>,
dim: usize,
current_seq_len: usize,
max_seq_len: usize,
}
impl Cache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
current_seq_len: 0,
max_seq_len,
}
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn current_seq_len(&self) -> usize {
self.current_seq_len
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}
pub fn current_data(&self) -> Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
};
Ok(data)
}
pub fn reset(&mut self) {
self.current_seq_len = 0;
self.all_data = None;
}
pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?;
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
// self.all_data.get_or_insert_with.
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.max_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad)
};
let ad = self.all_data.as_mut().unwrap();
if self.current_seq_len + seq_len > self.max_seq_len {
candle::bail!(
"kv-cache: above max-seq-len {}+{seq_len}>{}",
self.current_seq_len,
self.max_seq_len
)
}
ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct KvCache {
k: Cache,
v: Cache,
}
impl KvCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
let k = Cache::new(dim, max_seq_len);
let v = Cache::new(dim, max_seq_len);
Self { k, v }
}
pub fn k_cache(&self) -> &Cache {
&self.k
}
pub fn v_cache(&self) -> &Cache {
&self.v
}
pub fn k_cache_mut(&mut self) -> &mut Cache {
&mut self.k
}
pub fn v_cache_mut(&mut self) -> &mut Cache {
&mut self.v
}
pub fn k(&self) -> Result<Option<Tensor>> {
self.k.current_data()
}
pub fn v(&self) -> Result<Option<Tensor>> {
self.v.current_data()
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.k.append(k)?;
self.v.append(v)?;
let out_k = self.k.current_data()?;
let out_v = self.v.current_data()?;
let k = match out_k {
None => {
let mut shape = k.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, k.dtype(), k.device())?
}
Some(k) => k,
};
let v = match out_v {
None => {
let mut shape = v.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, v.dtype(), v.device())?
}
Some(v) => v,
};
Ok((k, v))
}
pub fn current_seq_len(&self) -> usize {
self.k.current_seq_len()
}
pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
}
}

View File

@ -11,8 +11,8 @@
//! use candle_nn::{LayerNorm, Module};
//! # fn main() -> candle::Result<()> {
//!
//! let w = Tensor::new(1f32, &Cpu)?;
//! let b = Tensor::new(0f32, &Cpu)?;
//! let w = Tensor::new(&[1f32, 1f32, 1f32], &Cpu)?;
//! let b = Tensor::new(&[0f32, 0f32, 0f32], &Cpu)?;
//! let layer = LayerNorm::new(w, b, 1e-5);
//!
//! let xs = Tensor::new(
@ -28,7 +28,7 @@
//! ```
//!
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
use candle::{DType, Result, Tensor, D};
use candle::{DType, Module, Result, Tensor, D};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerNormConfig {
@ -105,8 +105,13 @@ impl LayerNorm {
}
}
impl crate::Module for LayerNorm {
impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
if x.is_contiguous() && self.remove_mean {
if let Some(bias) = self.bias.as_ref() {
return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32);
}
}
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
@ -162,11 +167,20 @@ impl RmsNorm {
pub fn into_inner(self) -> LayerNorm {
self.0
}
/// Faster variant of the forward kernel, this can only be used on contiguous tensors though.
pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> {
self.0.forward(xs)
}
}
impl crate::Module for RmsNorm {
impl Module for RmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.0.forward(xs)
if xs.is_contiguous() {
crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32)
} else {
self.0.forward(xs)
}
}
}

View File

@ -6,6 +6,7 @@ pub mod encoding;
pub mod func;
pub mod group_norm;
pub mod init;
pub mod kv_cache;
pub mod layer_norm;
pub mod linear;
pub mod loss;

View File

@ -1,4 +1,4 @@
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
use rayon::prelude::*;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
@ -39,13 +39,197 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
}
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
let xs = xs.chunk(2, candle::D::Minus1)?;
let xs = xs.chunk(2, D::Minus1)?;
&xs[0].silu()? * &xs[1]
}
struct Sigmoid;
impl candle::CustomOp1 for Sigmoid {
fn name(&self) -> &'static str {
"sigmoid"
}
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
use candle::backend::BackendStorage;
fn fwd<T: num_traits::Float>(v: T) -> T {
(v.neg().exp() + T::one()).recip()
}
// FIXME: using `candle::map_dtype` causes compilation errors.
let storage = match storage {
CpuStorage::BF16(slice) => {
CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
}
CpuStorage::F16(slice) => {
CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
}
CpuStorage::F32(slice) => {
CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
}
CpuStorage::F64(slice) => {
CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
}
_ => Err(candle::Error::UnsupportedDTypeForOp(
storage.dtype(),
self.name(),
))?,
};
Ok((storage, layout.shape().clone()))
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
storage: &candle::CudaStorage,
layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use candle::cuda_backend::SlicePtrOrNull;
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
use candle::{CudaDevice, WithDType};
struct S;
impl Map1 for S {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
let params = (el_count, dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
let dev = storage.device();
let slice = S.map(&storage.slice, dev, layout)?;
let dst = candle::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, layout.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::MetalError;
let device = storage.device();
let dtype = storage.dtype();
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
let command_buffer = device.command_buffer()?;
command_buffer.set_label("sigmoid");
let src = candle_metal_kernels::BufferOffset {
buffer: storage.buffer(),
offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
};
match (el_count % 2, dtype, layout.is_contiguous()) {
(0, DType::BF16 | DType::F16, true) => {
use candle_metal_kernels::unary::contiguous_tiled;
let kernel_name = match dtype {
DType::F16 => contiguous_tiled::sigmoid::HALF,
DType::F32 => contiguous_tiled::sigmoid::FLOAT,
DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
dtype => {
candle::bail!(
"Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
)
}
};
candle_metal_kernels::call_unary_contiguous_tiled(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
}
(_, _, true) => {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match dtype {
DType::F16 => contiguous::sigmoid::HALF,
DType::F32 => contiguous::sigmoid::FLOAT,
DType::BF16 => contiguous::sigmoid::BFLOAT,
dtype => {
candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
}
};
candle_metal_kernels::call_unary_contiguous(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
}
(_, _, false) => {
use candle_metal_kernels::unary::strided;
let kernel_name = match dtype {
DType::F16 => strided::sigmoid::HALF,
DType::F32 => strided::sigmoid::FLOAT,
DType::BF16 => strided::sigmoid::BFLOAT,
dtype => {
candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
}
};
let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
candle_metal_kernels::call_unary_strided(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
layout.dims(),
src,
layout.stride(),
dst,
)
.map_err(MetalError::from)?;
}
}
let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
Ok((new_storage, layout.shape().clone()))
}
fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
// d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
}
}
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
// TODO: Should we have a specialized op for this?
(xs.neg()?.exp()? + 1.0)?.recip()
xs.apply_op1(Sigmoid)
}
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
@ -70,7 +254,7 @@ pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
let scale = 1.0 / (1.0 - drop_p as f64);
let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
let mask = (rand.ge(&drop_p)? * scale)?.to_dtype(xs.dtype())?;
let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?;
xs * mask
}
@ -436,15 +620,15 @@ pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(candle::D::Minus1)?;
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
}
pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
let hidden_size_xs = xs.dim(candle::D::Minus1)?;
let hidden_size_xs = xs.dim(D::Minus1)?;
let hidden_size_alpha = alpha.dims1()?;
if hidden_size_xs != hidden_size_alpha {
candle::bail!(
@ -456,6 +640,254 @@ pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
}
#[derive(Debug, Clone)]
struct LayerNorm {
eps: f32,
}
impl candle::CustomOp3 for LayerNorm {
fn name(&self) -> &'static str {
"layer-norm"
}
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)> {
use candle::backend::BackendStorage;
let eps = self.eps;
fn inner<
T: candle::WithDType
+ num_traits::Float
+ num_traits::AsPrimitive<f32>
+ num_traits::FromPrimitive,
>(
src: &[T],
layout: &Layout,
alpha: &[T],
alpha_layout: &Layout,
beta: &[T],
beta_layout: &Layout,
eps: f32,
) -> Result<(CpuStorage, Shape)> {
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => &src[o1..o2],
};
let alpha = match alpha_layout.contiguous_offsets() {
None => candle::bail!("alpha has to be contiguous"),
Some((o1, o2)) => &alpha[o1..o2],
};
let beta = match beta_layout.contiguous_offsets() {
None => candle::bail!("beta has to be contiguous"),
Some((o1, o2)) => &beta[o1..o2],
};
let el_count = layout.shape().elem_count();
let dims = layout.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let mut dst = vec![T::zero(); el_count];
src.par_chunks(dim_m1)
.zip(dst.par_chunks_mut(dim_m1))
.for_each(|(src, dst)| {
let mut sum = 0f32;
let mut sum2 = 0f32;
for v in src {
let v = v.as_();
sum += v;
sum2 += v * v;
}
let mean = sum / dim_m1 as f32;
let var = sum2 / dim_m1 as f32 - mean * mean;
let inv_std = (var + eps).sqrt().recip();
for ((d, s), (alpha, beta)) in
dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
{
let alpha = alpha.as_();
let beta = beta.as_();
let d_ = (s.as_() - mean) * inv_std * alpha + beta;
*d = T::from_f32(d_).unwrap_or_else(T::nan);
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, Shape::from_dims(dims)))
}
use CpuStorage as C;
match (s1, s2, s3) {
(C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
}
(C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
(C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
_ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
l1: &Layout,
s2: &candle::CudaStorage,
l2: &Layout,
s3: &candle::CudaStorage,
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
use candle::{CudaDevice, WithDType};
struct S {
eps: f32,
}
impl Map3 for S {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
layout: &Layout,
alpha: &CudaSlice<T>,
alpha_layout: &Layout,
beta: &CudaSlice<T>,
beta_layout: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let alpha = match alpha_layout.contiguous_offsets() {
None => candle::bail!("alpha has to be contiguous"),
Some((o1, o2)) => alpha.slice(o1..o2),
};
let beta = match beta_layout.contiguous_offsets() {
None => candle::bail!("beta has to be contiguous"),
Some((o1, o2)) => beta.slice(o1..o2),
};
let el = layout.shape().elem_count();
let dims = layout.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (1024, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
use candle::backend::BackendStorage;
let dev = s1.device();
let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
let dst = candle::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, l1.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
s1: &candle::MetalStorage,
l1: &Layout,
s2: &candle::MetalStorage,
l2: &Layout,
s3: &candle::MetalStorage,
l3: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
let device = s1.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
(DType::F32, DType::F32, DType::F32) => "layernorm_f32",
(DType::F16, DType::F16, DType::F16) => "layernorm_f16",
(DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
(dt1, dt2, dt3) => {
candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
}
};
if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
candle::bail!("Non contiguous layernorm is not implemented");
}
let last_dim = l1.dims()[l1.shape().rank() - 1];
let elem_count = l1.shape().elem_count();
let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
candle_metal_kernels::call_layer_norm(
device.metal_device(),
&command_buffer,
kernels,
name,
elem_count,
last_dim,
self.eps,
s1.buffer(),
l1.start_offset() * s1.dtype().size_in_bytes(),
s2.buffer(),
l2.start_offset() * s2.dtype().size_in_bytes(),
s3.buffer(),
l3.start_offset() * s3.dtype().size_in_bytes(),
&output,
)
.map_err(candle::Error::wrap)?;
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
Ok((newstorage, l1.shape().clone()))
}
}
pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let x = {
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)?
};
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(alpha)?
.broadcast_add(beta)
}
pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
let hidden_size_xs = xs.dim(D::Minus1)?;
let hidden_size_alpha = alpha.dims1()?;
let hidden_size_beta = beta.dims1()?;
if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
candle::bail!(
"shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
xs.shape(),
alpha.shape(),
beta.shape()
)
}
xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
}
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
let (b_size, c, h, w) = xs.dims4()?;
@ -494,3 +926,24 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
n => candle::bail!("replication-pad with a size of {n} is not supported"),
}
}
#[derive(Clone, Debug)]
pub struct Identity;
impl Identity {
pub fn new() -> Identity {
Self
}
}
impl Default for Identity {
fn default() -> Self {
Self
}
}
impl Module for Identity {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
Ok(xs.clone())
}
}

Some files were not shown because too many files have changed in this diff Show More