Compare commits

..

192 Commits

Author SHA1 Message Date
7ff921c538 Add RandomNormal ONNX operator (#2200) 2024-05-21 21:47:32 +02:00
9b8537a62f Remove the deprecated wav crate in favor of hound. (#2202) 2024-05-21 21:43:35 +02:00
7ebc3548e1 Use flash-attn in gemma. (#2195)
* Use flash-attn in gemma.

* Fix flash-attn for head dim 256.
2024-05-18 19:18:59 +02:00
eefc1c77ef Support flash-attn in quantized phi3. (#2194) 2024-05-18 17:12:56 +02:00
01545f7303 Add a slice_set op. (#2193)
* Add a slice_set op.

* Add some testing.

* Add the dedicated kv-cache module.

* Derive debug and clone.

* Expose more kv-cache functions.

* Return the current data when appending.

* Use the new cache in the quantized phi3 model.
2024-05-18 15:58:18 +02:00
349c3e806a Support embedding model gte-Qwen1.5-7B-instruct (#2190)
* Support embedding model gte-Qwen1.5-7B-instruct

This is a text embedding model based on Qwen2. They share same
model architecture except the last MLP module. This commit brings in
minimal modification of the old Qwen2 implementation to support both
models.

An example is provided, and had been verified according to the official
PyTorch implementation.

* Avoid doing the 'last-token filtering' based on the absence of attention mask.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-05-16 21:34:10 +02:00
bdaa34216a chore: add fix for windows cudarc into the readme (#2189) 2024-05-16 14:32:50 +02:00
cc80e065e5 Allow the threshold argumet to be negative in the segment-anything example (#2187)
Threshold is 0.0 by default, negative values make more points included,
expanding the mask. Positive values make it more picky, making the mask
smaller.

Negative numbers start with a minus sign, which normally makes clap
consider it a flag.
2024-05-15 13:17:20 +02:00
13c64f6828 Fix VarBuilder::from_slice_safetensors (#2180)
Also implement SimpleBackend for SliceSafetensors

Signed-off-by: Harry Stern <harry@harrystern.net>
2024-05-12 07:26:06 +02:00
21f82a5155 Add SliceSafetensors. (#2179)
* Add SlicedSafetensors.

* And add some testing.
2024-05-11 13:15:42 +02:00
9cff7bc3f4 Make it possible to use TF32 accumulation in F32 matmuls. (#2178)
* Allow the use of tf32 accumulation in matmul.

* Better timings.

* Dummy versions for use when cuda is not enabled.
2024-05-11 12:28:39 +02:00
d9bc5ec151 Switch cudarc back to dynamic linking. (#2176) 2024-05-09 10:35:44 +02:00
84328e2b60 Update cudarc requirement from 0.11.0 to 0.11.1 (#2174)
* Upgrading cudarc dependency from v0.11.0 to v0.11.1 due to that version having resolved a compile-time bug.

See: https://github.com/huggingface/candle/issues/2173
2024-05-08 20:40:36 +02:00
82b641fd27 Update cudarc requirement from 0.10.0 to 0.11.0 (#2165)
* Update cudarc requirement from 0.10.0 to 0.11.0

Updates the requirements on [cudarc](https://github.com/coreylowman/cudarc) to permit the latest version.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.10.0...v0.10.0)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* Use the default cuda version.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-05-06 17:12:14 +02:00
01794dc16e Use write rather than try-write on the metal rw-locks. (#2162) 2024-05-05 07:22:46 +02:00
a75cd8164f Force the revision for the phi3-llama quantized models. (#2159) 2024-05-04 10:41:18 +02:00
b13a82a438 Separate quantized phi-3 implementation. (#2157)
* Separate quantized phi-3 implementation.

* Integrate the quantized phi3 model.=

* Small fixes, get the generation to work properly.

* Keep the old llama implementation around.

* Change the default.
2024-05-04 10:14:57 +02:00
59b18d974e Pin the version used for the quantized phi 3 gguf file. (#2156) 2024-05-03 15:03:22 +02:00
89f53b9d7b Bump the version number to 0.5.1. (#2155)
* Bump the version number to 0.5.1.

* Fix clippy lints for 1.78.

* More clippy fixes.
2024-05-03 11:17:05 +02:00
a09d451d11 Support top-k in tthe llama example. (#2150) 2024-05-01 22:25:47 +02:00
fa06f5f5f9 F16/BF16 bugfix (bis). (#2143)
* F16/BF16 bugfix (bis).

* Another fix.

* Yet another fix.
2024-04-29 14:08:44 +02:00
09d4845aa8 Bugfix the recent f16/bf16 changes. (#2142) 2024-04-29 13:30:11 +02:00
a0d03aded1 Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. (#2124)
* When converting a tensor to a variable, clone if the tensor is already a variable.

* Add a test to ensure training a batch norm works with VarMaps

---------

Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
2024-04-29 11:21:53 +02:00
3bbb88fcb4 Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op

* small fix

* add as a method on `Tensor`

* implement gradient calculation for sigmoid

* add sigmoid tests

* we should have a specialized op for this

* fix clippy

* fix clippy 2

* Revert all previous commits in favor of a `CustomOp` based solution

* use `CustomOp1` implementation

* fix rustfmt

* experimental add metal impl

* add cuda kernel impl

* fix fmt

* Add a test + reduce some cuda duplication.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-29 11:04:43 +02:00
ed7b99f525 Add a toggle for F16/BF16 accumulation in gemm. (#2141)
* Add a toggle to control f16/bf16 gemm precision.

* Use the faster variant in the quantized example.

* Bugfix.
2024-04-29 09:21:07 +02:00
287013ef28 Add a forward_via_f16 method to the qmatmul op. (#2138) 2024-04-28 20:35:01 +02:00
eb26e2467e Add the cuda dequantize f16 kernels. (#2137)
* Add the cuda dequantize f16 kernels.

* Expose the cuda kernels.

* Add some testing + fix.

* Test the other cases too.

* A few more tests.

* Add an environment variable to enable the dequantize f16 + matmul behavior.
2024-04-28 20:05:05 +02:00
c68ed8963f chore: fix some typos in comments (#2121)
Signed-off-by: hardlydearly <799511800@qq.com>
2024-04-28 08:34:32 +02:00
e5c8b88f90 Apply the cast before the scaling. (#2135) 2024-04-28 08:30:35 +02:00
805f3be8e1 Add a sort function. (#2134) 2024-04-28 08:18:04 +02:00
3b429f3023 Make the dtype configurable for phi. (#2133) 2024-04-27 21:32:49 +02:00
96a48e5cc4 Add argsort. (#2132)
* Add the argsort cuda kernels.

* CPU version of arg-sort.

* Hook the cuda kernel + rework the cpu bits.

* Add some dedicated test.

* Working cuda kernel.

* Metal kernel.

* Metal adjustments.

* Bugfix.

* Use the fast rope in qwen.

* Rework the expert selection in qwen.
2024-04-27 20:17:35 +02:00
6cf82fd7a3 Add Olmo models (#2127)
* add olmo support

* add olmo readme

* Fix fmt.

* Fix clippy.

* Get olmo to work on cuda.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-26 11:02:51 +02:00
cfab6e7616 Mention phi-v3 in the readmes. (#2122) 2024-04-24 20:54:24 +02:00
11d4a3c588 Add the phi-3 model. (#2120)
* Add the phi-3 model.

* Faster rope.

* Bugfix.

* Fix the detokenization.
2024-04-24 09:48:13 +02:00
9d3f1c8af5 Add the phi-v3 quantized model. (#2118)
* Add the phi-v3 quantized model.

* Also include phi-3 in the main phi example.
2024-04-24 08:22:23 +02:00
7211009179 Fix for rustfmt. (#2117) 2024-04-23 19:09:33 +02:00
6fadaf2eff candle-onnx: add operators RandomUniform and Exp (#2116)
* Add basic RandomUniform implementation

* Use is_some to check if seed is present

* Added Exp operator implementation

---------

Co-authored-by: Mateusz Okulus <mmokulus@gmail.com>
2024-04-23 19:02:19 +02:00
8a05743a21 Add StorageRef. (#2113)
* Add the storage-ref bits.

* Add the metal implementation.
2024-04-23 13:23:27 +02:00
b2e816752b Use the faster rms-norm kernel for llama. (#2107)
* Use the faster rms-norm kernel for llama.

* Use the fast variant by default.
2024-04-22 18:52:00 +02:00
618ecf5e23 Better time measurement for the llama example. (#2106) 2024-04-22 17:54:27 +02:00
267601eec1 Update tokenizers requirement from 0.15.0 to 0.19.1 (#2104)
Updates the requirements on [tokenizers](https://github.com/huggingface/tokenizers) to permit the latest version.
- [Release notes](https://github.com/huggingface/tokenizers/releases)
- [Changelog](https://github.com/huggingface/tokenizers/blob/main/RELEASE.md)
- [Commits](https://github.com/huggingface/tokenizers/compare/v0.15.0...v0.15.2)

---
updated-dependencies:
- dependency-name: tokenizers
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-22 17:10:46 +02:00
08a15cb79e Update zip requirement from 0.6.6 to 1.1.1 (#2103)
* Update zip requirement from 0.6.6 to 1.1.1

---
updated-dependencies:
- dependency-name: zip
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* Fix for the zip crate update.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-22 16:23:27 +02:00
c388be93e7 Updated quantized phi model (#2099)
* Quantized phi in a separate file.

* Add the quantized phi example + rework the model code.

* Improve the phi model.

* Get some generation out.

* Use the appropriate rope shape.

* Tweak the default prompt.

---------

Co-authored-by: Jane Doe <jane.doe@example.org>
2024-04-21 07:37:07 +02:00
d22f1d4f4e Derive clone and debug traits for Moondream model (#2100)
* moondream implementation

* add moondream example

* change config default activation

* Add assets and integrate phi mixformer with example

* Make use of kv cache and fix seq_len bug; Clean up example code

* Add README link to example

* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig

* Delete image

* Use apply instead of forward

* Use latest release special token; Fix token/s accuracy; Use GeluPytorchTanh in VisionConfig v2

* Derive debug and clone traits for Moondream model.
2024-04-21 07:08:28 +02:00
0067fe00a8 Metal Unary: Add benchmarks and process kernels in a tile based fashion (#2056)
* add basic unary bench for sqrt

* process unary commands in tiles of 4

* re-enable all benchmarks

* rename helper to unary

* modify approach to split up tiled and non-tiled operations

* undo bench ignore for other tests

* update tile size to 2

* only perform the optimization on the contiguous even numbered element case
2024-04-21 00:10:33 +02:00
587ee3bb6f Small cleanups to the llama multi-process example. (#2098) 2024-04-20 22:19:46 +02:00
dd78422701 Handle multiple dimensions in metal QMM + two fixes. (#2097) 2024-04-20 18:55:45 +02:00
9215e9ce8c Add missing onnx operations (#2096)
* Add missing onnx operations

* Add tests and fix errors

* Run rustfmt
2024-04-20 18:44:22 +02:00
52ae332910 Use llama v3 by default + add to readme. (#2094) 2024-04-20 16:11:24 +02:00
8b390ddd29 Only download the weights in the main process (and not in the child processes). (#2093) 2024-04-20 13:01:23 +02:00
c97d639fa0 Multiprocess/multi-GPU support for llama 3. (#2092)
* Multiprocess/multi-GPU support for llama 3.

* Modernize the mp example a bit.
2024-04-20 12:49:21 +02:00
b45c710dbf Fix for gemma MQA. (#2091) 2024-04-19 21:49:55 +02:00
9c532aef47 Also enable llama-v3 8b instruct. (#2088) 2024-04-19 08:50:06 +02:00
f7a6468238 Add support for llama3 on the quantized example (#2086)
* add support for l3b, new tokenizer

* add todo

* Add todo and use k_s model

* Use the official tokenizers.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-18 22:52:00 +02:00
2b93dffe64 Use faster rotary embeddings for llama like models. (#2087) 2024-04-18 22:34:29 +02:00
e6ee7ba4d4 Llama v3. (#2085)
* Llama v3.

* Tweak the default params + handle special tokens.

* Small tweak.
2024-04-18 22:19:54 +02:00
1690ab45d2 Fix the silu gradient issue on 0. (#2083) 2024-04-18 14:31:41 +02:00
8de0ce6cba Add more QMMV cuda kernels. (#2077)
* Add more QMMV cuda kernels.

* Enable the new kernels.

* Adapt the testing.
2024-04-18 08:36:43 +02:00
ce6d08df94 Minor fix to the readme. (#2080)
Co-authored-by: Jane Doe <jane.doe@example.org>
2024-04-17 22:43:00 +02:00
2817643db9 Add the mmv kernels for small batch sizes. (#2075)
* Add the mmv kernels for smaller sizes.

* Support more mmv kernels.

* Use the new kernels.

* Fix the call.

* Silly fix.

* Improve the testing.

* Fix for dmmv.

* Add another dedicated test for the batching mmv.
2024-04-16 21:30:51 +02:00
4d14777673 Utilize batches in Stable Diffusion (#2071)
* Utilize batches in Stable Diffusion that were already there, but unutilized.

Also refactor out the `save_image` function.

* Clippy + cosmetic fixes.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-16 06:49:04 +02:00
f135b7963d Fix for the batch dim in the quantized matmul example. (#2073)
* Fix for the batch dim in the quantized matmul example.

* Enable more tests on cuda.

* Add a test for qmm with a batch.

* Fix the zeros-dim test on metal.
2024-04-15 20:00:28 +02:00
af955f260c Make the falcon model cloneable. (#2067) 2024-04-15 09:39:03 +02:00
8ad822a983 Add a function to clear the KV cache in falcon. (#2066)
* Add a function to clear the KV cache in falcon.

* Clippy.
2024-04-15 09:29:25 +02:00
e198bb0816 Handle zero dims in some simple operations. (#2064)
* Handle zero dims in some simple operations.

* Handle zero-dims in matmul.

* More testing.
2024-04-15 09:18:54 +02:00
f7d5bf5b97 Faster kernels for quantized matmul on cuda (#2060)
* Hook the quantized matmul cuda kernels.

* Add a (currently broken) test.

* Kernel fixes.

* Fix by transposing the rhs matrix.

* Add the q4-1 kernels.

* Proper block sizes.

* More details in the tests.
2024-04-15 08:32:47 +02:00
c119600d6e Move image tensor to device in trocr example (#2063)
Signed-off-by: Harry Stern <harry@harrystern.net>
2024-04-15 06:50:32 +02:00
c449f65b12 Expose the synchronize function on the generic device. (#2062) 2024-04-14 23:02:03 +02:00
db7dbf3071 Add missing bfloat unary strided kernels and fix typo (#2058) 2024-04-14 20:01:13 +02:00
4ecedb1598 Add the full quantized matmul kernels for cuda. (#2057) 2024-04-14 17:52:08 +02:00
53e5380bf6 Add a synchronize method to devices. (#2055)
* Add a synchronize method to devices.

* Metal version.
2024-04-14 16:32:55 +02:00
50e49ecc5f Add a quantized version of recurrent-gemma. (#2054)
* Add a quantized version of recurrent-gemma.

* Share the rglru part.

* Get the quantized gemma model to work.
2024-04-13 20:07:01 +02:00
4c88c3ce06 Add benchmarks for qmatmul operations (#2048)
* Add qmatmul bench

* add all dtypes
2024-04-13 12:30:14 +02:00
8b8fb630df Add a convenient way to rename tensors accessed through a varbuilder. (#2052) 2024-04-13 12:09:41 +02:00
fb805b8ca2 Avoid crashes when running T5 models with F16 tensors on CPU (#2047)
* This change avoids crashes when running T5 models with F16 tensors on CPU.

* This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point.

* Revert "This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point."

This reverts commit d886d3ce5e.
2024-04-13 11:07:28 +02:00
79e3bec789 Change for the encoder-only ProstT5 model (#2045)
* This change avoids crashes when running T5 models with F16 tensors on CPU.

* This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN.  This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point.
2024-04-13 11:06:24 +02:00
e6d412b156 Add ReduceMean onnx operation (#2049)
* Add ReduceMean onnx operation

* Format code with rustfmt
2024-04-13 11:00:25 +02:00
26cbbf8d84 Mandatory topk sampling for recurrent-gemma. (#2051) 2024-04-13 10:31:39 +02:00
2bf413caa3 Add the recurrent-gemma model. (#2039)
* Start adding the recurrent-gemma model.

* More griffin.

* Add the example + get the weights to load from the HF version.

* More inference code.

* Rope + kv-cache on the attention side.

* Add to the inference code.

* Add more to the recurrent gemma inference.

* Get some first inference to run.

* Add the softcap on logits.

* Fixes.

* Use partial rotary embeddings.

* Get inference to work.

* Add a comment.

* And add a readme.
2024-04-13 00:05:21 +02:00
3ad4770eb6 Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation.

* Move the function to utils + use it in mistral.

* Use the shared repeat-kv in a few more models.

* Fix.
2024-04-12 09:15:10 +02:00
a0460cd2b1 Add the code-gemma models. (#2038)
* Add the code-gemma models.

* Tweak to the gemma config.
2024-04-10 21:19:21 +02:00
b81ecf712d Support alternative dtypes for mamba (#2036)
* Allow different dtypes in mamba.

* Add a dtype flag.
2024-04-10 18:10:01 +02:00
a4d5a414e3 Support gather on bf16 for metal. (#2035) 2024-04-10 12:49:25 +02:00
798e0335cd Handle more tensor shapes in onnx "Gather" operation (#2026)
* Handle more tensor shapes in onnx "Gather" operation

* Add more tests

* Add comment

* Fix typo
2024-04-08 14:06:14 +02:00
718671a0d5 Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend.

* More BufferOffset usage.

* Use in where-cond.
2024-04-08 09:37:25 +02:00
c5fe4a7f89 Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module.

* Use the BufferOffset for unary ops.

* Fix clippy lints.

* Use the new BufferOffset.

* Adapt the binary ops.

* Affine.

* More ops (powf, elu, cast).
2024-04-07 22:37:53 +02:00
7f354473cf Optimize copy-2d for metal. (#2024)
* Optimize copy-2d for metal.

* Add a hacky stopping rule for moondream.
2024-04-07 12:34:16 +02:00
33c9b66554 Add the new gemma models. (#2023)
* Add the new gemma models.

* Revert the lightning changes.

* Support for the 1.1 models.
2024-04-06 21:25:38 +02:00
9fd52b3b71 Handle the batch dimension in quantized MMV on metal. (#2022) 2024-04-06 20:02:24 +02:00
e662431acf Fix the final rmsnorm for quantized-metavoice. (#2021) 2024-04-06 19:35:01 +02:00
ab892274d1 first commit (#2018) 2024-04-05 15:20:28 +02:00
b869a659ec Faster mask implementation for mixformers. (#2017)
* Faster mask implementation for mixformers.

* Clippy.
2024-04-05 09:38:26 +02:00
88f7793598 Moondream tracing. (#2016)
* Moondream tracing.

* A bit more tracing.
2024-04-05 09:11:08 +02:00
2ac302a5d1 Add the rope THD kernel. (#2014)
* Add the rope THD kernel.

* Cuda kernel for rope-thd.

* Add the metal kernels.

* Add a dedicated test.
2024-04-05 08:32:58 +02:00
ace282e5c2 Add flag to run Moondream in f16 precision (#2015)
* moondream implementation

* add moondream example

* change config default activation

* Add assets and integrate phi mixformer with example

* Make use of kv cache and fix seq_len bug; Clean up example code

* Add README link to example

* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig

* Delete image

* Use apply instead of forward

* Use latest release special token; Fix token/s accuracy; Use GeluPytorchTanh in VisionConfig v2

* Add flag to use f16

* Avoid breaking the quantized version on cuda.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-05 07:03:33 +02:00
c87381fc96 Use F16 for moondream on cuda. (#2013) 2024-04-04 23:30:10 +02:00
c5626b8271 Add support for "sign" on tensors (#2012)
* add the sign unary operator

* remove uneeded import

* remove uneeded import

* undo formatting

* undo formatting

* remove unnecessary redefintion

* allow gradient to flow through for sign and round

* fix cpu ops to ensure that negzero and positive zero are handled properly

* clippy fixes

* Properly avoid gradient tracking.

* Use a branchless version.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-04 22:32:47 +02:00
e6a5b82ba6 Fix the matmul layout for accelerate & mkl. (#2011)
* Fix the matmul layout for accelerate & mkl.

* Reduce the required precision for pow (because of accelerate).

* And a fix the gelu f16 test.
2024-04-04 19:18:03 +02:00
5aebe53dd2 update dtypes checks for several metal operations (#2010) 2024-04-04 18:39:06 +02:00
f76bb7794a Bumping the version number to 0.5.0. (#2009) 2024-04-04 17:48:45 +02:00
30b145150f Optimize the gelu f16 opt. (#2008)
* Optimize the gelu f16 opt.

* And add a test.
2024-04-04 16:28:23 +02:00
f48c07e242 Include topk sampling in the quantized example. (#2005)
* Include topk sampling in the quantized example.

* Also sample with top-k on the mistral side.
2024-04-04 09:27:54 +02:00
8967c46563 Split the cuda error file. (#2003) 2024-04-04 08:27:23 +02:00
1e46cf8b19 Minor cleanups in reduce.metal. (#2004) 2024-04-04 08:26:02 +02:00
bd8db2a771 refactor to reduce the amount of code wrapped in template syntax (#2002) 2024-04-04 08:13:12 +02:00
318d143224 Relax the contiguous check for cuda kernels. (#2000)
* Relax the contiguous check for cuda kernels.

* Ensure contiguity for RNNs.

* Unrelated fix for segment anything.

* Better error message + allow concatenating empty slices.
2024-04-03 09:02:38 +02:00
2be1a35710 Added link to the Coursera ML algorithm implementations (#1989)
* Added link to the coursera ML algo implementations

* Fixed link
2024-04-03 07:16:32 +02:00
26226068a4 Moondream WASM (#1999)
* moondream wasm wip

* examples, more

* fix eos token check

* README

* cleanip

* cleanup, clippy
2024-04-03 07:11:50 +02:00
cd6b9e317c Add benchmarks for the candle-nn package (#1995)
* add benchmarks for the candle-nn package

* uncomment test

* format
2024-04-03 07:03:54 +02:00
08c049def3 Improve the handling of matmul with squeezed layouts. (#1998)
* Improve the handling of matmul with squeezed layouts.

* Fix for the cuda backend.

* Revert the temporary fix.
2024-04-02 23:17:05 +02:00
d17b2cdad9 Match Moondream's latest release (#1997)
* moondream implementation

* add moondream example

* change config default activation

* Add assets and integrate phi mixformer with example

* Make use of kv cache and fix seq_len bug; Clean up example code

* Add README link to example

* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig

* Delete image

* Use apply instead of forward

* Use latest release special token; Fix token/s accuracy; Use GeluPytorchTanh in VisionConfig v2
2024-04-02 21:37:09 +02:00
fb918a23c8 first commit (#1994) 2024-04-02 16:31:05 +02:00
b23436bf90 Stable diffusion fix. (#1993)
* Stable diffusion fix.

* And add a comment.
2024-04-02 14:36:28 +02:00
be9c200cbb Expose the t5 config fields + allow t5-large. (#1987) 2024-04-01 20:58:34 +02:00
ea0d8d3753 Quantized moondream implementation and BOS token (#1980)
* moondream implementation

* add moondream example

* change config default activation

* Add assets and integrate phi mixformer with example

* Make use of kv cache and fix seq_len bug; Clean up example code

* Add README link to example

* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig

* Delete image

* Use apply instead of forward

* Pass bos token at the beginning of tensor.

* Quantize moondream.

* Forward with image bos token.

* Clippy.

* Use q4_0 quantization.

* Add pointers for sequence and tokens; Remove seq_len conditional
2024-04-01 19:37:54 +02:00
308ea070ed modify access for conv and op to be pub to allow external packages to have custom backends (#1986) 2024-04-01 17:44:49 +02:00
b20acd622c Update for pyo3 0.21. (#1985)
* Update for pyo3 0.21.

* Also adapt the RL example.

* Fix for the pyo3-onnx bindings...

* Print details on failures.

* Revert pyi.
2024-04-01 17:07:02 +02:00
5522bbc57c Add fn 'get_with_hints_dtype' in VarBuilder (#1877) (#1897)
* quantized models(awq/squeezellm/...) have multiple data type tensors, use 'get_with_hints_dtype' to load tensors with given dtype
2024-04-01 12:10:08 +02:00
888c09a3db add identity op (#1976) 2024-04-01 12:08:25 +02:00
318cb82f16 Quantized cuda tweaks. (#1981)
* Quantized cuda tweaks.

* Add some safety checks.

* Factorize the dequantization bits.
2024-04-01 11:06:42 +02:00
c7557b65dc Switch the default to using the faster kernels. (#1978)
* Switch the default to using the faster kernels.

* Add the force-dmmv flag.
2024-04-01 10:00:11 +02:00
cd29c7ccd4 More ggml cuda kernels (#1977)
* Add more cuda kernels for quantized matmul.

* Add the vec-dot bits.

* Expose the quantized matmul-vec kernels.

* Also include the quantize-q8-1 kernel.

* Glue code for the q8-1 quantization.

* mm-vec product via q8-1 quantization.

* Add a test.

* Add a mm test.

* Get the test to return some sensible results.

* Also test dmmv.

* Fix the launch params.

* Allow for tweaking the force_dmmv parameter while it's experimental.
2024-04-01 00:15:48 +02:00
f9954b73ba Add options to use local files + specify a custom repo or branch. (#1973) 2024-03-31 09:32:50 +02:00
eead1dcead Clippy fix. (#1972) 2024-03-31 08:57:40 +02:00
92f81d2fcb Add Moondream transformer implementation and example (#1970)
* moondream implementation

* add moondream example

* change config default activation

* Add assets and integrate phi mixformer with example

* Make use of kv cache and fix seq_len bug; Clean up example code

* Add README link to example

* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig

* Delete image

* Use apply instead of forward
2024-03-31 08:54:56 +02:00
3144150b8d Move the tensor-tools binary in a separate crate. (#1969) 2024-03-30 15:49:37 +01:00
b190fd8592 Remove some unnecessary calls to contiguous. (#1968)
* Remove some unnecessary calls to contiguous.

* Slightly improved kv cache concatenation.
2024-03-30 13:22:00 +01:00
efe4a0c84b Add a print command to tensor-tools. (#1967)
* Add a print command to tensor-tools.

* Add some flags to tweak the formatting.
2024-03-30 11:34:33 +01:00
665da30487 Backend refactoring. (#1966)
* Backend refactoring.

* Metal tweaks.

* Move the cudnn module.
2024-03-29 23:02:11 +01:00
356a170ae9 Update parquet requirement from 50.0.0 to 51.0.0 (#1867)
Updates the requirements on [parquet](https://github.com/apache/arrow-rs) to permit the latest version.
- [Changelog](https://github.com/apache/arrow-rs/blob/master/CHANGELOG-old.md)
- [Commits](https://github.com/apache/arrow-rs/compare/50.0.0...50.0.0)

---
updated-dependencies:
- dependency-name: parquet
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-03-29 21:58:15 +01:00
7ecbc6d50b fix minor typo (#1924) 2024-03-29 18:09:57 +01:00
8ad12a0e81 Add some examples using the MT5 variants. (#1963) 2024-03-29 18:09:29 +01:00
eb1b27abcd Readme fix. (#1961) 2024-03-28 23:24:46 +01:00
708e422456 Qwen MoE model. (#1960)
* Qwen MoE model.

* Add the MoE model to the example.

* Fix the scaling.

* Readme updates.

* Readme tweaks.
2024-03-28 23:10:57 +01:00
c5092f2c29 Add a couple t5 models. (#1958) 2024-03-28 17:58:06 +01:00
cdc8b57b5c Fix clippy lints + minor cleanups. (#1957)
* Fix clippy lints + minor cleanups.

* fmt.

* Derive clone.
2024-03-28 14:17:46 +01:00
b0340d72ec CLIP model implementation with example (#1950)
* CLIP model implementation with example

* CLIP Implementation fixes, batch images

* CLIP model remove images from git

* CLIP model remove unnecessary use of batch_indices
2024-03-28 13:44:12 +01:00
b3484e7a5e Fix for the RWKV models. (#1955)
* Fix for the RWKV models.

* More general fix + revert the rwkv hack.

* Remove the old hack.
2024-03-28 10:17:38 +01:00
ada5d7c096 add send and sync trait bounds for scheduler config in stable diffusion models (#1952)
* first commit

* add Sync deriving

* static

* remove static
2024-03-28 10:03:00 +01:00
13ae5a34c7 Ensure that the kernels get rebuilt on cuh changes. (#1954) 2024-03-28 06:56:48 +01:00
ab86cd37c8 Support i64 in index-select on metal. (#1951)
* Support i64 in index-select on metal.

* Add some testing of index-select for all dtypes.
2024-03-27 16:30:07 +01:00
a9abde5f93 More flexible matmul contiguity checks. (#1949)
* More flexible matmul contiguity checks.

* Also relax the checks on the metal side.
2024-03-27 10:59:05 +01:00
75b6d4b0da add config for mamba 2.8b model parameter (#1946)
* first commit

* Make the mamba config public.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-03-27 07:47:23 +01:00
66f0a4eeea Another fix for squeezing. (#1943) 2024-03-26 17:05:26 +01:00
4523ecfb2a Faster repeat penalty (#1940)
* Avoid the attention mask where possible.

* Faster repeat penalty.
2024-03-26 11:31:20 +01:00
f5dfe883d7 Extend supported dtypes for metal (im2col & upsample_2d) (#1938)
* update im2col dtype implementations

* update dtypes for upsample
2024-03-26 06:48:56 +01:00
196765e995 Use the new rope kernel in mistral. (#1937)
* Use the new rope kernel in mistral.

* Compute the cos and sin with full precision.

* Bugfix.
2024-03-25 23:26:05 +01:00
60676780a9 Fix detail in new RoPE implementation (#1935) 2024-03-25 18:20:09 +01:00
d3a8d291d5 Avoid the attention mask where possible. (#1933) 2024-03-25 15:31:04 +01:00
cd254074f3 Really unique identifier for metal device ids. (#1932)
* Really unique identifier for metal device ids.

* Same device.
2024-03-25 11:48:16 +01:00
e7f8e72588 Contiguous variant of the rope kernel. (#1929)
* Contiguous variant of the rope kernel.

* Add the cuda kernel.

* Metal kernel.
2024-03-25 09:11:20 +01:00
1b98f84a2b Fast kernels for rotary embeddings. (#1928)
* Fast kernels for rotary embeddings.

* Add a test for the fast CPU kernel.

* Rope cuda bindings.

* Cuda kernel.

* Metal kernel (part 1).

* Cuda kernels.

* Finish the metal kernel.

* Use the new kernels in the quantized example.

* Fix warning.
2024-03-24 22:48:52 +01:00
cf7d7fcf2f Also avoid the mask in the llama example. 2024-03-24 19:04:32 +01:00
8c0db87992 Avoid using the attn mask when not necessary. 2024-03-24 18:55:56 +01:00
e2b4829531 Support more mistral models. (#1927)
* Support more mistral models.

* Use the appropriate rope parameter.
2024-03-24 08:04:04 +01:00
5e70821dd0 Allow for arbitrary temperature modifications. 2024-03-23 15:47:39 +01:00
a62a97340c Add topk sampling. (#1923) 2024-03-23 15:26:09 +01:00
fdfe8fd129 Preliminary support for inplace ops. (#1921)
* Preliminary support for inplace ops.

* Add a test.
2024-03-23 14:16:19 +01:00
790037390c Add cast_bf16_x/cast_x_bf16 when CUDA_ARCH<800 but CUDA_VERSION >= 11000 (#1919)
- it make possible to load bf16 models on T4(sm75)
2024-03-23 13:44:10 +01:00
6f877592a7 Avoid broadcasting on the batch dimension for the attention mask. (#1920) 2024-03-23 13:08:53 +01:00
cc856db9ce Backwards for ConvTranspose2D (#1910)
* add documentation  for nackprop

* add backwards for ConvTranspose2D

* add test python code to test
2024-03-23 07:05:55 +01:00
fc1fe5e45b Support scatter/index_add with i64 indices for f16 (#1915) 2024-03-22 11:51:41 +01:00
32f567bac4 Fix loading the gguf files. (#1913) 2024-03-22 10:28:38 +01:00
fee33b45c2 Add support for strided index-select on Metal (#1909)
* initial implementation

* use correct index, but still not breaking like it should have...

* fix test
2024-03-22 07:30:02 +01:00
6708870e63 Add the alloc_uninit function. (#1901)
* Add the alloc_uninit function.

* Dummy metal fix.

* Lazy initialization.
2024-03-22 07:25:23 +01:00
a00e24d752 Improve the error message on overlong prompts. (#1908) 2024-03-21 21:08:07 +01:00
c07e4057ab Fix for the llama model. (#1906) 2024-03-21 19:36:10 +01:00
c0bdd9c7a6 Use the fast RmsNorm in the quantized model. (#1904) 2024-03-21 18:49:35 +01:00
9563a5fee4 Add support for conv_transpose2d on Metal backend (#1903)
* add support for conv transpose 2d and add bench mark for float types

* update bench calculation

* enable testing all conv operations on metal
2024-03-21 18:08:45 +01:00
ec97c98e81 Async tensor copying. (#1900) 2024-03-21 13:09:42 +01:00
bb3ee48039 whisper readme (#1899) 2024-03-21 12:54:09 +01:00
0c11e055be support distil-large-v3 (#1898) 2024-03-21 11:46:49 +01:00
18036c6ccb Update the image crate + use the re-exported version. (#1893)
* Update the image crate + use the re-exported version.

* Update to using ab_glyph.
2024-03-21 10:56:41 +01:00
0fddec762e RmsNorm kernel for metal. (#1895)
* RmsNorm kernel for metal.

* Wrapper for the metal kernel.

* Get the ops to actually work.

* Fix, get the tests to pass.
2024-03-21 09:48:56 +01:00
74b7f59261 Prepare for the custom-op extension. (#1892) 2024-03-21 07:02:20 +01:00
af7f8b87d3 Custom op for RmsNorm (#1890)
* Trying out a custom RmsNorm cuda kernel.

* CPU implementation for rms-norm.

* Cuda wrappers.

* Add some validation.

* Add some testing.

* More testing.
2024-03-21 06:36:28 +01:00
b219903d0f Cuda backend optimization (#1886)
* Attempt at making the kernel faster.

* Also adapt the cast kernels.

* Also apply to binary ops.
2024-03-20 18:32:55 +01:00
469635a3eb Minor cleanup. (#1885) 2024-03-20 14:38:27 +01:00
455c42aa72 Avoid copying the data on squeeze and unsqueeze. (#1884)
* Avoid copying the data on squeeze and unsqueeze.

* Fix the quantized llama example.

* Unrelated fix for the quantized stable-lm example on cuda.

* Fix for mamba on cuda (unrelated to the PR).
2024-03-20 13:04:36 +01:00
2a8679509e Add support for conv_transpose1d for metal backend (#1874)
* first attempt

* progress

* integrate into metal backend

* finish and get test passing

* add other dtype support

* update transpose1d dtypes supported
2024-03-19 08:46:58 +01:00
143c481c20 Expose candle gather op in pyo3. (#1870) 2024-03-18 21:54:15 +01:00
f115895b9e Apply rustfmt. (#1873) 2024-03-18 21:43:31 +01:00
90fc82211f Use a common with_tracing::RmsNorm in a few models. (#1871)
* Add RmsNorm with tracing.

* Use with_tracing::RmsNorm in some models.
2024-03-18 21:40:06 +01:00
6a966cf9e0 Add a DQN example to the reinforcement-learning section (#1872) 2024-03-18 21:22:53 +01:00
04a61a9c72 Add avg_pool2d metal implementation for the metal backend (#1869)
* implement metal avg pool 2d

* fixX

* add suggested precision workaround for the accumulator
2024-03-18 18:50:14 +01:00
58605252e8 Microphone support for the encodec example. (#1866) 2024-03-18 11:19:46 +01:00
d365ef32d9 Improve the encodec example: handle resampling. (#1865)
* Improve the encodec example: handle resampling.

* Play the audio directly.
2024-03-18 10:09:40 +01:00
754fa1e813 Add support for max_pool2d for Metal backend (#1863)
* first pass at implementation of maxpool2d

* Add definitions for other dtypes

* add tests for other dtypes

* Cosmetic tweaks + re-enable maxpool2d tests for metal.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-03-18 08:33:30 +01:00
184105792f add test for index add and add missing match statements (#1862) 2024-03-17 22:19:12 +01:00
a15f859ab4 Fix for the encodec example. (#1861) 2024-03-17 21:15:12 +01:00
e316cb6997 add support for casting between all datatypes (#1860) 2024-03-17 20:55:11 +01:00
204 changed files with 22935 additions and 4285 deletions

View File

@ -9,6 +9,7 @@ members = [
"candle-transformers",
"candle-wasm-examples/*",
"candle-wasm-tests",
"tensor-tools",
]
exclude = [
"candle-flash-attn",
@ -19,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.4.2"
version = "0.5.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -28,49 +29,49 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[workspace.dependencies]
ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.4.2" }
candle-datasets = { path = "./candle-datasets", version = "0.4.2" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.2" }
candle-kernels = { path = "./candle-kernels", version = "0.4.2" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.2" }
candle-nn = { path = "./candle-nn", version = "0.4.2" }
candle-onnx = { path = "./candle-onnx", version = "0.4.2" }
candle-transformers = { path = "./candle-transformers", version = "0.4.2" }
candle = { path = "./candle-core", package = "candle-core", version = "0.5.1" }
candle-datasets = { path = "./candle-datasets", version = "0.5.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.1" }
candle-kernels = { path = "./candle-kernels", version = "0.5.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.1" }
candle-nn = { path = "./candle-nn", version = "0.5.1" }
candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] }
cudarc = { version = "0.11.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.23.0", default-features = false }
hound = "3.5.1"
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
num_cpus = "1.15.0"
num-traits = "0.2.15"
parquet = { version = "50.0.0" }
parquet = { version = "51.0.0" }
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false }
safetensors = "0.4.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.15.0", default-features = false }
tokenizers = { version = "0.19.1", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
[profile.release-with-debug]

View File

@ -60,12 +60,14 @@ These online demos run entirely in your browser:
We also provide a some command line based examples using state of the art models:
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
Deepmind.
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
Griffin based models from Google that mix attention with a RNN like state.
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
2.7b, and 3.8b general LLMs with performance on par with 7b models.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
@ -110,7 +112,7 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmentation model.
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
model using residual vector quantization.
@ -125,10 +127,14 @@ We also provide a some command line based examples using state of the art models
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image.
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
model.
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
dedicated submodels for hand-writing and printed recognition.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
model, generates the translated text from the input text.
- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model
that can answer real-world questions about images.
Run them using commands like:
```
@ -172,6 +178,7 @@ And then head over to
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
@ -194,10 +201,10 @@ If you have an addition to this list, please submit a pull request.
- WASM support, run your models in a browser.
- Included models.
- Language Models.
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
- LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.
- Falcon.
- StarCoder, StarCoder2.
- Phi 1, 1.5, and 2.
- Phi 1, 1.5, 2, and 3.
- Mamba, Minimal Mamba
- Gemma 2b and 7b.
- Mistral 7b v0.1.
@ -206,7 +213,7 @@ If you have an addition to this list, please submit a pull request.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- Qwen1.5.
- Qwen1.5, Qwen1.5 MoE.
- RWKV v5 and v6.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
@ -369,9 +376,9 @@ git submodule update --init
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ...:
```
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
```
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
```
#### Linking error on windows when running rustdoc or mdbook tests
@ -401,3 +408,10 @@ This may be caused by the models being loaded from `/mnt/c`, more details on
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
error is generated.
#### CudaRC error
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`

View File

@ -37,7 +37,6 @@ tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
parquet = { workspace = true }
image = { workspace = true }

View File

@ -81,7 +81,7 @@ let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];
if size % world_size != 0 {
panic!("The dimension is not divisble by `world_size`");
panic!("The dimension is not divisible by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;

View File

@ -5,5 +5,8 @@ criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
benchmarks::where_cond::benches
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,
benchmarks::unary::benches
);

View File

@ -0,0 +1,59 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(
x: &Tensor,
k: &Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) {
x.conv_transpose2d(k, padding, output_padding, stride, dilation)
.unwrap();
}
fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let t = Tensor::arange(0.0f32, 10000.0, device)
.unwrap()
.reshape((1, 4, 50, 50))
.unwrap()
.to_dtype(dtype)
.unwrap();
let kernel = Tensor::arange(0.0f32, 100.0, device)
.unwrap()
.reshape((4, 1, 5, 5))
.unwrap()
.to_dtype(dtype)
.unwrap();
let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,6 +1,9 @@
pub(crate) mod affine;
pub(crate) mod conv_transpose2d;
pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;
pub(crate) mod unary;
pub(crate) mod where_cond;
use candle_core::{Device, Result};

View File

@ -0,0 +1,72 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{
quantized::{self, GgmlDType, QMatMul},
Device, Module, Tensor,
};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(matmul: &QMatMul, x: &Tensor) {
matmul.forward(&x).unwrap();
}
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
let b = 1;
let m = 1;
let n = 1024;
let k = 1024;
let lhs = (0..(m * k))
.map(|v| v as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..(k * n))
.map(|v| v as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
group.sample_size(200);
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&matmul), black_box(&lhs));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in vec![
GgmlDType::F32,
GgmlDType::F16,
GgmlDType::Q4_0,
GgmlDType::Q4_1,
GgmlDType::Q5_0,
GgmlDType::Q5_1,
GgmlDType::Q8_0,
GgmlDType::Q2K,
GgmlDType::Q3K,
GgmlDType::Q4K,
GgmlDType::Q5K,
GgmlDType::Q6K,
] {
run_bench(c, &device, dtype);
}
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -0,0 +1,49 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor) {
a.sqrt().unwrap();
}
fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let b = 1;
let m = 1024;
let k = 1024;
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device)
.unwrap()
.to_dtype(dtype)
.unwrap()
.reshape((b, m, k))
.unwrap();
let flops = b * m * k * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in [DType::F32, DType::BF16, DType::F16] {
let name = format!("sqrt_{:?}", dtype);
run_unary_benchmark(c, &device, dtype, &name);
}
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -5,32 +5,26 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Module, Tensor};
use candle_core::quantized::{QMatMul, QTensor};
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
let q_cpu = q.to_device(&Device::Cpu)?;
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
let q = QMatMul::from_qtensor(q)?;
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
let res_q_cuda = q.forward(&x)?;
println!("{res_q_cuda}");
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
let x_cpu = x.to_device(&Device::Cpu)?;
let res_q_cpu = q_cpu.forward(&x_cpu)?;
println!("{res_q_cpu}");
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
.abs()?
.flatten_all()?
.max(0)?;
println!("{diff}");
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("tf32: {:?}", start_time.elapsed());
drop(_x1);
Ok(())
}

View File

@ -127,11 +127,24 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
/// # Safety
/// This function is unsafe as it doesn't initialize the underlying data store.
/// The caller should ensure that the data is properly initialized as early as possible
/// after this call.
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn set_seed(&self, _: u64) -> Result<()>;
/// Synchronize should block until all the operations on the device are completed.
fn synchronize(&self) -> Result<()>;
}

View File

@ -1,3 +1,4 @@
/// Methods for backpropagation of gradients.
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
@ -111,7 +112,8 @@ impl Tensor {
}
Op::Unary(_node, UnaryOp::Ceil)
| Op::Unary(_node, UnaryOp::Floor)
| Op::Unary(_node, UnaryOp::Round) => nodes,
| Op::Unary(_node, UnaryOp::Round)
| Op::Unary(_node, UnaryOp::Sign) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D { arg: node, .. }
| Op::UpsampleNearest2D { arg: node, .. }
@ -310,9 +312,32 @@ impl Tensor {
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose1d",
})?,
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose2d",
})?,
Op::ConvTranspose2D {
arg,
kernel,
padding,
stride,
dilation,
output_padding: _output_padding,
} => {
let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = grad
.transpose(0, 1)?
.conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::AvgPool2D {
arg,
kernel_size,
@ -464,7 +489,6 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?;
}
Op::Cmp(_args, _) => {}
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
@ -554,20 +578,18 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
Op::Unary(_, UnaryOp::Floor)
| Op::Unary(_, UnaryOp::Round)
| Op::Reduce(_, ReduceOp::ArgMin, _)
| Op::Reduce(_, ReduceOp::ArgMax, _)
| Op::Unary(_, UnaryOp::Sign)
| Op::Cmp(_, _) => {}
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
Op::Unary(_, UnaryOp::Floor) => {
Err(Error::BackwardNotSupported { op: "floor" })?
}
Op::Unary(_, UnaryOp::Round) => {
Err(Error::BackwardNotSupported { op: "round" })?
}
Op::Unary(arg, UnaryOp::Gelu) => {
let sum_grad = grads.or_insert(arg)?;
let cube = arg.powf(3.)?;
@ -602,7 +624,7 @@ impl Tensor {
Op::Unary(arg, UnaryOp::Silu) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
let sigmoid_arg = (*node / arg)?;
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}
@ -690,30 +712,38 @@ impl Tensor {
}
}
/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {
/// Create a new gradient store
fn new() -> Self {
GradStore(HashMap::new())
}
/// Get the gradient tensor corresponding to the given tensor id
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
self.0.get(&id)
}
/// Get the gradient tensor associated with the given tensor
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
self.0.get(&tensor.id())
}
/// Remove the gradient tensor associated with the given tensor, returning it if it exists
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
self.0.remove(&tensor.id())
}
/// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
self.0.insert(tensor.id(), grad)
}
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
use std::collections::hash_map::Entry;
let grad = match self.0.entry(tensor.id()) {

View File

@ -1,6 +1,7 @@
pub mod erf;
pub mod kernels;
#[allow(unused)]
trait Cpu<const ARR: usize> {
type Unit;
type Array;
@ -18,6 +19,7 @@ trait Cpu<const ARR: usize> {
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
}
#[allow(unused)]
trait CpuF16<const ARR: usize> {
type Unit;
type Array;

View File

@ -4,6 +4,11 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
use rayon::prelude::*;
mod utils;
pub use utils::{
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
};
const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV1D_TR: bool = true;
const USE_IM2COL_CONV2D: bool = true;
@ -21,105 +26,20 @@ pub enum CpuStorage {
F64(Vec<f64>),
}
#[derive(Debug, Clone)]
pub enum CpuStorageRef<'a> {
U8(&'a [u8]),
U32(&'a [u32]),
I64(&'a [i64]),
BF16(&'a [bf16]),
F16(&'a [f16]),
F32(&'a [f32]),
F64(&'a [f64]),
}
#[derive(Debug, Clone)]
pub struct CpuDevice;
pub trait Map1 {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
match vs {
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
}
}
}
pub trait Map1Any {
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
&self,
vs: &[T],
layout: &Layout,
wrap: W,
) -> Result<CpuStorage>;
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
match vs {
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
}
}
}
type C = CpuStorage;
pub trait Map2 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
fn map(
&self,
v1: &CpuStorage,
l1: &Layout,
v2: &CpuStorage,
l2: &Layout,
) -> Result<CpuStorage> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
pub trait Map2U8 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
fn map(
&self,
v1: &CpuStorage,
l1: &Layout,
v2: &CpuStorage,
l2: &Layout,
) -> Result<CpuStorage> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
struct Cmp(CmpOp);
impl Map2U8 for Cmp {
const OP: &'static str = "cmp";
@ -366,275 +286,6 @@ impl<'a> Map1 for ReduceSum<'a> {
}
}
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
vs: &[T],
layout: &Layout,
mut f: F,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
[start_offset..start_offset + len]
.iter()
.map(|&v| f(v))
.collect(),
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let mut result = Vec::with_capacity(layout.shape().elem_count());
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
} else {
for index in block_start_index {
for offset in 0..block_len {
let v = unsafe { vs.get_unchecked(index + offset) };
result.push(f(*v))
}
}
}
result
}
}
}
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
vs: &[T],
layout: &Layout,
mut f: F,
mut f_vec: FV,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
ys
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let el_count = layout.shape().elem_count();
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
let mut result = Vec::with_capacity(el_count);
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
result
} else {
let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let mut dst_index = 0;
for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len];
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
f_vec(vs, ys);
dst_index += block_len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
}
}
}
// This function maps over two strided index sequences.
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
) -> Vec<U> {
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
.iter()
.zip(rhs[o_r1..o_r2].iter())
.map(|(&l, &r)| f(l, r))
.collect(),
(Some((o_l1, o_l2)), None) => {
// TODO: Maybe we want to avoid going through the layout twice.
match rhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
lhs[o_l1..o_l2]
.iter()
.map(|&l| {
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(l, *r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
(None, Some((o_r1, o_r2))) => {
// TODO: Maybe we want to avoid going through the layout twice.
match lhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
rhs[o_r1..o_r2]
.iter()
.map(|&r| {
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(*l, r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
// Similar to binary_map but with vectorized variants.
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
mut f_vec: FV,
) -> Vec<T> {
let el_count = lhs_l.shape().elem_count();
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec(
&lhs[src_i..src_i + ob.len],
rhs,
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys = lhs[o_l1..o_l2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &r) in rhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(*v, r)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec(
lhs,
&rhs[src_i..src_i + ob.len],
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys = rhs[o_r1..o_r2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &l) in lhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(l, *v)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
struct Affine(f64, f64);
impl Map1 for Affine {
@ -1564,6 +1215,30 @@ impl MatMul {
}))
.bt()
}
fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let (_b, m, n, k) = self.0;
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[_, stride] if lhs_l.dims()[0] == 1 => stride,
[stride, _] if lhs_l.dims()[1] == 1 => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[_, stride] if rhs_l.dims()[0] == 1 => stride,
[stride, _] if rhs_l.dims()[1] == 1 => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
Ok((a_skip, b_skip))
}
}
impl Map2 for MatMul {
@ -1597,18 +1272,7 @@ impl Map2 for MatMul {
let rhs_cs = rhs_stride[rank - 1];
let rhs_rs = rhs_stride[rank - 2];
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
let c_skip: usize = m * n;
let dst_shape: Shape = (m, n).into();
@ -1668,20 +1332,8 @@ impl Map2 for MatMul {
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
let c_skip: usize = m * n;
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
@ -1689,7 +1341,7 @@ impl Map2 for MatMul {
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
(n as i32, b'N')
} else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, b'T')
@ -1697,7 +1349,7 @@ impl Map2 for MatMul {
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
(k as i32, b'N')
} else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, b'T')
@ -1771,20 +1423,8 @@ impl Map2 for MatMul {
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
let c_skip: usize = m * n;
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
@ -1792,7 +1432,7 @@ impl Map2 for MatMul {
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
(n as i32, b'N')
} else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, b'T')
@ -1800,7 +1440,7 @@ impl Map2 for MatMul {
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
(k as i32, b'N')
} else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, b'T')
@ -2582,7 +2222,10 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
let mut kernel_c = unsafe {
self.device()
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@ -2590,7 +2233,7 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@ -2681,7 +2324,10 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
let mut kernel_c = unsafe {
self.device()
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@ -2691,7 +2337,7 @@ impl BackendStorage for CpuStorage {
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@ -2810,10 +2456,18 @@ impl BackendDevice for CpuDevice {
true
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
Ok(T::to_cpu_storage(s))
}
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
Ok(s.clone())
}
fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
Ok(s)
}
fn new(_: usize) -> Result<Self> {
Ok(Self)
}
@ -2915,6 +2569,53 @@ impl BackendDevice for CpuDevice {
}
}
#[allow(clippy::uninit_vec)]
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
// The code below is highly unsafe but hopefully not directly unsound as we only consider
// types that are Copy, not Drop, and for which all bit patterns are proper values.
// It's still pretty risky, see the following for more details:
// https://github.com/rust-lang/rust-clippy/issues/4483
let storage = match dtype {
DType::U8 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::U8(v)
}
DType::U32 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::U32(v)
}
DType::I64 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::I64(v)
}
DType::BF16 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::BF16(v)
}
DType::F16 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F16(v)
}
DType::F32 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F32(v)
}
DType::F64 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F64(v)
}
};
Ok(storage)
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
let storage = match dtype {
@ -2942,6 +2643,10 @@ impl BackendDevice for CpuDevice {
};
Ok(storage)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}
#[macro_export]

View File

@ -0,0 +1,350 @@
/// Helper functions to write CPU kernels.
use crate::backend::BackendStorage;
use crate::{Error, Layout, Result, WithDType};
type C = super::CpuStorage;
pub trait Map1 {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
match vs {
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
}
}
}
pub trait Map1Any {
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
match vs {
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
}
}
}
pub trait Map2 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
pub trait Map2U8 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
) -> Vec<U> {
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
.iter()
.zip(rhs[o_r1..o_r2].iter())
.map(|(&l, &r)| f(l, r))
.collect(),
(Some((o_l1, o_l2)), None) => {
// TODO: Maybe we want to avoid going through the layout twice.
match rhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
lhs[o_l1..o_l2]
.iter()
.map(|&l| {
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(l, *r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
(None, Some((o_r1, o_r2))) => {
// TODO: Maybe we want to avoid going through the layout twice.
match lhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
rhs[o_r1..o_r2]
.iter()
.map(|&r| {
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(*l, r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
// Similar to binary_map but with vectorized variants.
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
mut f_vec: FV,
) -> Vec<T> {
let el_count = lhs_l.shape().elem_count();
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec(
&lhs[src_i..src_i + ob.len],
rhs,
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys = lhs[o_l1..o_l2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &r) in rhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(*v, r)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec(
lhs,
&rhs[src_i..src_i + ob.len],
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys = rhs[o_r1..o_r2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &l) in lhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(l, *v)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
vs: &[T],
layout: &Layout,
mut f: F,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
[start_offset..start_offset + len]
.iter()
.map(|&v| f(v))
.collect(),
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let mut result = Vec::with_capacity(layout.shape().elem_count());
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
} else {
for index in block_start_index {
for offset in 0..block_len {
let v = unsafe { vs.get_unchecked(index + offset) };
result.push(f(*v))
}
}
}
result
}
}
}
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
vs: &[T],
layout: &Layout,
mut f: F,
mut f_vec: FV,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
ys
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let el_count = layout.shape().elem_count();
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
let mut result = Vec::with_capacity(el_count);
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
result
} else {
let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let mut dst_index = 0;
for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len];
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
f_vec(vs, ys);
dst_index += block_len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
}
}
}

View File

@ -0,0 +1,452 @@
use crate::backend::BackendDevice;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
use half::{bf16, f16};
use std::sync::{Arc, Mutex};
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
device: Arc<cudarc::driver::CudaDevice>,
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
impl std::fmt::Debug for CudaDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CudaDevice({:?})", self.id)
}
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaDevice>;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl CudaDevice {
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
self.device.clone()
}
pub fn id(&self) -> DeviceId {
self.id
}
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
let params = (&data, v as u8, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
let params = (&data, v as u32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
let params = (&data, v as i64, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
let params = (&data, bf16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
let params = (&data, f16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
// Leaking the string here is a bit sad but we need a &'static str and this is only
// done once per kernel name.
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
self.load_ptx(ptx.into(), module_name, &[static_module_name])
.map_err(|cuda| CudaError::Load {
cuda,
module_name: module_name.to_string(),
})
.w()?;
}
self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
// able to only build the error value if needed.
.ok_or(CudaError::MissingKernel {
module_name: module_name.to_string(),
})
.w()
}
}
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
})
}
fn set_seed(&self, seed: u64) -> Result<()> {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
}
}
fn same_device(&self, rhs: &Self) -> bool {
self.id == rhs.id
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
};
let slice = if lo == 0. && up == 1.0 {
slice
} else {
use super::utils::Map1;
let layout = Layout::contiguous(shape);
super::Affine(up - lo, lo).map(&slice, self, &layout)?
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
// curand can only generate an odd number of values.
// https://github.com/huggingface/candle/issues/734
let elem_count_round = if elem_count % 2 == 1 {
elem_count + 1
} else {
elem_count
};
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_copy(storage).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn synchronize(&self) -> Result<()> {
self.device.synchronize().map_err(crate::Error::wrap)?;
Ok(())
}
}

View File

@ -0,0 +1,62 @@
use crate::{DType, Layout};
/// cudarc related errors
#[derive(thiserror::Error, Debug)]
pub enum CudaError {
#[error(transparent)]
Cuda(#[from] cudarc::driver::DriverError),
#[error(transparent)]
Compiler(#[from] cudarc::nvrtc::CompileError),
#[error(transparent)]
Cublas(#[from] cudarc::cublas::result::CublasError),
#[error(transparent)]
Curand(#[from] cudarc::curand::result::CurandError),
#[error("missing kernel '{module_name}'")]
MissingKernel { module_name: String },
#[error("unsupported dtype {dtype:?} for {op}")]
UnsupportedDtype { dtype: DType, op: &'static str },
#[error("internal error '{0}'")]
InternalError(&'static str),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Layout,
rhs_stride: Layout,
mnk: (usize, usize, usize),
},
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
#[error("{cuda} when loading {module_name}")]
Load {
cuda: cudarc::driver::DriverError,
module_name: String,
},
}
impl From<CudaError> for crate::Error {
fn from(val: CudaError) -> Self {
crate::Error::Cuda(Box::new(val)).bt()
}
}
pub trait WrapErr<O> {
fn w(self) -> std::result::Result<O, crate::Error>;
}
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
fn w(self) -> std::result::Result<O, crate::Error> {
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
}
}

View File

@ -0,0 +1,134 @@
/// Helper functions to plug cuda kernels in candle.
use crate::{Layout, Result, Shape, WithDType};
pub use cudarc;
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
use super::{CudaDevice, CudaError, WrapErr};
pub type S = super::CudaStorageSlice;
pub trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?),
S::I64(s) => S::I64(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?),
S::F32(s) => S::F32(self.f(s, d, l)?),
S::F64(s) => S::F64(self.f(s, d, l)?),
};
Ok(out)
}
}
pub trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
Ok(out)
}
}
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()>;
fn map(
&self,
dst: &mut S,
dst_s: &Shape,
src: &S,
src_l: &Layout,
d: &CudaDevice,
) -> Result<()> {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
}
}
}
pub trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
wrap: W,
) -> Result<S>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::I64(s) => self.f(s, d, l, S::I64)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
S::F64(s) => self.f(s, d, l, S::F64)?,
};
Ok(out)
}
}
pub trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<S>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
};
Ok(out)
}
}

View File

@ -0,0 +1,377 @@
use crate::op::{BackpropOp, Op};
use crate::tensor::from_storage;
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use std::sync::Arc;
/// Unary ops that can be defined in user-land.
pub trait CustomOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_arg3: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
impl Tensor {
/// Applies a unary custom op without backward support
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a binary custom op without backward support
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
let (storage, shape) =
self.storage()
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a ternary custom op without backward support
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a unary custom op.
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
let (storage, shape) = self
.storage()
.apply_op1(self.layout(), c.as_ref().as_ref())?;
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
self.apply_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
pub fn apply_op2_arc(
&self,
rhs: &Self,
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
self.apply_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
pub fn apply_op3_arc(
&self,
t2: &Self,
t3: &Self,
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
Op::CustomOp3(t1, t2, t3, c.clone())
});
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
&self,
t2: &Self,
t3: &Self,
c: C,
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
}
// In place ops.
/// Unary ops that can be defined in user-land.
/// These ops work in place and as such back-prop is unsupported.
pub trait InplaceOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
pub trait InplaceOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
-> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &mut MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
pub trait InplaceOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &mut CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &mut CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &mut MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
impl Tensor {
/// Applies a unary custom op in place.
pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
self.storage_mut().inplace_op1(self.layout(), c)
}
/// Applies a unary custom op in place (for the first tensor).
pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
self.storage_mut()
.inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
}
/// Applies a ternary custom op in place (for the first tensor).
pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
self.storage_mut().inplace_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)
}
}

View File

@ -289,17 +289,48 @@ impl Device {
}
}
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
Device::Cuda(device) => {
let storage = device.storage_from_slice(data)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.storage_from_slice(data)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage(&storage)?;
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage(&storage)?;
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Metal(storage))
}
}
@ -310,14 +341,22 @@ impl Device {
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
Device::Cuda(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage(&storage)?;
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage(&storage)?;
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Metal(storage))
}
}
}
pub fn synchronize(&self) -> Result<()> {
match self {
Self::Cpu => Ok(()),
Self::Cuda(d) => d.synchronize(),
Self::Metal(d) => d.synchronize(),
}
}
}

View File

@ -1,7 +1,7 @@
//! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result};
use crate::{CpuStorage, CpuStorageRef, Error, Result};
/// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
@ -100,12 +100,14 @@ pub trait WithDType:
+ 'static
+ Send
+ Sync
+ std::any::Any
+ crate::cpu::kernels::VecOps
{
const DTYPE: DType;
fn from_f64(v: f64) -> Self;
fn to_f64(self) -> f64;
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
@ -129,6 +131,10 @@ macro_rules! with_dtype {
$to_f64(self)
}
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
CpuStorageRef::$dtype(data)
}
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
CpuStorage::$dtype(data)
}

View File

@ -210,10 +210,22 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -221,4 +233,38 @@ impl crate::backend::BackendDevice for CudaDevice {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn gemm_reduced_precision_f16() -> bool {
true
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn set_gemm_reduced_precision_f16(_: bool) {}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn gemm_reduced_precision_bf16() -> bool {
true
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn gemm_reduced_precision_f32() -> bool {
true
}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn set_gemm_reduced_precision_f32(_b: bool) {}

View File

@ -222,10 +222,22 @@ impl crate::backend::BackendDevice for MetalDevice {
Err(Error::NotCompiledWithMetalSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
@ -233,4 +245,8 @@ impl crate::backend::BackendDevice for MetalDevice {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}

View File

@ -219,10 +219,14 @@ impl Error {
Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
pub fn msg(err: impl std::error::Error) -> Self {
Self::Msg(err.to_string()).bt()
}
pub fn debug(err: impl std::fmt::Debug) -> Self {
Self::Msg(format!("{err:?}")).bt()
}
pub fn bt(self) -> Self {
let backtrace = std::backtrace::Backtrace::capture();
match backtrace.status() {

View File

@ -14,7 +14,7 @@
//!
//! ## Features
//!
//! - Simple syntax (looks and like PyTorch)
//! - Simple syntax (looks and feels like PyTorch)
//! - CPU and Cuda backends (and M1 support)
//! - Enable serverless (CPU) small and fast deployments
//! - Model training
@ -37,18 +37,17 @@
mod accelerate;
pub mod backend;
pub mod backprop;
mod conv;
pub mod conv;
mod convert;
pub mod cpu;
pub mod cpu_backend;
#[cfg(feature = "cuda")]
pub mod cuda_backend;
#[cfg(feature = "cudnn")]
pub mod cudnn;
mod custom_op;
mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
pub mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error;
mod indexer;
@ -58,12 +57,13 @@ pub mod metal_backend;
#[cfg(feature = "mkl")]
mod mkl;
pub mod npy;
mod op;
pub mod op;
pub mod pickle;
pub mod quantized;
pub mod safetensors;
pub mod scalar;
pub mod shape;
mod sort;
mod storage;
mod strided_index;
mod tensor;
@ -72,13 +72,16 @@ pub mod test_utils;
pub mod utils;
mod variable;
pub use cpu_backend::CpuStorage;
#[cfg(feature = "cudnn")]
pub use cuda_backend::cudnn;
pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, FloatDType, IntDType, WithDType};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
pub use layout::Layout;
pub use op::{CustomOp1, CustomOp2, CustomOp3};
pub use shape::{Shape, D};
pub use storage::Storage;
pub use strided_index::{StridedBlocks, StridedIndex};
@ -86,10 +89,12 @@ pub use tensor::{Tensor, TensorId};
pub use variable::Var;
#[cfg(feature = "cuda")]
pub use cuda_backend::{CudaDevice, CudaStorage};
pub use cuda_backend as cuda;
#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
pub use dummy_cuda_backend as cuda;
pub use cuda::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};

View File

@ -0,0 +1,287 @@
use crate::{DType, Result};
use candle_metal_kernels::Kernels;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
use super::MetalError;
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
pub(crate) fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
#[derive(Clone)]
pub struct MetalDevice {
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
/// the device itself.
pub(crate) id: DeviceId,
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
pub(crate) device: metal::Device,
/// Single command queue for the entire device.
pub(crate) command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
pub(crate) compute_per_buffer: usize,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
/// (could be linked to FFI communication overhead).
///
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
/// to be reused. However, in order for this to work, we need to guarantee the order of
/// operation, so that this buffer is not being used by another kernel at the same time.
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
///
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
/// (strong_count = 1).
pub(crate) buffers: AllocatedBuffers,
/// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>,
}
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.id)
}
}
impl std::ops::Deref for MetalDevice {
type Target = metal::DeviceRef;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl MetalDevice {
pub fn id(&self) -> DeviceId {
self.id
}
pub fn metal_device(&self) -> &metal::Device {
&self.device
}
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self
.command_buffer_index
.write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffer_lock = command_buffer.clone();
*index = 0;
self.drop_unused_buffers()?;
}
*index += 1;
Ok(command_buffer)
}
pub fn wait_until_completed(&self) -> Result<()> {
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
command_buffer.commit();
command_buffer.wait_until_completed();
*command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
pub fn device(&self) -> &metal::Device {
&self.device
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer data cannot be read on the CPU directly.
///
/// [`name`] is only used to keep track of the resource origin in case of bugs
pub fn new_buffer(
&self,
element_count: usize,
dtype: DType,
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer can be read on the CPU but will require manual
/// synchronization when the CPU memory is modified
/// Used as a bridge to gather data back from the GPU
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
}
/// Creates a new buffer from data.
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
///
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data) as NSUInteger;
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr() as *const c_void,
size,
MTLResourceOptions::StorageModeManaged,
);
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.end_encoding();
Ok(buffer)
}
fn find_available_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
buffers: &RwLockWriteGuard<BufferMap>,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
/// The critical allocator algorithm
fn allocate_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
_name: &str,
) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
}
let size = buf_size(size);
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
/// Create a metal GPU capture trace on [`path`].
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let capture = metal::CaptureManager::shared();
let descriptor = metal::CaptureDescriptor::new();
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
descriptor.set_capture_device(self);
descriptor.set_output_url(path);
capture
.start_capture(&descriptor)
.map_err(MetalError::from)?;
Ok(())
}
}
fn buf_size(size: NSUInteger) -> NSUInteger {
size.saturating_sub(1).next_power_of_two() as NSUInteger
}

View File

@ -330,7 +330,7 @@ impl Tensor {
path: P,
) -> Result<()> {
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
let options =
let options: zip::write::FileOptions<()> =
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
for (name, tensor) in ts.iter() {

View File

@ -1,5 +1,5 @@
#![allow(clippy::redundant_closure_call)]
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use crate::Tensor;
use half::{bf16, f16};
use num_traits::float::Float;
@ -66,6 +66,7 @@ pub enum UnaryOp {
Floor,
Ceil,
Round,
Sign,
}
#[derive(Clone)]
@ -161,168 +162,23 @@ pub enum Op {
Permute(Tensor, Vec<usize>),
Elu(Tensor, f64),
Powf(Tensor, f64),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
CustomOp1(
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
),
CustomOp2(
Tensor,
Tensor,
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
),
CustomOp3(
Tensor,
Tensor,
Tensor,
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
),
}
/// Unary ops that can be defined in user-land.
pub trait CustomOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_arg3: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait UnaryOpT {
const NAME: &'static str;
const KERNEL: &'static str;
@ -399,6 +255,7 @@ pub(crate) struct Tanh;
pub(crate) struct Floor;
pub(crate) struct Ceil;
pub(crate) struct Round;
pub(crate) struct Sign;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
@ -602,6 +459,13 @@ unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
// Hardcode the value for sqrt(2/pi)
// https://github.com/huggingface/candle/issues/1982
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
/// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one.
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
@ -614,7 +478,7 @@ impl UnaryOpT for Gelu {
* v
* (bf16::ONE
+ bf16::tanh(
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
* v
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
))
@ -625,22 +489,18 @@ impl UnaryOpT for Gelu {
* v
* (f16::ONE
+ f16::tanh(
(f16::from_f32_const(2.0) / f16::PI).sqrt()
f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
* v
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
0.5 * v
* (1.0
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn f64(v: f64) -> f64 {
0.5 * v
* (1.0
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn u8(_: u8) -> u8 {
@ -1067,3 +927,37 @@ impl std::ops::Deref for BackpropOp {
&self.0
}
}
impl UnaryOpT for Sign {
const NAME: &'static str = "sign";
const KERNEL: &'static str = "usign";
const V: Self = Sign;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
}
#[inline(always)]
fn f32(v: f32) -> f32 {
f32::from(v > 0.) - f32::from(v < 0.)
}
#[inline(always)]
fn f64(v: f64) -> f64 {
f64::from(v > 0.) - f64::from(v < 0.)
}
#[inline(always)]
fn u8(v: u8) -> u8 {
u8::min(1, v)
}
#[inline(always)]
fn u32(v: u32) -> u32 {
u32::min(1, v)
}
#[inline(always)]
fn i64(v: i64) -> i64 {
(v > 0) as i64 - (v < 0) as i64
}
}

View File

@ -1,24 +1,66 @@
use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use half::f16;
use cudarc::driver::{CudaSlice, DeviceSlice};
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
#[derive(Clone, Debug)]
pub struct QCudaStorage {
data: CudaSlice<u8>,
dtype: GgmlDType,
device: CudaDevice,
}
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
pub fn set_force_dmmv(f: bool) {
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
}
pub const WARP_SIZE: usize = 32;
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
pub const MATRIX_ROW_PADDING: usize = 512;
fn dequantize(
fn ceil_div(p: usize, q: usize) -> usize {
(p + q - 1) / q
}
fn pad(p: usize, q: usize) -> usize {
ceil_div(p, q) * q
}
fn quantize_q8_1(
src: &CudaView<f32>,
dst: &mut CudaSlice<u8>,
elem_count: usize,
ky: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
let kx = elem_count;
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, ky as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
let params = (src, dst, kx as i32, kx_padded as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(())
}
fn dequantize_f32(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
@ -28,39 +70,31 @@ fn dequantize(
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
GgmlDType::Q5_0 => {
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
(
"dequantize_block_q5_0",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
nb,
)
}
GgmlDType::Q5_1 => {
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
(
"dequantize_block_q5_1",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
nb,
)
}
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@ -83,9 +117,66 @@ fn dequantize(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mut_mal_vec(
fn dequantize_f16(
data: &CudaSlice<u8>,
y: &cudarc::driver::CudaView<f32>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0_f16",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1_f16",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
if is_k {
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
@ -93,6 +184,13 @@ fn dequantize_mut_mal_vec(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
let kernel_name = match dtype {
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
@ -107,8 +205,8 @@ fn dequantize_mut_mal_vec(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
@ -120,9 +218,149 @@ fn dequantize_mut_mal_vec(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn mul_mat_vec_via_q8_1(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
b_size: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols * b_size {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
if b_size == 0 || b_size > 8 {
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
}
// Start by quantizing y
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes =
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
1 => (nrows as u32, 4),
2..=4 => ((nrows as u32 + 1) / 2, 4),
5..=8 => ((nrows as u32 + 1) / 2, 2),
_ => crate::bail!("unexpected bsize {b_size}"),
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nblocks, 1, 1),
block_dim: (WARP_SIZE as u32, nwarps, 1),
shared_mem_bytes: 0,
};
let params = (
data,
&y_q8_1,
&dst,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
/* nrows_y */ ncols_padded as i32,
/* nrows_dst */ nrows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
#[allow(clippy::too_many_arguments)]
fn mul_mat_via_q8_1(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
x_rows: usize,
x_cols: usize,
y_rows: usize,
y_cols: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
}
if y.len() != y_rows * y_cols {
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
}
if x_cols != y_rows {
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
}
let k = x_cols;
// Start by quantizing y
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
let (kernel_name, mmq_x, mmq_y) = match dtype {
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
ceil_div(x_rows, mmq_y) as u32,
ceil_div(y_cols, mmq_x) as u32,
1,
),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
let params = (
/* vx */ data,
/* vy */ &y_q8_1,
/* dst */ &dst,
/* ncols_x */ x_cols as i32,
/* nrows_x */ x_rows as i32,
/* ncols_y */ y_cols as i32,
/* nrows_y */ k_padded as i32,
/* nrows_dst */ x_rows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
Ok(QCudaStorage {
data,
@ -140,6 +378,12 @@ impl QCudaStorage {
}
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
let vec = slice.to_vec();
T::to_float(&vec, dst)
}
let fast_kernel = matches!(
self.dtype,
GgmlDType::Q4_0
@ -155,78 +399,38 @@ impl QCudaStorage {
| GgmlDType::Q8K
);
if fast_kernel {
return dequantize(&self.data, self.dtype, elem_count, self.device());
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.
use crate::quantized::k_quants::GgmlType;
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
GgmlDType::F32 => {
let slice =
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
out.copy_from_slice(slice)
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
}
self.device
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
}
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
}
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
@ -255,7 +459,17 @@ impl QCudaStorage {
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
1
} else {
8
};
let use_vec_kernel = match layout.shape().dims() {
[b, m, _k] => b * m <= max_bm,
[b, _k] => *b <= max_bm,
_ => false,
};
if use_vec_kernel {
self.dequantize_matmul_vec(self_shape, storage, layout)
} else {
self.dequantize_matmul(self_shape, storage, layout)
@ -276,22 +490,31 @@ impl QCudaStorage {
Some((o1, o2)) => rhs.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
};
let (with_batch, k) = match rhs_l.shape().dims() {
[1, 1, k] => (true, k),
[1, k] => (false, k),
let (b_size, k) = match rhs_l.shape().dims() {
[b, m, k] => (b * m, *k),
[b, k] => (*b, *k),
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
};
if ncols != *k {
if ncols != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out =
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
let out_shape = if with_batch {
vec![1, 1, nrows]
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
} else {
vec![1, nrows]
mul_mat_vec_via_q8_1(
&self.data,
&rhs,
self.dtype,
ncols,
nrows,
b_size,
self.device(),
)?
};
let mut out_shape = rhs_l.shape().dims().to_vec();
out_shape.pop();
out_shape.push(nrows);
Ok((out, out_shape.into()))
}
@ -312,9 +535,30 @@ impl QCudaStorage {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
}
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
} else {
let storage = storage.as_cuda_slice::<f32>()?;
let storage = match layout.contiguous_offsets() {
Some((o1, o2)) => storage.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous {
op: "quantized-matmul",
}
.bt())?,
};
mul_mat_via_q8_1(
&self.data,
&storage,
self.dtype,
/* x_rows */ n,
/* x_cols */ k,
/* y_rows */ k,
/* y_cols */ b * m,
self.device(),
)?
};
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();
out_shape.push(n);
@ -322,11 +566,6 @@ impl QCudaStorage {
}
}
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
slice.to_vec()
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
device: &CudaDevice,
data: &[T],
@ -341,3 +580,101 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
dtype: T::DTYPE,
}))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn cuda_quantize_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let el = 256;
let el_padded = pad(el, MATRIX_ROW_PADDING);
let y_size_in_bytes =
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
#[test]
fn cuda_mmv_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
/* b_size */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
// Q8 means 1/256 precision.
assert_eq!(vs[0], 5561664.5);
let cuda_storage = dequantize_mul_mat_vec(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], 5561851.0);
Ok(())
}
#[test]
fn cuda_mm_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ 4,
/* x_cols */ ncols,
/* y_rows */ ncols,
/* y_cols */ 4,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
/*
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
x @ x.t() / 16
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
*/
assert_eq!(vs.len(), 16);
assert_eq!(vs[0], 347604.0);
assert_eq!(vs[1], 888153.06);
assert_eq!(vs[4], 869780.7);
assert_eq!(vs[5], 2483145.0);
assert_eq!(vs[11], 9407368.0);
assert_eq!(vs[14], 9470856.0);
assert_eq!(vs[15], 13138824.0);
Ok(())
}
}

View File

@ -24,6 +24,10 @@ impl QCudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -135,7 +135,6 @@ pub enum ValueType {
// The value is a UTF-8 non-null-terminated string, with length prepended.
String,
// The value is an array of other values, with the length and type prepended.
///
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
Array,
}

View File

@ -149,9 +149,12 @@ impl QMetalStorage {
let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
2 => (1, dst_shape[0]),
// We always use a single batch dimension and stack all the tensors in the batch on the
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
// properly.
let m = match dst_shape.len() {
3 => dst_shape[0] * dst_shape[1],
2 => dst_shape[0],
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
@ -163,18 +166,23 @@ impl QMetalStorage {
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&self.buffer,
&dst,
)
.map_err(MetalError::from)?;
// In some cases it would be better to use the mm variant, though it has its drawbacks
// around memory alignemnt.
for batch_id in 0..m {
candle_metal_kernels::call_quantized_matmul_mv_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(1, 1, n, k),
storage.buffer(),
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
&self.buffer,
batch_id * n * DType::F32.size_in_bytes(),
&dst,
)
.map_err(MetalError::from)?;
}
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
Ok((dst_storage, dst_shape))
}

View File

@ -1,4 +1,4 @@
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
@ -360,9 +360,24 @@ impl QTensor {
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
let is_variable = false;
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
.to_device(device)
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
}
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
// In the CUDA case, we have a specialized kernel as this can be useful for volta
// architectures. https://github.com/huggingface/candle/issues/2136
match &self.storage {
QStorage::Cuda(s) => {
let s = s.dequantize_f16(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
.to_device(device)
}
_ => {
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
Ok(s)
}
}
}
pub fn storage_size_in_bytes(&self) -> usize {
@ -378,6 +393,7 @@ impl QTensor {
pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
TensorF16(Tensor),
}
thread_local! {
@ -391,6 +407,17 @@ thread_local! {
}
}
thread_local! {
static DEQUANTIZE_ALL_F16: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let dequantize = match qtensor.dtype() {
@ -400,6 +427,9 @@ impl QMatMul {
let t = if dequantize {
let tensor = qtensor.dequantize(&qtensor.device())?;
Self::Tensor(tensor)
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
Self::TensorF16(tensor)
} else {
Self::QTensor(qtensor)
};
@ -409,6 +439,25 @@ impl QMatMul {
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
Self::from_arc(std::sync::Arc::new(qtensor))
}
pub fn dequantize_f16(&self) -> Result<Tensor> {
match self {
Self::QTensor(t) => t.dequantize_f16(&t.device()),
Self::Tensor(t) => t.to_dtype(DType::F16),
Self::TensorF16(t) => Ok(t.clone()),
}
}
pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
let w = self.dequantize_f16()?;
let in_dtype = xs.dtype();
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}
}
impl crate::CustomOp1 for QTensor {
@ -486,6 +535,15 @@ impl crate::Module for QMatMul {
};
xs.matmul(&w)
}
Self::TensorF16(w) => {
let in_dtype = xs.dtype();
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}
}
}
}

View File

@ -349,6 +349,30 @@ impl MmapedSafetensors {
}
}
pub struct SliceSafetensors<'a> {
safetensors: SafeTensors<'a>,
}
impl<'a> SliceSafetensors<'a> {
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
pub fn new(buffer: &'a [u8]) -> Result<Self> {
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
Ok(Self { safetensors })
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.safetensors.tensor(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
self.safetensors.tensors()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
Ok(self.safetensors.tensor(name)?)
}
}
pub struct BufferedSafetensors {
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
}

View File

@ -171,7 +171,7 @@ impl Shape {
}
let mut acc = 1;
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
if stride != acc {
if dim > 1 && stride != acc {
return false;
}
acc *= dim;
@ -186,7 +186,7 @@ impl Shape {
}
let mut acc = 1;
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
if stride != acc {
if dim > 1 && stride != acc {
return false;
}
acc *= dim;

239
candle-core/src/sort.rs Normal file
View File

@ -0,0 +1,239 @@
use crate::{Result, Tensor};
use rayon::prelude::*;
#[derive(Debug, Clone, Copy)]
struct ArgSort {
asc: bool,
last_dim: usize,
}
impl ArgSort {
fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
#[allow(clippy::uninit_vec)]
// Safety: indexes are set later in the parallelized section.
let mut sort_indexes = unsafe {
let el_count = layout.shape().elem_count();
let mut v = Vec::with_capacity(el_count);
v.set_len(el_count);
v
};
if self.asc {
sort_indexes
.par_chunks_exact_mut(self.last_dim)
.zip(vs.par_chunks_exact(self.last_dim))
.for_each(|(indexes, vs)| {
indexes
.iter_mut()
.enumerate()
.for_each(|(i, v)| *v = i as u32);
indexes.sort_by(|&i, &j| {
vs[i as usize]
.partial_cmp(&vs[j as usize])
.unwrap_or(std::cmp::Ordering::Greater)
})
});
} else {
sort_indexes
.par_chunks_exact_mut(self.last_dim)
.zip(vs.par_chunks_exact(self.last_dim))
.for_each(|(indexes, vs)| {
indexes
.iter_mut()
.enumerate()
.for_each(|(i, v)| *v = i as u32);
indexes.sort_by(|&j, &i| {
vs[i as usize]
.partial_cmp(&vs[j as usize])
.unwrap_or(std::cmp::Ordering::Greater)
})
});
}
sort_indexes
}
}
impl crate::CustomOp1 for ArgSort {
fn name(&self) -> &'static str {
"argsort"
}
fn cpu_fwd(
&self,
storage: &crate::CpuStorage,
layout: &crate::Layout,
) -> Result<(crate::CpuStorage, crate::Shape)> {
let sort_indexes = match storage {
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
};
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
Ok((sort_indexes, layout.shape().into()))
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, crate::Shape)> {
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
use crate::{CudaDevice, WithDType};
impl Map1Any for ArgSort {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &crate::Layout,
_wrap: W,
) -> Result<S> {
let slice = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let ncols_pad = next_power_of_2(ncols);
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
let cfg = LaunchConfig {
grid_dim: (1, nrows as u32, 1),
block_dim: (ncols_pad as u32, 1, 1),
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
};
unsafe { func.launch(cfg, params) }.w()?;
Ok(S::U32(dst))
}
}
use crate::backend::BackendStorage;
let dev = storage.device();
let slice = self.map(&storage.slice, dev, layout)?;
let dst = crate::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, layout.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, crate::Shape)> {
use crate::backend::BackendStorage;
use crate::DType;
let name = {
if self.asc {
match storage.dtype() {
DType::BF16 => "asort_asc_bf16",
DType::F16 => "asort_asc_f16",
DType::F32 => "asort_asc_f32",
DType::F64 => "asort_asc_f64",
DType::U8 => "asort_asc_u8",
DType::U32 => "asort_asc_u32",
DType::I64 => "asort_asc_i64",
}
} else {
match storage.dtype() {
DType::BF16 => "asort_desc_bf16",
DType::F16 => "asort_desc_f16",
DType::F32 => "asort_desc_f32",
DType::F64 => "asort_desc_f64",
DType::U8 => "asort_desc_u8",
DType::U32 => "asort_desc_u32",
DType::I64 => "asort_desc_i64",
}
}
};
let device = storage.device();
let kernels = device.kernels();
let command_buffer = device.command_buffer()?;
let el = layout.shape().elem_count();
let ncols = self.last_dim;
let nrows = el / ncols;
let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
let dst = device.new_buffer(el, DType::U32, "asort")?;
let mut ncols_pad = 1;
while ncols_pad < ncols {
ncols_pad *= 2;
}
candle_metal_kernels::call_arg_sort(
device.metal_device(),
&command_buffer,
kernels,
name,
nrows,
ncols,
ncols_pad,
src,
&dst,
)
.map_err(crate::Error::wrap)?;
let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
Ok((dst, layout.shape().clone()))
}
}
#[allow(unused)]
fn next_power_of_2(x: usize) -> usize {
let mut n = 1;
while n < x {
n *= 2
}
n
}
impl Tensor {
/// Returns the indices that sort the tensor along the last dimension.
///
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
/// descending order. The sort is unstable so there is no guarantees on the final order when it
/// comes to ties.
pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
if !self.is_contiguous() {
return Err(crate::Error::RequiresContiguous {
op: "arg_sort_last_dim",
});
}
let last_dim = match self.dims().last() {
None => crate::bail!("empty last-dim in arg-sort"),
Some(last_dim) => *last_dim,
};
// No need for a backward pass for arg sort.
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
}
/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
/// sorted indexes.
///
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
/// descending order. The sort is unstable so there is no guarantees on the final order when it
/// comes to ties.
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
if !self.is_contiguous() {
return Err(crate::Error::RequiresContiguous {
op: "sort_last_dim",
});
}
let asort = self.arg_sort_last_dim(asc)?;
let sorted = self.gather(&asort, crate::D::Minus1)?;
Ok((sorted, asort))
}
}

View File

@ -1,6 +1,7 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::op::{self, CmpOp, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@ -43,9 +44,19 @@ impl Storage {
}
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
let lhs = self.device().location();
let rhs = rhs.device().location();
if lhs != rhs {
let lhs_device = self.device();
let rhs_device = rhs.device();
let lhs = lhs_device.location();
let rhs = rhs_device.location();
let same_device = if self.device().is_metal() {
// On metal, we require the device to be exactly the same rather than
// having the same location. In cuda this is not necessary as all CudaDevice on the
// same GPU will use the same cuda stream.
lhs_device.same_device(&rhs_device)
} else {
lhs == rhs
};
if !same_device {
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
} else {
Ok(())
@ -252,6 +263,51 @@ impl Storage {
}
}
pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
match self {
Self::Cpu(storage) => c.cpu_fwd(storage, l),
Self::Cuda(storage) => c.cuda_fwd(storage, l),
Self::Metal(storage) => c.metal_fwd(storage, l),
}
}
pub(crate) fn inplace_op2(
&mut self,
l1: &Layout,
t2: &Self,
l2: &Layout,
c: &dyn InplaceOp2,
) -> Result<()> {
self.same_device(t2, c.name())?;
match (self, t2) {
(Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
(Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
(Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
_ => unreachable!(),
}
}
pub(crate) fn inplace_op3(
&mut self,
l1: &Layout,
t2: &Self,
l2: &Layout,
t3: &Self,
l3: &Layout,
c: &dyn InplaceOp3,
) -> Result<()> {
self.same_device(t2, c.name())?;
self.same_device(t3, c.name())?;
match (self, t2, t3) {
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
c.metal_fwd(s1, l1, s2, l2, s3, l3)
}
_ => unreachable!(),
}
}
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
match self {
Storage::Cpu(storage) => {

View File

@ -1,9 +1,7 @@
//! Tensors are N-dimensional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
@ -81,6 +79,9 @@ macro_rules! unary_op {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self) -> Result<Self> {
let shape = self.shape();
if shape.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self
.storage()
.unary_impl::<crate::op::$op_name>(self.layout())?;
@ -94,6 +95,9 @@ macro_rules! binary_op {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
if shape.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(),
self.layout(),
@ -116,6 +120,9 @@ macro_rules! binary_op_scalar {
.broadcast_as(self.shape())?,
};
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(),
self.layout(),
@ -449,7 +456,15 @@ impl Tensor {
shape: S,
device: &Device,
) -> Result<Self> {
Self::new_impl(array, shape.into(), device, false)
let shape = shape.into();
let n: usize = shape.elem_count();
let buffer_size: usize = array.len();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let storage = device.storage_from_slice(array)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, false))
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
@ -512,6 +527,7 @@ impl Tensor {
unary_op!(ceil, Ceil);
unary_op!(floor, Floor);
unary_op!(round, Round);
unary_op!(sign, Sign);
/// Round element of the input tensor to the nearest integer.
///
@ -647,6 +663,9 @@ impl Tensor {
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().affine(self.layout(), mul, add)?;
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
Ok(from_storage(storage, self.shape(), op, false))
@ -654,6 +673,9 @@ impl Tensor {
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
pub fn elu(&self, alpha: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().elu(self.layout(), alpha)?;
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
Ok(from_storage(storage, self.shape(), op, false))
@ -661,6 +683,9 @@ impl Tensor {
/// Raise the tensor to some float exponent `e`.
pub fn powf(&self, e: f64) -> Result<Self> {
if self.elem_count() == 0 {
return Ok(self.clone());
}
let storage = self.storage().powf(self.layout(), e)?;
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
Ok(from_storage(storage, self.shape(), op, false))
@ -1155,6 +1180,9 @@ impl Tensor {
let n = b_dims[dim - 1];
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
if c_shape.elem_count() == 0 || k == 0 {
return Tensor::zeros(c_shape, self.dtype(), self.device());
}
let batching: usize = a_dims[..dim - 2].iter().product();
let batching_b: usize = b_dims[..dim - 2].iter().product();
if k != k2 || batching != batching_b {
@ -1351,7 +1379,7 @@ impl Tensor {
}
.bt())?
}
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let offset = start * src.dims()[1..].iter().product::<usize>();
@ -2001,7 +2029,7 @@ impl Tensor {
Ok(self.clone())
} else {
let shape = self.shape();
let mut storage = self.device().zeros(shape, self.dtype())?;
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let op = BackpropOp::new1(self, Op::Copy);
@ -2009,11 +2037,21 @@ impl Tensor {
}
}
/// Returns a tensor that is in row major order. This always makes a copy.
pub fn force_contiguous(&self) -> Result<Tensor> {
let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let op = BackpropOp::new1(self, Op::Copy);
Ok(from_storage(storage, shape.clone(), op, false))
}
/// Create a variable based on the values currently stored in a tensor. The storage is always
/// copied.
pub(crate) fn make_var(&self) -> Result<Tensor> {
let shape = self.shape().clone();
let mut storage = self.device().zeros(&shape, self.dtype())?;
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, BackpropOp::none(), true))
@ -2066,7 +2104,7 @@ impl Tensor {
};
Ok(Tensor(Arc::new(tensor_)))
} else {
let mut storage = self.device().zeros(&shape, self.dtype())?;
let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, op, false))
@ -2093,8 +2131,19 @@ impl Tensor {
let dim = dim.to_index(self.shape(), "squeeze")?;
if dims[dim] == 1 {
let mut dims = dims.to_vec();
let mut strides = self.stride().to_vec();
dims.remove(dim);
self.reshape(dims)
strides.remove(dim);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
op: BackpropOp::new1(self, Op::Reshape),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
} else {
Ok(self.clone())
}
@ -2115,10 +2164,24 @@ impl Tensor {
/// ```
pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<Self> {
let mut dims = self.dims().to_vec();
let mut strides = self.stride().to_vec();
let dim = dim.to_index_plus_one(self.shape(), "unsqueeze")?;
// Cannot panic because to_index_plus_one already checks dimensions
dims.insert(dim, 1);
self.reshape(dims)
// Any stride would work here, but we pick one so as to maximize the probability to remain
// C contiguous.
let stride = if dim < strides.len() { strides[dim] } else { 1 };
strides.insert(dim, stride);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: Layout::new(dims.into(), strides, self.layout.start_offset()),
op: BackpropOp::new1(self, Op::Reshape),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
/// Stacks two or more tensors along a particular dimension.
@ -2231,6 +2294,10 @@ impl Tensor {
self.storage.read().unwrap()
}
pub(crate) fn storage_mut(&self) -> std::sync::RwLockWriteGuard<'_, Storage> {
self.storage.write().unwrap()
}
// If we extend the visibility of this function to be usable outside of this crate, we should
// make it unsafe.
pub(crate) fn storage_mut_and_layout(
@ -2252,96 +2319,6 @@ impl Tensor {
std::ptr::eq(lhs, rhs)
}
/// Applies a unary custom op without backward support
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a binary custom op without backward support
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
let (storage, shape) =
self.storage()
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a ternary custom op without backward support
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a unary custom op.
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
let (storage, shape) = self
.storage()
.apply_op1(self.layout(), c.as_ref().as_ref())?;
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
self.apply_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
pub fn apply_op2_arc(
&self,
rhs: &Self,
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
self.apply_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
pub fn apply_op3_arc(
&self,
t2: &Self,
t3: &Self,
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
Op::CustomOp3(t1, t2, t3, c.clone())
});
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
&self,
t2: &Self,
t3: &Self,
c: C,
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
/// Normalize a 'relative' axis value: positive values are kept, negative
/// values means counting the dimensions from the back.
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {

View File

@ -58,20 +58,18 @@ impl Tensor {
}
}
}
if dim == 0 {
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
if all_contiguous {
Self::cat_contiguous(args, dim)
} else if dim == 0 {
Self::cat0(args)
} else {
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
if all_contiguous {
Self::cat_contiguous(args, dim)
} else {
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
let args: Vec<Tensor> = args
.iter()
.map(|a| a.as_ref().transpose(0, dim))
.collect::<Result<Vec<_>>>()?;
let cat = Self::cat0(&args)?;
cat.transpose(0, dim)
}
}
@ -141,7 +139,7 @@ impl Tensor {
}
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
let mut storage = device.zeros(&shape, dtype)?;
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage()
@ -215,7 +213,7 @@ impl Tensor {
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
let mut storage = device.zeros(&shape, dtype)?;
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
let mut dst_o = 0;
for arg in args.iter() {
let arg = arg.as_ref();
@ -237,4 +235,66 @@ impl Tensor {
}
Ok(crate::tensor::from_storage(storage, shape, op, false))
}
/// Set the values on `self` using values from `src`. The copy starts at the specified
/// `offset` for the target dimension `dim` on `self`.
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
/// has to be greater than or equal to `offset` plus the `src` size.
///
/// Note that this modifies `self` in place and as such is not compatibel with
/// back-propagation.
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
let dim = dim.to_index(self.shape(), "slice-set")?;
if !self.is_contiguous() || !src.is_contiguous() {
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
}
if self.dtype() != src.dtype() {
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: src.dtype(),
op: "slice-set",
}
.bt())?
}
if self.device().location() != src.device().location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: self.device().location(),
rhs: src.device().location(),
op: "slice-set",
}
.bt())?
}
if self.rank() != src.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: self.rank(),
got: src.rank(),
shape: self.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
if dim_idx == dim && *v2 + offset > *v1 {
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
}
if dim_idx != dim && v1 != v2 {
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
}
}
let block_size: usize = src.dims().iter().skip(1 + dim).product();
let d1: usize = src.dims().iter().take(dim).product();
let d2 = block_size * src.dims()[dim];
let dst_o = self.layout().start_offset() + offset * block_size;
let src_o = src.layout().start_offset();
src.storage().copy2d(
&mut self.storage_mut(),
d1,
d2,
/* src_s */ d2,
/* dst_s */ block_size * self.dims()[dim],
src_o,
dst_o,
)?;
Ok(())
}
}

View File

@ -34,9 +34,14 @@ impl Var {
Ok(Self(inner))
}
// Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
pub fn from_tensor(t: &Tensor) -> Result<Self> {
let inner = t.make_var()?;
Ok(Self(inner))
if t.is_variable() {
Ok(Self(t.clone()))
} else {
let inner = t.make_var()?;
Ok(Self(inner))
}
}
pub fn rand_f64<S: Into<Shape>>(

View File

@ -54,11 +54,6 @@ fn conv1d(dev: &Device) -> Result<()> {
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
// conv-transposes are not implemented for metal.
if dev.is_metal() {
return Ok(());
}
let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
for w in [w.clone(), w.contiguous()?] {
@ -140,7 +135,7 @@ fn conv2d(dev: &Device) -> Result<()> {
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
],
dev,
)?;
@ -168,33 +163,34 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
if !dev.is_metal() {
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7, 7]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
[
[
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
],
[
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
]
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
],
[
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
]
);
}
]
);
// Dilations.
let res = t.conv2d(&w, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 1, 1]);
@ -203,44 +199,37 @@ fn conv2d(dev: &Device) -> Result<()> {
[2.45, -2.3504],
);
if !dev.is_metal() {
// Transpose and dilations.
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
assert_eq!(res.dims(), [1, 2, 9, 9]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
// Transpose and dilations.
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
assert_eq!(res.dims(), [1, 2, 9, 9]);
assert_eq!(
test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
[
[
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
[
-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51,
-3.5024
],
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
],
[
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
[
-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827,
1.0171
]
]
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
],
[
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
]
);
}
]
);
Ok(())
}
@ -287,19 +276,13 @@ fn conv2d_small(dev: &Device) -> Result<()> {
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
-1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640,
-0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0,
3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0
]
);
// conv-transposes are not implemented for metal
if dev.is_metal() {
return Ok(());
}
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
@ -402,9 +385,6 @@ print(w.grad[0])
*/
fn conv2d_grad(dev: &Device) -> Result<()> {
// conv-transposes are not implemented for metal
if dev.is_metal() {
return Ok(());
}
use candle_core::Var;
let t = Var::from_slice(
&[
@ -417,7 +397,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
-0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
-0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
],
(1, 4, 5, 5),
dev,
@ -602,6 +582,154 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
]
);
// Conv Transpose 2d Test
//tested against following python
// import torch
// torch.manual_seed(4242)
// padding = 4
// outpadding = 2
// dilation = 3
// stride = 3
// input = torch.randn((1, 4, 7, 5), requires_grad=True)
// kernel = torch.randn((4, 2, 3, 5), requires_grad=True)
// print("input", input.flatten())
// print("kernel", kernel.flatten())
// res = torch.nn.functional.conv_transpose2d(
// input,
// kernel,
// stride=stride,
// padding=padding,
// dilation=dilation,
// output_padding=outpadding,
// )
// res.retain_grad()
// print(res.shape)
// loss = (res**2).sum()
// print(loss)
// loss.backward()
// print(input.grad.shape)
// print("input grad", torch.round(input.grad, decimals=1))
// print(kernel.grad.shape)
// print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1))
let padding = 4;
let outpadding = 2;
let dilation = 3;
let stride = 3;
let t = Var::from_slice(
&[
0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997,
3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843,
0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013,
-0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130,
1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071,
1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090,
0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323,
-1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742,
0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912,
-0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465,
-0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264,
1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451,
-0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258,
-2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186,
1.6475, 0.2219,
],
(1, 4, 7, 5),
dev,
)?;
#[rustfmt::skip]
let w = Var::from_slice(
&[
-1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234,
-0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762,
0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204,
0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555,
0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990,
0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181,
0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481,
0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509,
0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732,
-0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071,
-1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604,
0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478,
],
(4, 2, 3, 5),
dev,
)?;
let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
assert_eq!(
test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
[
// torch gets 89.1
-89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0,
-15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9,
52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2,
106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6,
-27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5,
-10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0,
-52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9,
-20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5,
92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5,
-28.4, 85.0, -18.3, 107.0, 28.3, -71.8
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
[
[
[32.3, -41.6, -24.0, 14.1, 17.6],
[-11.8, 72.5, 87.6, 46.4, 61.5],
[115.0, 108.5, -48.6, -63.4, -50.0],
[51.3, 5.4, 31.3, 91.1, -30.9],
[52.7, 92.8, -68.0, -47.0, 83.0],
// pytorch gets -107.1
[-10.2, -107.0, -5.4, 213.1, -31.4],
[-2.4, 65.1, 9.2, -146.2, -24.2]
],
[
[-72.6, -63.9, -61.9, 45.3, 33.0],
[79.3, -0.5, -26.2, 78.2, 42.7],
[90.9, 141.6, 40.1, -62.7, 37.0],
[32.8, 198.2, -0.8, -31.1, 27.3],
// torch gets 48.0
[34.5, 34.9, -47.9, 127.6, -12.3],
[-61.4, -3.2, -2.9, -10.9, -16.6],
[74.6, 60.1, -68.9, 34.5, -50.4]
],
[
[37.5, -56.9, -43.6, -13.5, -9.9],
[40.0, 97.3, 28.6, 14.2, -30.1],
[-22.3, -126.3, -68.8, -8.2, 26.1],
[-32.9, 37.3, 108.5, -54.8, 29.6],
[34.9, -176.9, -125.0, -28.3, -13.9],
[-54.9, 142.6, 62.1, -80.4, -65.6],
[7.4, -91.1, -67.6, 35.0, 39.7]
],
[
[-57.2, -40.9, -10.1, 32.6, 29.4],
[18.7, -18.0, 29.5, -1.2, 59.2],
[-14.0, -74.4, 19.8, -117.0, 58.2],
[-21.8, 163.5, -71.1, -99.0, 80.9],
[-58.9, -10.9, 93.8, -139.6, 98.0],
// torch gets 54.5
[-54.4, 135.3, 6.0, -79.1, 134.6],
[27.5, -76.0, 43.4, -2.8, -7.8]
]
]
);
Ok(())
}

View File

@ -112,3 +112,34 @@ fn custom_op1_with_backward() -> Result<()> {
Ok(())
}
impl candle_core::InplaceOp1 for Elu {
fn name(&self) -> &'static str {
"elu"
}
fn cpu_fwd(&self, s: &mut CpuStorage, _l: &Layout) -> Result<()> {
let alpha = self.alpha;
match s {
CpuStorage::BF16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
CpuStorage::F16(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
CpuStorage::F32(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
CpuStorage::F64(s) => s.iter_mut().for_each(|v| *v = fwd(*v, alpha)),
_ => candle_core::bail!("unsupported dtype for inplace elu"),
}
Ok(())
}
}
#[test]
fn inplace_op1() -> Result<()> {
let cpu = &Device::Cpu;
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
let t = (t - 5.)?;
t.inplace_op1(&Elu { alpha: 1. })?;
assert_eq!(
to_vec1_round(&t, 4)?,
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
);
Ok(())
}

View File

@ -88,7 +88,7 @@ fn strided_blocks() -> Result<()> {
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
let tensor = tensor.i((.., 1))?;
let tensor = tensor.i((.., 1))?.contiguous()?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { start_offset, len } => {
assert_eq!(start_offset, 0);
@ -100,6 +100,20 @@ fn strided_blocks() -> Result<()> {
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
let tensor = tensor.i((.., 1))?;
match tensor.strided_blocks() {
candle::StridedBlocks::SingleBlock { .. } => {
panic!("unexpected block structure")
}
candle::StridedBlocks::MultipleBlocks {
block_len,
block_start_index,
} => {
assert_eq!(block_len, 4);
assert_eq!(block_start_index.collect::<Vec<_>>(), &[4, 16])
}
};
let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?;
match tensor.t()?.strided_blocks() {
candle::StridedBlocks::SingleBlock { .. } => {
panic!("unexpected block structure")

View File

@ -0,0 +1,106 @@
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
fn matmul(device: &Device) -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), device)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let b = Tensor::from_slice(&data, (2, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
let data = vec![1.0f32, 2.0];
let a = Tensor::from_slice(&data, (2, 1), device)?;
let data = vec![3.0f32, 4.0];
let b = Tensor::from_slice(&data, (1, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 3), device)?;
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (3, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
let c = a.matmul(&b)?;
assert_eq!(c.to_vec3::<f32>()?, &expected);
// Also perform the matmul on contiguous transposed versions.
let a_tt = a.t()?.contiguous()?.t()?;
assert!(!a_tt.is_contiguous());
assert_eq!(a.dims(), a_tt.dims());
assert_eq!(a_tt.stride(), &[6, 1, 2]);
let b_tt = b.t()?.contiguous()?.t()?;
assert!(!b_tt.is_contiguous());
assert_eq!(b.dims(), b_tt.dims());
assert_eq!(b_tt.stride(), &[6, 1, 3]);
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
Ok(())
}
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
let out = lhs.broadcast_matmul(&rhs)?;
assert_eq!(out.dims(), &[3, 6, 4, 2]);
for idx1 in 0..3 {
for idx2 in 0..6 {
let out = out.i((idx1, idx2))?;
let lhs = lhs.i((idx1, 0))?;
let rhs = rhs.i(idx2)?;
let out2 = lhs.matmul(&rhs);
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
// With cuda, we see errors of up to ~1e-12.
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
}
}
Ok(())
}
// https://github.com/huggingface/candle/issues/1948
fn squeeze_mm(device: &Device) -> Result<()> {
let seq_len = 8_usize;
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
let x = a.i((.., seq_len - 1, ..))?;
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
let x = x.matmul(&w)?;
assert_eq!(x.dims(), &[1, 32]);
Ok(())
}
// https://github.com/huggingface/candle/issues/1992
fn mm_layout(device: &Device) -> Result<()> {
let a = Tensor::arange(0f32, 16f32, device)?.reshape((1, 1, 4, 4))?;
let b = Tensor::arange(0f32, 8f32, device)?.reshape((1, 1, 4, 2))?;
let mm1 = a.matmul(&b)?;
// Forces the layout to be:
// shape: [1, 1, 4, 2], stride: [8, 2, 2, 1], start_offset: 0
// This is still a contiguous matrix but matmul checks are only the two last dimensions have
// non 1 sizes but matmul check may be reluctant to handle it.
let b = b.transpose(1, 2)?.force_contiguous()?.transpose(1, 2)?;
let mm2 = a.matmul(&b)?;
let diff = (mm1 - mm2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
Ok(())
}
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
test_device!(mm_layout, mm_layout_cpu, mm_layout_gpu, mm_layout_metal);

View File

@ -2,9 +2,6 @@ use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
// https://github.com/huggingface/candle/issues/364
fn avg_pool2d(dev: &Device) -> Result<()> {
if dev.is_metal() {
return Ok(());
}
let data: Vec<f32> = vec![
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
@ -22,9 +19,6 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
}
fn max_pool2d(dev: &Device) -> Result<()> {
if dev.is_metal() {
return Ok(());
}
let data: Vec<f32> = vec![
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
];

View File

@ -3,7 +3,7 @@ use candle_core::{
quantized::{self, GgmlDType},
test_device,
test_utils::to_vec2_round,
Device, Module, Result, Tensor,
DType, Device, IndexOp, Module, Result, Tensor,
};
use quantized::{k_quants, GgmlType};
use rand::prelude::*;
@ -47,18 +47,14 @@ fn test_matmul(
}
fn quantized_matmul(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let lhs_s = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[
@ -67,7 +63,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
]
);
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
let mm = tensor_lhs.matmul(&tensor_rhs)?;
let mm = lhs.matmul(&tensor_rhs)?;
assert_eq!(
mm.to_vec2::<f32>()?,
&[
@ -79,7 +75,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
let res = matmul.forward(&lhs)?;
match device {
Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?,
@ -89,7 +85,15 @@ fn quantized_matmul(device: &Device) -> Result<()> {
[341970.0, 994574.0, 1656181.0, 2302182.0]
]
),
_ => assert_eq!(
Device::Cuda(_) => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[84866.0, 214045.0, 344676.0, 473707.0],
[213425.0, 604313.0, 1000431.0, 1387960.0],
[342030.0, 994630.0, 1656248.0, 2302250.0]
]
),
Device::Cpu => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[85120.0, 214562.0, 345455.0, 474748.0],
@ -98,22 +102,16 @@ fn quantized_matmul(device: &Device) -> Result<()> {
]
),
}
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
Ok(())
}
fn quantized_matmul_neg(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k))
let lhs_s = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n)
@ -121,7 +119,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
.collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[
@ -129,7 +127,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
-196472.0, 63012.0, 324585.0, 587902.0
]
);
let mm = tensor_lhs.matmul(&tensor_rhs)?;
let mm = lhs.matmul(&tensor_rhs)?;
assert_eq!(
to_vec2_round(&mm, 0)?,
&[
@ -141,7 +139,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
let res = matmul.forward(&lhs)?;
match device {
Device::Metal(_) => assert_eq!(
to_vec2_round(&res, 0)?,
@ -151,7 +149,15 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
[-196102.0, 63022.0, 324233.0, 587191.0]
]
),
_ => assert_eq!(
Device::Cuda(_) => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243740.0, -19762.0, -285476.0, -550498.0],
[23774.0, 21645.0, 19395.0, 18364.0],
[-196045.0, 63030.0, 324120.0, 587079.0]
]
),
Device::Cpu => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243524.0, -19596.0, -285051.0, -549815.0],
@ -160,22 +166,58 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
]
),
}
let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?;
let res2 = matmul.forward(&lhs2)?;
let res2 = res2.i(1)?;
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if device.is_cuda() {
assert!(diff < 0.1);
} else {
assert_eq!(diff, 0.);
}
Ok(())
}
test_device!(
quantized_matmul,
quantized_matmul_cpu,
quantized_matmul_cuda,
quantized_matmul_metal
);
test_device!(
quantized_matmul_neg,
quantized_matmul_neg_cpu,
quantized_matmul_neg_cuda,
quantized_matmul_neg_metal
);
fn qmm_batch(dev: &Device) -> Result<()> {
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.shape().dims(), [2, 6]);
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
let mm2 = rhs.forward(&lhs2)?;
assert_eq!(mm2.shape().dims(), [4, 6]);
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff2, 0.0);
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
let mm3 = rhs.forward(&lhs3)?;
assert_eq!(mm3.shape().dims(), [6, 6]);
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff3, 0.0);
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff3, 0.0);
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
let mm4 = rhs.forward(&lhs4)?;
assert_eq!(mm4.shape().dims(), [12, 6]);
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
if dev.is_cuda() {
// We use a different kernel for sizes from 1 to 8 on cuda which explains
// the difference here.
assert!(0. < diff4 && diff4 < 1e-4)
} else {
assert_eq!(diff4, 0.0)
};
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff4, 0.0);
Ok(())
}
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
fn quantize_q4_0(device: &Device) -> Result<()> {
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
@ -183,6 +225,13 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!(
dst.to_vec1::<f32>()?,
&[
@ -209,6 +258,13 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!(
round_vector(&dst.to_vec1::<f32>()?),
&[
@ -235,6 +291,13 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!(
round_vector(&dst.to_vec1::<f32>()?),
&[
@ -261,6 +324,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
assert_eq!(
round_vector(&dst.to_vec1::<f32>()?),
&[
@ -345,6 +415,13 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
if error > max_error {
bail!(
@ -362,6 +439,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -381,6 +465,13 @@ fn quantize_q2k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
@ -395,6 +486,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -414,6 +512,13 @@ fn quantize_q3k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
@ -428,6 +533,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -447,6 +559,13 @@ fn quantize_q4k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
@ -461,6 +580,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -480,6 +606,13 @@ fn quantize_q5k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
@ -494,6 +627,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -513,6 +653,13 @@ fn quantize_q6k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
@ -527,6 +674,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let dst_f16 = quant.dequantize_f16(device)?;
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src = src.to_vec1::<f32>()?;
let dst = dst.to_vec1::<f32>()?;
@ -546,6 +700,13 @@ fn quantize_q8k(device: &Device) -> Result<()> {
let src_big = get_test_vector2(128.0, 1024, device)?;
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
let dst_big = quant_big.dequantize(device)?;
let dst_big_f16 = quant_big.dequantize_f16(device)?;
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
.to_dtype(DType::F32)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;

View File

@ -1,5 +1,31 @@
use candle_core::{DType, Result, Tensor};
struct TmpFile(std::path::PathBuf);
impl TmpFile {
fn create(base: &str) -> TmpFile {
let filename = std::env::temp_dir().join(format!(
"candle-{}-{}-{:?}",
base,
std::process::id(),
std::thread::current().id(),
));
TmpFile(filename)
}
}
impl std::convert::AsRef<std::path::Path> for TmpFile {
fn as_ref(&self) -> &std::path::Path {
self.0.as_path()
}
}
impl Drop for TmpFile {
fn drop(&mut self) {
std::fs::remove_file(&self.0).unwrap()
}
}
#[test]
fn npy() -> Result<()> {
let npy = Tensor::read_npy("tests/test.npy")?;
@ -22,3 +48,24 @@ fn npz() -> Result<()> {
);
Ok(())
}
#[test]
fn safetensors() -> Result<()> {
use candle_core::safetensors::Load;
let tmp_file = TmpFile::create("st");
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
t.save_safetensors("t", &tmp_file)?;
// Load from file.
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
let t2 = st.get("t").unwrap();
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0f32);
// Load from bytes.
let bytes = std::fs::read(tmp_file)?;
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0f32);
Ok(())
}

View File

@ -96,6 +96,40 @@ fn clamp(device: &Device) -> Result<()> {
Ok(())
}
fn asort(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
let indexes = tensor.arg_sort_last_dim(true)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
);
let indexes = tensor.arg_sort_last_dim(false)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
);
let (sorted, indexes) = tensor.sort_last_dim(true)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
);
assert_eq!(
sorted.to_vec2::<f32>()?,
[[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]]
);
let (sorted, indexes) = tensor.sort_last_dim(false)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
);
assert_eq!(
sorted.to_vec2::<f32>()?,
[[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]]
);
Ok(())
}
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
@ -106,6 +140,9 @@ fn unary_op(device: &Device) -> Result<()> {
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
]
);
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
assert!(max_diff.to_vec0::<f32>()? < 5e-3);
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[
@ -148,6 +185,14 @@ fn unary_op(device: &Device) -> Result<()> {
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
[3000.0, 300.]
);
let tensor = Tensor::new(
&[-1.01f32, -0.9, -0.1, 0.0, -0.0, 0.1, 0.9, 1.0, 1.1],
device,
)?;
assert_eq!(
tensor.sign()?.to_vec1::<f32>()?,
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
);
Ok(())
}
@ -620,6 +665,30 @@ fn broadcast(device: &Device) -> Result<()> {
Ok(())
}
fn slice_set(device: &Device) -> Result<()> {
let (b, h, max_t, d) = (2, 4, 7, 3);
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
cache.slice_set(&tensor, 2, 0)?;
let cache_t = cache.narrow(2, 0, 4)?;
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
cache.slice_set(&tensor, 2, 1)?;
let cache_t = cache.narrow(2, 1, 4)?;
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
cache.slice_set(&ones, 2, 6)?;
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let diff = (cache.narrow(2, 6, 1)? - 1.)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
Ok(())
}
fn cat(device: &Device) -> Result<()> {
// 1D
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
@ -707,6 +776,8 @@ fn embeddings(device: &Device) -> Result<()> {
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
Ok(())
}
@ -734,44 +805,47 @@ fn index_select(device: &Device) -> Result<()> {
[9.0, 10.0, 11.0]
]
);
let hs = t.index_select(&ids, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 2.0, 1.0],
[3.0, 5.0, 4.0],
[6.0, 8.0, 7.0],
[9.0, 11.0, 10.0]
]
);
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
);
// Prior to https://github.com/huggingface/candle/pull/1022
// There would be a bug where the last values in the result tensor would be set to 0.
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
]
);
for dtype in [DType::U8, DType::U32, DType::I64] {
let ids = ids.to_dtype(dtype)?;
let hs = t.index_select(&ids, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 2.0, 1.0],
[3.0, 5.0, 4.0],
[6.0, 8.0, 7.0],
[9.0, 11.0, 10.0]
]
);
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
);
// Prior to https://github.com/huggingface/candle/pull/1022
// There would be a bug where the last values in the result tensor would be set to 0.
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
]
);
// Test when selecting dim > 0 with ids size different from elem count of
// target dim in source/input.
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
let hs = t.index_select(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
// Test when selecting dim > 0 with ids size different from elem count of
// target dim in source/input.
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
let hs = t.index_select(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
}
Ok(())
}
@ -933,74 +1007,6 @@ fn gather(device: &Device) -> Result<()> {
Ok(())
}
fn matmul(device: &Device) -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), device)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let b = Tensor::from_slice(&data, (2, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
let data = vec![1.0f32, 2.0];
let a = Tensor::from_slice(&data, (2, 1), device)?;
let data = vec![3.0f32, 4.0];
let b = Tensor::from_slice(&data, (1, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 3), device)?;
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (3, 2), device)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
let c = a.matmul(&b)?;
assert_eq!(c.to_vec3::<f32>()?, &expected);
// Also perform the matmul on contiguous transposed versions.
let a_tt = a.t()?.contiguous()?.t()?;
assert!(!a_tt.is_contiguous());
assert_eq!(a.dims(), a_tt.dims());
assert_eq!(a_tt.stride(), &[6, 1, 2]);
let b_tt = b.t()?.contiguous()?.t()?;
assert!(!b_tt.is_contiguous());
assert_eq!(b.dims(), b_tt.dims());
assert_eq!(b_tt.stride(), &[6, 1, 3]);
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
Ok(())
}
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
let out = lhs.broadcast_matmul(&rhs)?;
assert_eq!(out.dims(), &[3, 6, 4, 2]);
for idx1 in 0..3 {
for idx2 in 0..6 {
let out = out.i((idx1, idx2))?;
let lhs = lhs.i((idx1, 0))?;
let rhs = rhs.i(idx2)?;
let out2 = lhs.matmul(&rhs);
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
// With cuda, we see errors of up to ~1e-12.
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
}
}
Ok(())
}
fn broadcasting(device: &Device) -> Result<()> {
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
let t2 = Tensor::new(&[100f32, 200f32], device)?;
@ -1135,6 +1141,27 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
fn zero_dim(device: &Device) -> Result<()> {
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
assert_eq!(t.dims3()?, (4, 0, 1));
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
assert_eq!(t_cat.dims3()?, (4, 3, 1));
let t_cat = Tensor::cat(&[&t, &t], 1)?;
assert_eq!(t_cat.dims3()?, (4, 0, 1));
let t_unary = t.sqrt()?;
assert_eq!(t_unary.dims3()?, (4, 0, 1));
let t_plus = (&t + 1.)?;
assert_eq!(t_plus.dims3()?, (4, 0, 1));
let t_mm = t2.matmul(&t.t()?)?;
assert_eq!(t_mm.dims3()?, (4, 3, 0));
let t_mm = t.matmul(&t2.t()?)?;
assert_eq!(t_mm.dims3()?, (4, 0, 3));
let t_mm = t.t()?.matmul(&t)?;
assert_eq!(t_mm.dims3()?, (4, 1, 1));
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
@ -1143,6 +1170,7 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);
@ -1154,13 +1182,6 @@ test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
test_device!(
broadcast_matmul,
broadcast_matmul_cpu,
broadcast_matmul_gpu,
broadcast_matmul_metal
);
test_device!(
broadcasting,
broadcasting_cpu,
@ -1189,7 +1210,9 @@ test_device!(
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
@ -1317,8 +1340,8 @@ fn pow() -> Result<()> {
let rhs = (&lhs - 2.)?;
let res = lhs.pow(&rhs)?;
assert_eq!(
test_utils::to_vec2_round(&res, 4)?,
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
test_utils::to_vec2_round(&res, 3)?,
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
);
Ok(())
}

View File

@ -25,8 +25,9 @@ hf-hub = { workspace = true, features = ["tokio"] }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
@ -41,7 +42,7 @@ clap = { workspace = true }
imageproc = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
rusttype = { workspace = true }
ab_glyph = { workspace = true }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
@ -63,6 +64,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
[[example]]
name = "llama_multiprocess"
@ -98,6 +100,4 @@ required-features = ["candle-datasets"]
[[example]]
name = "encodec"
required-features = ["symphonia"]
required-features = ["encodec"]

View File

@ -0,0 +1,46 @@
Contrastive Language-Image Pre-Training
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
pairs of images with related texts.
https://github.com/openai/CLIP
https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
## Running on an example on cpu
```
$ cargo run --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
INFO clip: Probability: 0.0000% Text: a cycling race
INFO clip: Probability: 0.0000% Text: a photo of two cats
INFO clip: Probability: 100.0000% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
INFO clip: Probability: 99.9999% Text: a cycling race
INFO clip: Probability: 0.0001% Text: a photo of two cats
INFO clip: Probability: 0.0000% Text: a robot holding a candle
```
## Running on an example with metal feature (mac)
```
$ cargo run --features metal --example clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
INFO clip: Probability: 0.0000% Text: a cycling race
INFO clip: Probability: 0.0000% Text: a photo of two cats
INFO clip: Probability: 100.0000% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
INFO clip: Probability: 99.9999% Text: a cycling race
INFO clip: Probability: 0.0001% Text: a photo of two cats
INFO clip: Probability: 0.0000% Text: a robot holding a candle
```

View File

@ -0,0 +1,202 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::clip;
use tokenizers::Tokenizer;
use tracing::info;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::io::Reader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
// .unsqueeze(0)?;
Ok(img)
}
fn load_images<T: AsRef<std::path::Path>>(
paths: &Vec<T>,
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
// std::env::set_var("RUST_BACKTRACE", "full");
let args = Args::parse();
tracing_subscriber::fmt::init();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.repo(hf_hub::Repo::with_revision(
"openai/clip-vit-base-patch32".to_string(),
hf_hub::RepoType::Model,
"refs/pr/15".to_string(),
));
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = clip::ClipConfig::vit_base_patch32();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = clip::ClipModel::new(vb, &config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
info!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
info!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.repo(hf_hub::Repo::with_revision(
"openai/clip-vit-base-patch32".to_string(),
hf_hub::RepoType::Model,
"refs/pr/15".to_string(),
));
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
pub fn tokenize_sequences(
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let pad_id = *tokenizer
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("No pad token"))?;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"a cycling race".to_string(),
"a photo of two cats".to_string(),
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -13,8 +13,13 @@ cargo run --example encodec --features symphonia --release -- code-to-audio \
```
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
an output wav file containing the audio data. Instead of `code-to-audio` one
can use:
an output wav file containing the audio data.
Instead of `code-to-audio` one can use:
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
containing EnCodec tokens for the input audio file.
If the audio output file name is set to `-`, the audio content directly gets
played on default audio output device. If the audio input file is set to `-`, the audio
gets recorded from the default audio input.

View File

@ -0,0 +1,275 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};
pub const SAMPLE_RATE: usize = 24_000;
pub(crate) struct AudioOutputData_ {
resampled_data: std::collections::VecDeque<f32>,
resampler: rubato::FastFixedIn<f32>,
output_buffer: Vec<f32>,
input_buffer: Vec<f32>,
input_len: usize,
}
impl AudioOutputData_ {
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
use rubato::Resampler;
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
let resampler = rubato::FastFixedIn::new(
resample_ratio,
f64::max(resample_ratio, 1.0),
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
Ok(Self {
resampled_data,
resampler,
input_buffer,
output_buffer,
input_len: 0,
})
}
pub fn reset(&mut self) {
use rubato::Resampler;
self.output_buffer.fill(0.);
self.input_buffer.fill(0.);
self.resampler.reset();
self.resampled_data.clear();
}
pub(crate) fn take_all(&mut self) -> Vec<f32> {
let mut data = Vec::with_capacity(self.resampled_data.len());
while let Some(elem) = self.resampled_data.pop_back() {
data.push(elem);
}
data
}
pub(crate) fn is_empty(&self) -> bool {
self.resampled_data.is_empty()
}
// Assumes that the input buffer is large enough.
fn push_input_buffer(&mut self, samples: &[f32]) {
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
self.input_len += samples.len()
}
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
use rubato::Resampler;
let mut pos_in = 0;
loop {
let rem = self.input_buffer.len() - self.input_len;
let pos_end = usize::min(pos_in + rem, samples.len());
self.push_input_buffer(&samples[pos_in..pos_end]);
pos_in = pos_end;
if self.input_len < self.input_buffer.len() {
break;
}
let (_, out_len) = self.resampler.process_into_buffer(
&[&self.input_buffer],
&mut [&mut self.output_buffer],
None,
)?;
for &elem in self.output_buffer[..out_len].iter() {
self.resampled_data.push_front(elem)
}
self.input_len = 0;
}
Ok(())
}
}
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio output stream!");
let host = cpal::default_host();
let device = host
.default_output_device()
.context("no output device available")?;
let mut supported_configs_range = device.supported_output_configs()?;
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
// On macOS, it's commonly the case that there are only stereo outputs.
None => device
.supported_output_configs()?
.next()
.context("no audio output available")?,
Some(config_range) => config_range,
};
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
let channels = config.channels as usize;
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
SAMPLE_RATE,
config.sample_rate.0 as usize,
)?));
let ad = audio_data.clone();
let stream = device.build_output_stream(
&config,
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
data.fill(0.);
let mut ad = ad.lock().unwrap();
let mut last_elem = 0f32;
for (idx, elem) in data.iter_mut().enumerate() {
if idx % channels == 0 {
match ad.resampled_data.pop_back() {
None => break,
Some(v) => {
last_elem = v;
*elem = v
}
}
} else {
*elem = last_elem
}
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio input stream!");
let host = cpal::default_host();
let device = host
.default_input_device()
.context("no input device available")?;
let mut supported_configs_range = device.supported_input_configs()?;
let config_range = supported_configs_range
.find(|c| c.channels() == 1)
.context("no audio input available")?;
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
config.sample_rate.0 as usize,
SAMPLE_RATE,
)?));
let ad = audio_data.clone();
let stream = device.build_input_stream(
&config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let mut ad = ad.lock().unwrap();
if let Err(err) = ad.push_samples(data) {
eprintln!("error processing audio input {err:?}")
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) =
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler.process_partial_into_buffer(
Some(&[&pcm_in[pos_in..]]),
&mut output_buffer,
None,
)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

View File

@ -11,59 +11,7 @@ use candle_transformers::models::encodec::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
mod audio_io;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
@ -109,14 +57,36 @@ fn main() -> Result<()> {
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
codes
codes.get("codes").expect("no codes in input file").clone()
}
Action::AudioToCode | Action::AudioToAudio => {
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}")
}
let pcm = if args.in_file == "-" {
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
let (stream, input_audio) = audio_io::setup_input_stream()?;
let mut pcms = vec![];
let stdin = std::thread::spawn(|| {
let mut s = String::new();
std::io::stdin().read_line(&mut s)
});
while !stdin.is_finished() {
let input = input_audio.lock().unwrap().take_all();
if input.is_empty() {
std::thread::sleep(std::time::Duration::from_millis(100));
continue;
}
pcms.push(input)
}
drop(stream);
pcms.concat()
} else {
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}, resampling...");
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
} else {
pcm
}
};
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
@ -135,8 +105,26 @@ fn main() -> Result<()> {
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
if args.out_file == "-" {
let (stream, ad) = audio_io::setup_output_stream()?;
{
let mut ad = ad.lock().unwrap();
ad.push_samples(&pcm)?;
}
loop {
let ad = ad.lock().unwrap();
if ad.is_empty() {
break;
}
// That's very weird, calling thread::sleep here triggers the stream to stop
// playing (the callback doesn't seem to be called anymore).
// std::thread::sleep(std::time::Duration::from_millis(100));
}
drop(stream)
} else {
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
}
}
}
Ok(())

View File

@ -16,6 +16,30 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
}
struct TextGeneration {
model: Model,
device: Device,
@ -165,6 +189,13 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
fn main() -> Result<()> {
@ -196,14 +227,19 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => match model_id.as_str() {
"7b-it" => "google/gemma-7b-it".to_string(),
"7b" => "google/gemma-7b".to_string(),
"2b-it" => "google/gemma-2b-it".to_string(),
"2b" => "google/gemma-2b".to_string(),
_ => model_id.to_string(),
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
Which::Base2B => "google/gemma-2b".to_string(),
Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
Which::CodeBase2B => "google/codegemma-2b".to_string(),
Which::CodeBase7B => "google/codegemma-7b".to_string(),
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
},
None => "google/gemma-2b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
@ -237,7 +273,7 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
let model = Model::new(args.use_flash_attn, &config, vb)?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -0,0 +1,19 @@
# gte-Qwen1.5-7B-instruct
gte-Qwen1.5-7B-instruct is a variant of the GTE embedding model family.
- [Model card](https://huggingface.co/Alibaba-NLP/gte-Qwen1.5-7B-instruct) on the HuggingFace Hub.
- [Technical report](https://arxiv.org/abs/2308.03281) *Towards General Text Embeddings with Multi-stage Contrastive Learning*
## Running the example
Automatically download the model from the HuggingFace hub:
```bash
$ cargo run --example gte-qwen --release
```
or, load the model from a local directory:
```bash
cargo run --example gte-qwen --release --features cuda -- --local-repo /path/to/gte_Qwen1.5-7B-instruct/
```

View File

@ -0,0 +1,178 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config, Model};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
Tokenizer,
};
// gte-Qwen1.5-7B-instruct use EOS token as padding token
const EOS_TOKEN: &str = "<|endoftext|>";
const EOS_TOKEN_ID: u32 = 151643;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long, default_value = "Alibaba-NLP/gte-Qwen1.5-7B-instruct")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
local_repo: Option<String>,
}
#[derive(Debug)]
struct ConfigFiles {
pub config: std::path::PathBuf,
pub tokenizer: std::path::PathBuf,
pub weights: Vec<std::path::PathBuf>,
}
// Loading the model from the HuggingFace Hub. Network access is required.
fn load_from_hub(model_id: &str, revision: &str) -> Result<ConfigFiles> {
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
));
Ok(ConfigFiles {
config: repo.get("config.json")?,
tokenizer: repo.get("tokenizer.json")?,
weights: candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
})
}
// Loading the model from a local directory.
fn load_from_local(local_path: &str) -> Result<ConfigFiles> {
let local_path = std::path::PathBuf::from(local_path);
let weight_path = local_path.join("model.safetensors.index.json");
let json: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(weight_path)?)?;
let weight_map = match json.get("weight_map") {
Some(serde_json::Value::Object(map)) => map,
Some(_) => panic!("`weight map` is not a map"),
None => panic!("`weight map` not found"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
safetensors_files.insert(
value
.as_str()
.expect("Weight files should be parsed as strings"),
);
}
let safetensors_paths = safetensors_files
.iter()
.map(|v| local_path.join(v))
.collect::<Vec<_>>();
Ok(ConfigFiles {
config: local_path.join("config.json"),
tokenizer: local_path.join("tokenizer.json"),
weights: safetensors_paths,
})
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
// Fetch the model. Do this offline if local path provided.
println!("Fetching model files...");
let start = std::time::Instant::now();
let config_files = match args.local_repo {
Some(local_path) => load_from_local(&local_path)?,
None => load_from_hub(&args.model_id, &args.revision)?,
};
println!("Model file retrieved in {:?}", start.elapsed());
// Inputs will be padded to the longest sequence in the batch.
let padding = PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_to_multiple_of: None,
pad_id: EOS_TOKEN_ID,
pad_type_id: 0,
pad_token: String::from(EOS_TOKEN),
};
// Tokenizer setup
let mut tokenizer = Tokenizer::from_file(config_files.tokenizer).map_err(E::msg)?;
tokenizer.with_padding(Some(padding));
// Model initialization
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let config: Config = serde_json::from_slice(&std::fs::read(config_files.config)?)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&config_files.weights, dtype, &device)? };
let mut model = Model::new(&config, vb)?;
println!("Model loaded in {:?}", start.elapsed());
// Encode the queries and the targets
let instruct = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ";
let documents = vec![
format!("{instruct}how much protein should a female eat{EOS_TOKEN}"),
format!("{instruct}summit define{EOS_TOKEN}"),
format!("As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.{EOS_TOKEN}"),
format!("Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.{EOS_TOKEN}"),
];
let encoded = tokenizer.encode_batch(documents, true).map_err(E::msg)?;
let tokens: Vec<&[u32]> = encoded.iter().map(|x| x.get_ids()).collect();
let tokens = Tensor::new(tokens, &device)?;
let mask: Vec<&[u32]> = encoded.iter().map(|x| x.get_attention_mask()).collect();
let mask = Tensor::new(mask, &device)?;
// Inference
let start_gen = std::time::Instant::now();
let logits = model.forward(&tokens, 0, Some(&mask))?;
// Extract the last hidden states as embeddings since inputs are padded left.
let (_, seq_len, _) = logits.dims3()?;
let embd = logits
.narrow(1, seq_len - 1, 1)?
.squeeze(1)?
.to_dtype(DType::F32)?;
// Calculate the relativity scores. Note the embeddings should be normalized.
let norm = embd.broadcast_div(&embd.sqr()?.sum_keepdim(1)?.sqrt()?)?;
let scores = norm.narrow(0, 0, 2)?.matmul(&norm.narrow(0, 2, 2)?.t()?)?;
// Print the results
println!("Embedding done in {:?}", start_gen.elapsed());
println!("Scores: {:?}", scores.to_vec2::<f32>()?);
Ok(())
}

View File

@ -17,7 +17,7 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
@ -31,6 +31,8 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
enum Which {
V1,
V2,
V3,
V3Instruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@ -45,19 +47,23 @@ struct Args {
cpu: bool,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 10000)]
#[arg(short = 'n', long, default_value_t = 10000)]
sample_len: usize,
/// Disable the key-value cache.
@ -83,18 +89,18 @@ struct Args {
revision: Option<String>,
/// The model size to use.
#[arg(long, default_value = "v2")]
#[arg(long, default_value = "v3")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}
@ -120,11 +126,13 @@ fn main() -> Result<()> {
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (llama, tokenizer_filename, mut cache) = {
let (llama, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@ -138,7 +146,7 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn);
let filenames = match args.which {
Which::V1 | Which::V2 | Which::Solar10_7B => {
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
@ -146,10 +154,12 @@ fn main() -> Result<()> {
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache)
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
let eos_token_id = config
.eos_token_id
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@ -160,8 +170,22 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now();
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
for index in 0..args.sample_len {
@ -170,6 +194,9 @@ fn main() -> Result<()> {
} else {
(tokens.len(), 0)
};
if index == 1 {
start_gen = std::time::Instant::now()
}
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
@ -205,7 +232,7 @@ fn main() -> Result<()> {
println!(
"\n\n{} tokens generated ({} token/s)\n",
token_generated,
token_generated as f64 / dt.as_secs_f64(),
(token_generated - 1) as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -10,7 +10,7 @@
extern crate intel_mkl_src;
use anyhow::{bail, Error as E, Result};
use clap::Parser;
use clap::{Parser, ValueEnum};
use candle::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
@ -24,57 +24,15 @@ mod model;
use model::{Config, Llama};
const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = r"
EDWARD:
I wonder how our princely father 'scaped,
Or whether he be 'scaped away or no
From Clifford's and Northumberland's pursuit:
Had he been ta'en, we should have heard the news;
Had he been slain, we should have heard the news;
Or had he 'scaped, methinks we should have heard
The happy tidings of his good escape.
How fares my brother? why is he so sad?
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
RICHARD:
I cannot joy, until I be resolved
Where our right valiant father is become.
I saw him in the battle range about;
And watch'd him how he singled Clifford forth.
Methought he bore him in the thickest troop
As doth a lion in a herd of neat;
Or as a bear, encompass'd round with dogs,
Who having pinch'd a few and made them cry,
The rest stand all aloof, and bark at him.
So fared our father with his enemies;
So fled his enemies my warlike father:
Methinks, 'tis prize enough to be his son.
See how the morning opes her golden gates,
And takes her farewell of the glorious sun!
How well resembles it the prime of youth,
Trimm'd like a younker prancing to his love!
EDWARD:
Dazzle mine eyes, or do I see three suns?
RICHARD:
Three glorious suns, each one a perfect sun;
Not separated with the racking clouds,
But sever'd in a pale clear-shining sky.
See, see! they join, embrace, and seem to kiss,
As if they vow'd some league inviolable:
Now are they but one lamp, one light, one sun.
In this the heaven figures some event.
EDWARD:
'Tis wondrous strange, the like yet never heard of.
I think it cites us, brother, to the field,
That we, the sons of brave Plantagenet,
Each one already blazing by our meeds,
Should notwithstanding join our lights together
And over-shine the earth as this the world.
Whate'er it bodes, henceforward will I bear
Upon my target three fair-shining suns.
";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
V2_7b,
V2_70b,
V3_8b,
V3_70b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
@ -86,8 +44,8 @@ struct Args {
rank: Option<usize>,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
@ -117,6 +75,12 @@ struct Args {
#[arg(long)]
dtype: Option<String>,
#[arg(long, default_value = "v3-8b")]
which: Which,
#[arg(long, default_value = "nccl_id.txt")]
comm_file: String,
}
fn main() -> Result<()> {
@ -129,14 +93,27 @@ fn main() -> Result<()> {
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
None => match args.which {
Which::V2_7b | Which::V2_70b => DType::F16,
Which::V3_8b | Which::V3_70b => DType::BF16,
},
};
let api = Api::new()?;
let comm_file = std::path::PathBuf::from(&args.comm_file);
if comm_file.exists() {
bail!("comm file {comm_file:?} already exists, please remove it first")
}
let model_id = args
.model_id
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
let api = Api::new()?;
let model_id = match args.model_id {
Some(model) => model,
None => match args.which {
Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(),
Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(),
},
};
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
@ -145,39 +122,40 @@ fn main() -> Result<()> {
let tokenizer_filename = api.get("tokenizer.json")?;
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
if args.rank.is_none() {
let children: Vec<_> = (0..args.num_shards)
.map(|rank| {
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
args.push_back("--rank".to_string());
args.push_back(format!("{rank}"));
let name = args.pop_front().unwrap();
std::process::Command::new(name).args(args).spawn().unwrap()
})
.collect();
for mut child in children {
child.wait().unwrap();
let rank = match args.rank {
None => {
println!("creating {} child processes", args.num_shards);
let children: Vec<_> = (0..args.num_shards)
.map(|rank| {
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
args.push_back("--rank".to_string());
args.push_back(format!("{rank}"));
let name = args.pop_front().unwrap();
std::process::Command::new(name).args(args).spawn().unwrap()
})
.collect();
for mut child in children {
child.wait()?;
}
return Ok(());
}
return Ok(());
}
Some(rank) => rank,
};
let i = args.rank.unwrap();
let num_shards = args.num_shards;
let rank = i;
// Primitive IPC
let id = if rank == 0 {
let id = Id::new().unwrap();
std::fs::File::create("nccl_id.txt.tmp")?
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())
.unwrap();
std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?;
let tmp_file = comm_file.with_extension(".comm.tgz");
std::fs::File::create(&tmp_file)?
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
std::fs::rename(&tmp_file, &comm_file)?;
id
} else {
let path = std::path::PathBuf::from("nccl_id.txt");
while !path.exists() {
while !comm_file.exists() {
std::thread::sleep(std::time::Duration::from_secs(1));
}
let data = std::fs::read("nccl_id.txt")?;
let data = std::fs::read(&comm_file)?;
let internal: [i8; 128] = data
.into_iter()
.map(|i| i as i8)
@ -187,14 +165,17 @@ fn main() -> Result<()> {
let id: Id = Id::uninit(internal);
id
};
let device = CudaDevice::new(i)?;
let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap());
let device = CudaDevice::new(rank)?;
let comm = match Comm::from_rank(device, rank, num_shards, id) {
Ok(comm) => Rc::new(comm),
Err(err) => anyhow::bail!("nccl error {:?}", err.0),
};
if rank == 0 {
std::fs::remove_file("nccl_id.txt")?;
std::fs::remove_file(comm_file)?;
}
println!("Rank {rank:?} spawned");
let device = Device::new_cuda(i)?;
let device = Device::new_cuda(rank)?;
let cache = model::Cache::new(dtype, &config, &device)?;
println!("building the model");
@ -210,14 +191,24 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let temperature = if args.temperature <= 0. {
None
} else {
Some(args.temperature)
};
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
// Only start timing at the second token as processing the first token waits for all the
// weights to be loaded in an async way.
if index == 1 {
start_gen = std::time::Instant::now()
};
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
@ -228,25 +219,23 @@ fn main() -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
if Some(next_token) == config.eos_token_id {
break;
}
if rank == 0 {
println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
next_token,
tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
}
let dt = start_gen.elapsed();
println!();
if rank == 0 {
let dt = start_gen.elapsed();
println!(
"{} tokens generated ({} token/s)\n----\n{}\n----",
"\n\n{} tokens generated ({} token/s)\n",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
tokenizer
.decode(new_tokens.as_slice(), true)
.map_err(E::msg)?
(args.sample_len - 1) as f64 / dt.as_secs_f64(),
);
}
Ok(())

View File

@ -1,15 +1,14 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
use candle_nn::{Embedding, Linear, Module, RmsNorm};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use serde::Deserialize;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN;
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
pub type Config = candle_transformers::models::llama::LlamaConfig;
struct TensorParallelColumnLinear {
linear: Linear,
@ -26,7 +25,7 @@ impl TensorParallelColumnLinear {
struct TensorParallelRowLinear {
linear: Linear,
comm: Rc<Comm>,
all_reduce: AllReduce,
}
struct AllReduce {
@ -36,8 +35,6 @@ struct AllReduce {
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Sync for AllReduce {}
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Send for AllReduce {}
impl CustomOp1 for AllReduce {
@ -46,7 +43,7 @@ impl CustomOp1 for AllReduce {
}
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
todo!("implement allreduce for cpu is not necessary for single node");
candle::bail!("AllReduce is never used on cpu")
}
#[cfg(feature = "cuda")]
@ -56,31 +53,49 @@ impl CustomOp1 for AllReduce {
l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::WrapErr;
use cudarc::driver::DeviceSlice;
use half::{bf16, f16};
let elem_count = l.shape().elem_count();
let dev = s.device().clone();
let s = s.as_cuda_slice::<f16>()?;
// let s = match l.contiguous_offsets() {
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
// Some((o1, o2)) => s.slice(o1..o2),
// };
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
let dst = match s.dtype() {
DType::BF16 => {
let s = s.as_cuda_slice::<bf16>()?;
let s = match l.contiguous_offsets() {
Some((0, l)) if l == s.len() => s,
Some(_) | None => candle::bail!("input has to be contiguous"),
};
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
self.comm
.all_reduce(s, &mut dst, &ReduceOp::Sum)
.map_err(candle::Error::debug)?;
candle::CudaStorage::wrap_cuda_slice(dst, dev)
}
DType::F16 => {
let s = s.as_cuda_slice::<f16>()?;
let s = match l.contiguous_offsets() {
Some((0, l)) if l == s.len() => s,
Some(_) | None => candle::bail!("input has to be contiguous"),
};
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
self.comm
.all_reduce(s, &mut dst, &ReduceOp::Sum)
.map_err(candle::Error::debug)?;
candle::CudaStorage::wrap_cuda_slice(dst, dev)
}
dtype => candle::bail!("unsupported dtype {dtype:?}"),
};
Ok((dst, l.shape().clone()))
}
}
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
x.apply_op1(AllReduce { comm: comm.clone() })
}
impl TensorParallelRowLinear {
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
Self { linear, comm }
let all_reduce = AllReduce { comm };
Self { linear, all_reduce }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.linear.forward(x)?;
all_reduce_sum(&x, &self.comm)
self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce)
}
}
@ -121,23 +136,6 @@ impl TensorParallelRowLinear {
}
}
#[derive(Deserialize)]
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
}
fn default_rope() -> f32 {
10_000.0
}
#[derive(Clone)]
pub struct Cache {
#[allow(clippy::type_complexity)]
@ -161,7 +159,6 @@ impl Cache {
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
@ -197,16 +194,10 @@ struct CausalSelfAttention {
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?;
let (_b_sz, _, seq_len, _hidden_size) = x.shape().dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
candle_nn::rotary_emb::rope(x, &cos, &sin)
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
@ -232,13 +223,16 @@ impl CausalSelfAttention {
let q = q
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let mut v = v
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
@ -269,25 +263,14 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
.transpose(1, 2)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
.reshape((b_sz, seq_len, hidden_size))?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.num_attention_heads / self.num_key_value_heads;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
Ok(x)
}
candle_transformers::utils::repeat_kv(x, n_rep)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
@ -301,7 +284,7 @@ impl CausalSelfAttention {
qkv_proj,
o_proj,
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(),
head_dim: cfg.hidden_size / cfg.num_attention_heads,
cache: cache.clone(),
})
@ -315,18 +298,6 @@ struct Mlp {
}
impl Mlp {
fn new(
c_fc1: TensorParallelColumnLinear,
c_fc2: TensorParallelColumnLinear,
c_proj: TensorParallelRowLinear,
) -> Self {
Self {
c_fc1,
c_fc2,
c_proj,
}
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
@ -336,7 +307,11 @@ impl Mlp {
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
Ok(Self::new(c_fc1, c_fc2, c_proj))
Ok(Self {
c_fc1,
c_fc2,
c_proj,
})
}
}
@ -427,10 +402,8 @@ impl Llama {
cfg,
comm.clone(),
)
.unwrap()
})
.collect();
.collect::<Result<Vec<_>>>()?;
Ok(Self::new(wte, blocks, norm, lm_head))
}
}

View File

@ -54,6 +54,7 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let dtype = self.model.dtype();
let mut tokens = self
.tokenizer
.tokenizer()
@ -66,7 +67,7 @@ impl TextGeneration {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let mut state = State::new(1, &self.config, &self.device)?;
let mut state = State::new(1, &self.config, dtype, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[t], &self.device)?;
@ -84,7 +85,7 @@ impl TextGeneration {
Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"),
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
@ -210,6 +211,9 @@ struct Args {
#[arg(long)]
config_file: Option<String>,
#[arg(long, default_value = "f32")]
dtype: String,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
@ -220,6 +224,7 @@ struct Args {
}
fn main() -> Result<()> {
use std::str::FromStr;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@ -279,7 +284,8 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let dtype = DType::from_str(&args.dtype)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb.pp("backbone"))?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral;
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
@ -39,11 +39,26 @@ impl TextGeneration {
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
top_k: Option<usize>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
let logits_processor = {
let temperature = temp.unwrap_or(0.);
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (top_k, top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(seed, sampling)
};
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
@ -122,6 +137,18 @@ impl TextGeneration {
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "7b-v0.1")]
Mistral7bV01,
#[value(name = "7b-v0.2")]
Mistral7bV02,
#[value(name = "7b-instruct-v0.1")]
Mistral7bInstructV01,
#[value(name = "7b-instruct-v0.2")]
Mistral7bInstructV02,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -147,6 +174,10 @@ struct Args {
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@ -155,6 +186,10 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "7b-v0.1")]
which: Which,
#[arg(long)]
model_id: Option<String>,
@ -164,6 +199,9 @@ struct Args {
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
@ -177,6 +215,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// Use the slower dmmv cuda kernel.
#[arg(long)]
force_dmmv: bool,
}
fn main() -> Result<()> {
@ -184,6 +226,9 @@ fn main() -> Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@ -211,9 +256,17 @@ fn main() -> Result<()> {
Some(model_id) => model_id,
None => {
if args.quantized {
if args.which != Which::Mistral7bV01 {
anyhow::bail!("only 7b-v0.1 is available as a quantized model for now")
}
"lmz/candle-mistral".to_string()
} else {
"mistralai/Mistral-7B-v0.1".to_string()
match args.which {
Which::Mistral7bV01 => "mistralai/Mistral-7B-v0.1".to_string(),
Which::Mistral7bV02 => "mistralai/Mistral-7B-v0.2".to_string(),
Which::Mistral7bInstructV01 => "mistralai/Mistral-7B-Instruct-v0.1".to_string(),
Which::Mistral7bInstructV02 => "mistralai/Mistral-7B-Instruct-v0.2".to_string(),
}
}
}
};
@ -243,7 +296,17 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn);
let config = match args.config_file {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
if args.quantized {
Config::config_7b_v0_1(args.use_flash_attn)
} else {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
}
};
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
@ -270,6 +333,7 @@ fn main() -> Result<()> {
args.seed,
args.temperature,
args.top_p,
args.top_k,
args.repeat_penalty,
args.repeat_last_n,
&device,

View File

@ -0,0 +1,26 @@
# candle-moondream
[Moondream](https://github.com/vikhyat/moondream) is a computer-vision model can answer real-world questions about images. It's tiny by today's models, with only 1.6B parameters. That enables it to run on a variety of devices, including mobile phones and edge devices.
## Running some examples
First download an example image
```bash
$ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg
```
<img src="https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg" width="200">
Now you can run Moondream from the `candle-examples` crate:
```bash
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
avavx: false, neon: true, simd128: false, f16c: false
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
retrieved the files in 3.395583ms
Running on CPU, to run on GPU(metal), build this example with `--features metal`
loaded the model in 5.485493792s
loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s
starting the inference loop
The girl is eating a hamburger.<
9 tokens generated (0.68 token/s)
```

View File

@ -0,0 +1,343 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
generation::LogitsProcessor,
models::{moondream, quantized_moondream},
};
use tokenizers::Tokenizer;
enum Model {
Moondream(moondream::Model),
Quantized(quantized_moondream::Model),
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, image_embeds: &Tensor, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the Moondream model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
// Moondream tokenizer bos_token and eos_token is "<|endoftext|>"
// https://huggingface.co/vikhyatk/moondream2/blob/main/special_tokens_map.json
let special_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the special token"),
};
let (bos_token, eos_token) = (special_token, special_token);
let start_gen = std::time::Instant::now();
let mut load_t = std::time::Duration::from_secs_f64(0f64);
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = if index > 0 {
match self.model {
Model::Moondream(ref mut model) => model.text_model.forward(&input)?,
Model::Quantized(ref mut model) => model.text_model.forward(&input)?,
}
} else {
let bos_token = Tensor::new(&[bos_token], &self.device)?.unsqueeze(0)?;
let logits = match self.model {
Model::Moondream(ref mut model) => {
model
.text_model
.forward_with_img(&bos_token, &input, image_embeds)?
}
Model::Quantized(ref mut model) => {
model
.text_model
.forward_with_img(&bos_token, &input, image_embeds)?
}
};
load_t = start_gen.elapsed();
println!("load_t: {:?}", load_t);
logits
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || tokens.ends_with(&[27, 10619, 29] /* <END> */) {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed() - load_t;
println!(
"\ngenerated in {} seconds\n{generated_tokens} tokens generated ({:.2} token/s)",
dt.as_secs_f64(),
(generated_tokens - 1) as f64 / dt.as_secs_f64()
);
Ok(())
}
}
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
#[arg(long)]
image: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 0)]
seed: u64,
#[arg(long, default_value_t = 5000)]
sample_len: usize,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
quantized: bool,
/// Use f16 precision for all the computations rather than f32.
#[arg(long)]
f16: bool,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
}
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 378, 378).
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> candle::Result<Tensor> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(378, 378, image::imageops::FilterType::Triangle); // Adjusted to 378x378
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (378, 378, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = hf_hub::api::tokio::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => {
if args.quantized {
"santiagomed/candle-moondream".to_string()
} else {
"vikhyatk/moondream2".to_string()
}
}
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
args.revision,
));
let model_file = match args.model_file {
Some(m) => m.into(),
None => {
if args.quantized {
repo.get("model-q4_0.gguf").await?
} else {
repo.get("model.safetensors").await?
}
}
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json").await?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let config = moondream::Config::v2();
let dtype = if args.quantized {
if args.f16 {
anyhow::bail!("Quantized model does not support f16");
}
DType::F32
} else if device.is_cuda() || args.f16 {
DType::F16
} else {
DType::F32
};
let model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&model_file,
&device,
)?;
let model = quantized_moondream::Model::new(&config, vb)?;
Model::Quantized(model)
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let model = moondream::Model::new(&config, vb)?;
Model::Moondream(model)
};
println!("loaded the model in {:?}", start.elapsed());
let start = std::time::Instant::now();
let image = load_image(args.image)?
.to_device(&device)?
.to_dtype(dtype)?;
let image_embeds = image.unsqueeze(0)?;
let image_embeds = match model {
Model::Moondream(ref m) => image_embeds.apply(m.vision_encoder())?,
Model::Quantized(ref m) => image_embeds.apply(m.vision_encoder())?,
};
println!(
"loaded and encoded the image {image:?} in {:?}",
start.elapsed()
);
let prompt = format!("\n\nQuestion: {0}\n\nAnswer:", args.prompt);
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&prompt, &image_embeds, args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,36 @@
# candle-olmo: Open Language Models designed to enable the science of language models
OLMo is a series of Open Language Models designed to enable the science of language models.
- **Project Page:** https://allenai.org/olmo
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
<!-- - **Press release:** TODO -->
## Running the example
```bash
$ cargo run --example olmo --release -- --prompt "It is only with the heart that one can see rightly"
avx: true, neon: false, simd128: false, f16c: true
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64
retrieved the files in 354.977µs
loaded the model in 19.87779666s
It is only with the heart that one can see rightly; what is essential is invisible to the eye.
```
Various model sizes are available via the `--model` argument.
```bash
$ cargo run --example olmo --release -- --model 1.7-7b --prompt 'It is only with the heart that one can see rightly'
avx: true, neon: false, simd128: false, f16c: true
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64
retrieved the files in 1.226087ms
loaded the model in 171.274578609s
It is only with the heart that one can see rightly; what is essential is invisible to the eye.”
~ Antoine de Saint-Exupery, The Little Prince
I am a big fan of this quote. It reminds me that I need to be open and aware of my surroundings in order to truly appreciate them.
```

View File

@ -0,0 +1,284 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::olmo::{Config, Model as OLMo};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
OLMo(OLMo),
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, false)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model {
Model::OLMo(m) => m.forward(&input, start_pos)?,
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
enum Which {
#[value(name = "1b")]
W1b,
#[value(name = "7b")]
W7b,
#[value(name = "7b-twin-2t")]
W7bTwin2T,
#[value(name = "1.7-7b")]
V1_7W7b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 1000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long, default_value = "1b")]
model: Which,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => match args.model {
Which::W1b => "allenai/OLMo-1B-hf".to_string(),
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
Which::W1b => {
vec![repo.get("model.safetensors")?]
}
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = {
let config_filename = repo.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
config
};
let device = candle_examples::device(args.cpu)?;
let model = {
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = OLMo::new(&config, vb)?;
Model::OLMo(model)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,8 +1,9 @@
# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
only 1.3 and 2.7 billion parameters but with state of the art performance compared to
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5),
[Phi-2](https://huggingface.co/microsoft/phi-2), and
[Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) are language models using
only 1.3, 2.7, and 3.8 billion parameters but with state of the art performance compared to
models with up to 10 billion parameters.
The candle implementation provides both the standard version as well as a

View File

@ -7,11 +7,13 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3};
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle::{DType, Device, Tensor};
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
@ -20,13 +22,14 @@ use tokenizers::Tokenizer;
enum Model {
MixFormer(MixFormer),
Phi(Phi),
Phi3(Phi3),
Quantized(QMixFormer),
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
@ -49,7 +52,7 @@ impl TextGeneration {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
@ -61,7 +64,11 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
let tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the phi model.")
}
@ -73,13 +80,14 @@ impl TextGeneration {
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the endoftext token"),
};
print!("{prompt}");
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
let mut pos = 0;
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
@ -88,6 +96,7 @@ impl TextGeneration {
Model::MixFormer(m) => m.forward(&input)?,
Model::Phi(m) => m.forward(&input)?,
Model::Quantized(m) => m.forward(&input)?,
Model::Phi3(m) => m.forward(&input, pos)?.i((.., 0, ..))?,
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
@ -107,9 +116,11 @@ impl TextGeneration {
if next_token == eos_token {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
pos += context_size;
}
let dt = start_gen.elapsed();
println!(
@ -128,6 +139,8 @@ enum WhichModel {
V1_5,
#[value(name = "2")]
V2,
#[value(name = "3")]
V3,
#[value(name = "2-old")]
V2Old,
PuffinPhiV2,
@ -196,6 +209,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The dtype to be used for running the model, e.g. f32, bf16, or f16.
#[arg(long)]
dtype: Option<String>,
}
fn main() -> Result<()> {
@ -236,6 +253,7 @@ fn main() -> Result<()> {
WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
@ -253,9 +271,10 @@ fn main() -> Result<()> {
WhichModel::V1 => "refs/pr/8".to_string(),
WhichModel::V1_5 => "refs/pr/73".to_string(),
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"main".to_string()
}
WhichModel::V2
| WhichModel::V3
| WhichModel::PuffinPhiV2
| WhichModel::PhiHermes => "main".to_string(),
}
}
}
@ -264,9 +283,11 @@ fn main() -> Result<()> {
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
repo.get("tokenizer.json")?
}
WhichModel::V1
| WhichModel::V1_5
| WhichModel::V2
| WhichModel::V2Old
| WhichModel::V3 => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
@ -282,14 +303,19 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
WhichModel::V3 => anyhow::bail!(
"use the quantized or quantized-phi examples for quantized phi-v3"
),
}
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
)?,
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 => {
candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
)?
}
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
}
@ -306,6 +332,9 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
WhichModel::V3 => {
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
}
};
let device = candle_examples::device(args.cpu)?;
let model = if args.quantized {
@ -320,7 +349,17 @@ fn main() -> Result<()> {
};
Model::Quantized(model)
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let dtype = match args.dtype {
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
None => {
if args.model == WhichModel::V3 && device.is_cuda() {
DType::BF16
} else {
DType::F32
}
}
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
let config_filename = repo.get("config.json")?;
@ -329,6 +368,13 @@ fn main() -> Result<()> {
let phi = Phi::new(&config, vb)?;
Model::Phi(phi)
}
WhichModel::V3 => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: Phi3Config = serde_json::from_str(&config)?;
let phi3 = Phi3::new(&config, vb)?;
Model::Phi3(phi3)
}
WhichModel::V2Old => {
let config = config();
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
@ -421,6 +467,10 @@ fn mmlu<P: AsRef<std::path::Path>>(
m.clear_kv_cache();
m.forward(&input)?
}
Model::Phi3(m) => {
m.clear_kv_cache();
m.forward(&input, 0)?
}
Model::Quantized(m) => {
m.clear_kv_cache();
m.forward(&input)?

View File

@ -0,0 +1,326 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::gguf_file;
use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama::ModelWeights as Phi3b;
use candle_transformers::models::quantized_phi::ModelWeights as Phi2;
use candle_transformers::models::quantized_phi3::ModelWeights as Phi3;
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "phi-2")]
Phi2,
#[value(name = "phi-3")]
Phi3,
/// Alternative implementation of phi-3, based on llama.
#[value(name = "phi-3b")]
Phi3b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
#[arg(long)]
model: Option<String>,
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
/// is preserved.
#[arg(long)]
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize,
/// The tokenizer config in json format.
#[arg(long)]
tokenizer: Option<String>,
/// The temperature used to generate samples, use 0 for greedy sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "phi-3b")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
impl Args {
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = match self.which {
Which::Phi2 => "microsoft/phi-2",
Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
};
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
}
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename, revision) = match self.which {
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf", "main"),
Which::Phi3 => (
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
"main",
),
Which::Phi3b => (
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
"5eef2ce24766d31909c0b269fe90c817a8f263fb",
),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
revision.to_string(),
))
.get(filename)?
}
};
Ok(model_path)
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
enum Model {
Phi2(Phi2),
Phi3(Phi3),
Phi3b(Phi3b),
}
impl Model {
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
match self {
Self::Phi2(m) => m.forward(xs, pos),
Self::Phi3(m) => m.forward(xs, pos),
Self::Phi3b(m) => m.forward(xs, pos),
}
}
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let mut model = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
match args.which {
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
1,
args.use_flash_attn,
model,
&mut file,
&device,
)?),
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
}
};
println!("model built");
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
print!("{}", &prompt_str);
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
let tokens = tokens.get_ids();
let to_sample = args.sample_len.saturating_sub(1);
let mut all_tokens = vec![];
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = *tos
.tokenizer()
.get_vocab(true)
.get("<|endoftext|>")
.unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
tokens.len(),
tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via:
```bash
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
```
## Using custom models

View File

@ -10,7 +10,7 @@ use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
use candle::Tensor;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama as model;
@ -67,6 +67,10 @@ enum Which {
Mixtral,
#[value(name = "mixtral-instruct")]
MixtralInstruct,
#[value(name = "llama3-8b")]
L8b,
#[value(name = "phi3")]
Phi3,
}
impl Which {
@ -82,7 +86,9 @@ impl Which {
| Self::L13bCode
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b => false,
| Self::Leo13b
| Self::L8b
| Self::Phi3 => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
@ -116,7 +122,9 @@ impl Which {
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::OpenChat35
| Self::Starling7bAlpha => false,
| Self::Starling7bAlpha
| Self::L8b
| Self::Phi3 => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
@ -140,33 +148,37 @@ impl Which {
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => false,
| Self::Zephyr7bBeta
| Self::L8b
| Self::Phi3 => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
}
}
fn tokenizer_repo(&self) -> &'static str {
match self {
Which::L7b
| Which::L13b
| Which::L70b
| Which::L7bChat
| Which::L13bChat
| Which::L70bChat
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
Which::Leo7b => "LeoLM/leo-hessianai-7b",
Which::Leo13b => "LeoLM/leo-hessianai-13b",
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Which::OpenChat35 => "openchat/openchat_3.5",
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => "hf-internal-testing/llama-tokenizer",
Self::Leo7b => "LeoLM/leo-hessianai-7b",
Self::Leo13b => "LeoLM/leo-hessianai-13b",
Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Self::OpenChat35 => "openchat/openchat_3.5",
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L8b => "meta-llama/Meta-Llama-3-8B",
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
}
}
}
@ -200,6 +212,10 @@ struct Args {
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@ -235,6 +251,10 @@ struct Args {
/// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
#[arg(long)]
gqa: Option<usize>,
/// Use the slower dmmv cuda kernel.
#[arg(long)]
force_dmmv: bool,
}
impl Args {
@ -314,10 +334,28 @@ impl Args {
"TheBloke/Starling-LM-7B-alpha-GGUF",
"starling-lm-7b-alpha.Q4_K_M.gguf",
),
// TODO: swap to TheBloke model when available
Which::L8b => (
"QuantFactory/Meta-Llama-3-8B-GGUF",
"Meta-Llama-3-8B.Q4_K_S.gguf",
),
Which::Phi3 => (
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
),
};
let revision = if self.which == Which::Phi3 {
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
} else {
"main"
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
api.get(filename)?
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
revision.to_string(),
))
.get(filename)?
}
};
Ok(model_path)
@ -341,11 +379,13 @@ fn main() -> anyhow::Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
let temperature = if args.temperature == 0. {
None
} else {
Some(args.temperature)
};
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
candle::cuda::set_gemm_reduced_precision_f16(true);
candle::cuda::set_gemm_reduced_precision_bf16(true);
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@ -413,7 +453,9 @@ fn main() -> anyhow::Result<()> {
| Which::L13bCode
| Which::L34bCode
| Which::Leo7b
| Which::Leo13b => 1,
| Which::Leo13b
| Which::L8b
| Which::Phi3 => 1,
Which::Mixtral
| Which::MixtralInstruct
| Which::Mistral7b
@ -492,7 +534,20 @@ fn main() -> anyhow::Result<()> {
prompt_tokens
};
let mut all_tokens = vec![];
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
@ -517,11 +572,14 @@ fn main() -> anyhow::Result<()> {
std::io::stdout().flush()?;
}
let eos_token = if args.which.is_open_chat() {
"<|end_of_turn|>"
} else {
"</s>"
let eos_token = match args.which {
Which::L8b => "<|end_of_text|>",
_ => match args.which.is_open_chat() {
true => "<|end_of_turn|>",
false => "</s>",
},
};
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;

View File

@ -0,0 +1,27 @@
# candle-qwen: large language model series from Alibaba Cloud
Qwen 1.5 is a series of large language models that provide strong performances
on English and Chinese.
- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5.
- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub.
- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the
mixture-of-experts (MoE) variant.
## Running the example
```bash
$ cargo run --example qwen --release -- --prompt "Hello there "
```
Various model sizes are available via the `--model` argument, including the MoE
variant.
```bash
$ cargo run --example qwen --release -- --model moe-a2.7b --prompt 'def print_prime(n: int): '
def print_prime(n: int): # n is the number of primes to be printed
for i in range(2, n + 1):
if all(i % j != 0 for j in range(2, i)):
print(i)
```

View File

@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config, Model};
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -16,6 +17,20 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
Base(ModelBase),
Moe(ModelMoe),
}
impl Model {
fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
match self {
Self::Moe(ref mut m) => m.forward(xs, s),
Self::Base(ref mut m) => m.forward(xs, s),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
@ -127,6 +142,8 @@ enum WhichModel {
W14b,
#[value(name = "72b")]
W72b,
#[value(name = "moe-a2.7b")]
MoeA27b,
}
#[derive(Parser, Debug)]
@ -224,6 +241,7 @@ fn main() -> Result<()> {
WhichModel::W7b => "7B",
WhichModel::W14b => "14B",
WhichModel::W72b => "72B",
WhichModel::MoeA27b => "MoE-A2.7B",
};
format!("Qwen/Qwen1.5-{size}")
}
@ -244,7 +262,11 @@ fn main() -> Result<()> {
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
WhichModel::W4b
| WhichModel::W7b
| WhichModel::W14b
| WhichModel::W72b
| WhichModel::MoeA27b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
@ -254,7 +276,6 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config_file = repo.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
@ -262,7 +283,16 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
let model = match args.model {
WhichModel::MoeA27b => {
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Moe(ModelMoe::new(&config, vb)?)
}
_ => {
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
Model::Base(ModelBase::new(&config, vb)?)
}
};
println!("loaded the model in {:?}", start.elapsed());

View File

@ -0,0 +1,9 @@
# candle-recurrent-gemma
This model card corresponds to the 2B base version of the RecurrentGemma model
[huggingface model card](https://huggingface.co/google/recurrentgemma-2b).
```bash
cargo run --features cuda -r --example recurrent-gemma -- \
--prompt "Write me a poem about Machine Learning."
```

View File

@ -0,0 +1,321 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::quantized_recurrent_gemma::Model as QModel;
use candle_transformers::models::recurrent_gemma::{Config, Model as BModel};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
B(BModel),
Q(QModel),
}
impl Model {
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
match self {
Self::B(m) => m.forward(xs, pos),
Self::Q(m) => m.forward(xs, pos),
}
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "2b-it")]
Instruct2B,
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
top_k: usize,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let sampling = match temp {
None => candle_transformers::generation::Sampling::ArgMax,
Some(temperature) => match top_p {
None => candle_transformers::generation::Sampling::TopK {
temperature,
k: top_k,
},
Some(top_p) => candle_transformers::generation::Sampling::TopKThenTopP {
temperature,
k: top_k,
p: top_p,
},
},
};
let logits_processor = LogitsProcessor::from_sampling(seed, sampling);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
#[arg(long, default_value_t = 250)]
top_k: usize,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 8000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
#[arg(long)]
quantized: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::Base2B => "google/recurrentgemma-2b".to_string(),
Which::Instruct2B => "google/recurrentgemma-2b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.quantized {
let filename = match args.which {
Which::Base2B => "recurrent-gemma-2b-q4k.gguf",
Which::Instruct2B => "recurrent-gemma-7b-q4k.gguf",
};
let filename = api.model("lmz/candle-gemma".to_string()).get(filename)?;
vec![filename]
} else {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&filenames[0],
&device,
)?;
Model::Q(QModel::new(&config, vb.pp("model"))?)
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
Model::B(BModel::new(&config, vb.pp("model"))?)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.top_k,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,118 @@
use std::collections::VecDeque;
use rand::distributions::Uniform;
use rand::{thread_rng, Rng};
use candle::{DType, Device, Module, Result, Tensor};
use candle_nn::loss::mse;
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
use crate::gym_env::GymEnv;
const DEVICE: Device = Device::Cpu;
const EPISODES: usize = 200;
const BATCH_SIZE: usize = 64;
const GAMMA: f64 = 0.99;
const LEARNING_RATE: f64 = 0.01;
pub fn run() -> Result<()> {
let env = GymEnv::new("CartPole-v1")?;
// Build the model that predicts the estimated rewards given a specific state.
let var_map = VarMap::new();
let vb = VarBuilder::from_varmap(&var_map, DType::F32, &DEVICE);
let observation_space = *env.observation_space().first().unwrap();
let model = seq()
.add(linear(observation_space, 64, vb.pp("linear_in"))?)
.add(Activation::Relu)
.add(linear(64, env.action_space(), vb.pp("linear_out"))?);
let mut optimizer = AdamW::new_lr(var_map.all_vars(), LEARNING_RATE)?;
// Initialize the model's memory.
let mut memory = VecDeque::with_capacity(10000);
// Start the training loop.
let mut state = env.reset(0)?;
let mut episode = 0;
let mut accumulate_rewards = 0.0;
while episode < EPISODES {
// Given the current state, predict the estimated rewards, and take the
// action that is expected to return the most rewards.
let estimated_rewards = model.forward(&state.unsqueeze(0)?)?;
let action: u32 = estimated_rewards.squeeze(0)?.argmax(0)?.to_scalar()?;
// Take that action in the environment, and memorize the outcome:
// - the state for which the action was taken
// - the action taken
// - the new state resulting of taking that action
// - the actual rewards of taking that action
// - whether the environment reached a terminal state or not (e.g. game over)
let step = env.step(action)?;
accumulate_rewards += step.reward;
memory.push_back((
state,
action,
step.state.clone(),
step.reward,
step.terminated || step.truncated,
));
state = step.state;
// If there's enough entries in the memory, perform a learning step, where
// BATCH_SIZE transitions will be sampled from the memory and will be
// fed to the model so that it performs a backward pass.
if memory.len() > BATCH_SIZE {
// Sample randomly from the memory.
let batch = thread_rng()
.sample_iter(Uniform::from(0..memory.len()))
.take(BATCH_SIZE)
.map(|i| memory.get(i).unwrap().clone())
.collect::<Vec<_>>();
// Group all the samples together into tensors with the appropriate shape.
let states: Vec<_> = batch.iter().map(|e| e.0.clone()).collect();
let states = Tensor::stack(&states, 0)?;
let actions = batch.iter().map(|e| e.1);
let actions = Tensor::from_iter(actions, &DEVICE)?.unsqueeze(1)?;
let next_states: Vec<_> = batch.iter().map(|e| e.2.clone()).collect();
let next_states = Tensor::stack(&next_states, 0)?;
let rewards = batch.iter().map(|e| e.3 as f32);
let rewards = Tensor::from_iter(rewards, &DEVICE)?.unsqueeze(1)?;
let non_final_mask = batch.iter().map(|e| !e.4 as u8 as f32);
let non_final_mask = Tensor::from_iter(non_final_mask, &DEVICE)?.unsqueeze(1)?;
// Get the estimated rewards for the actions that where taken at each step.
let estimated_rewards = model.forward(&states)?;
let x = estimated_rewards.gather(&actions, 1)?;
// Get the maximum expected rewards for the next state, apply them a discount rate
// GAMMA and add them to the rewards that were actually gathered on the current state.
// If the next state is a terminal state, just omit maximum estimated
// rewards for that state.
let expected_rewards = model.forward(&next_states)?.detach();
let y = expected_rewards.max_keepdim(1)?;
let y = (y * GAMMA * non_final_mask + rewards)?;
// Compare the estimated rewards with the maximum expected rewards and
// perform the backward step.
let loss = mse(&x, &y)?;
optimizer.backward_step(&loss)?;
}
// If we are on a terminal state, reset the environment and log how it went.
if step.terminated || step.truncated {
episode += 1;
println!("Episode {episode} | Rewards {}", accumulate_rewards as i64);
state = env.reset(0)?;
accumulate_rewards = 0.0;
}
}
Ok(())
}

View File

@ -42,7 +42,7 @@ impl GymEnv {
/// Creates a new session of the specified OpenAI Gym environment.
pub fn new(name: &str) -> Result<GymEnv> {
Python::with_gil(|py| {
let gym = py.import("gymnasium")?;
let gym = py.import_bound("gymnasium")?;
let make = gym.getattr("make")?;
let env = make.call1((name,))?;
let action_space = env.getattr("action_space")?;
@ -66,10 +66,10 @@ impl GymEnv {
/// Resets the environment, returning the observation tensor.
pub fn reset(&self, seed: u64) -> Result<Tensor> {
let state: Vec<f32> = Python::with_gil(|py| {
let kwargs = PyDict::new(py);
let kwargs = PyDict::new_bound(py);
kwargs.set_item("seed", seed)?;
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
state.as_ref(py).get_item(0)?.extract()
let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?;
state.bind(py).get_item(0)?.extract()
})
.map_err(w)?;
Tensor::new(state, &Device::Cpu)
@ -81,8 +81,10 @@ impl GymEnv {
action: A,
) -> Result<Step<A>> {
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
let step = step.as_ref(py);
let step = self
.env
.call_method_bound(py, "step", (action.clone(),), None)?;
let step = step.bind(py);
let state: Vec<f32> = step.get_item(0)?.extract()?;
let reward: f64 = step.get_item(1)?.extract()?;
let terminated: bool = step.get_item(2)?.extract()?;

View File

@ -13,6 +13,7 @@ mod gym_env;
mod vec_gym_env;
mod ddpg;
mod dqn;
mod policy_gradient;
#[derive(Parser)]
@ -25,6 +26,7 @@ struct Args {
enum Command {
Pg,
Ddpg,
Dqn,
}
fn main() -> Result<()> {
@ -32,6 +34,7 @@ fn main() -> Result<()> {
match args.command {
Command::Pg => policy_gradient::run()?,
Command::Ddpg => ddpg::run()?,
Command::Dqn => dqn::run()?,
}
Ok(())
}

View File

@ -24,13 +24,13 @@ fn w(res: PyErr) -> candle::Error {
impl VecGymEnv {
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
Python::with_gil(|py| {
let sys = py.import("sys")?;
let sys = py.import_bound("sys")?;
let path = sys.getattr("path")?;
let _ = path.call_method1(
"append",
("candle-examples/examples/reinforcement-learning",),
)?;
let gym = py.import("atari_wrappers")?;
let gym = py.import_bound("atari_wrappers")?;
let make = gym.getattr("make")?;
let env = make.call1((name, img_dir, nprocesses))?;
let action_space = env.getattr("action_space")?;
@ -60,10 +60,10 @@ impl VecGymEnv {
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
let (obs, reward, is_done) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action,), None)?;
let step = step.as_ref(py);
let step = self.env.call_method_bound(py, "step", (action,), None)?;
let step = step.bind(py);
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
let obs_buffer = pyo3::buffer::PyBuffer::get_bound(&obs)?;
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
let reward: Vec<f32> = step.get_item(1)?.extract()?;
let is_done: Vec<f32> = step.get_item(2)?.extract()?;

View File

@ -5,7 +5,7 @@ use candle_transformers::models::segformer::{
Config, ImageClassificationModel, SemanticSegmentationModel,
};
use clap::{Args, Parser, Subcommand};
use image::Rgb;
use imageproc::image::Rgb;
use imageproc::integral_image::ArrayData;
use std::collections::HashMap;
use std::path::PathBuf;

View File

@ -39,7 +39,7 @@ struct Args {
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
/// mask, positive makes the mask more selective.
#[arg(long, default_value_t = 0.)]
#[arg(long, allow_hyphen_values = true, default_value_t = 0.)]
threshold: f32,
/// Enable tracing (generates a trace-timestamp.json file).

View File

@ -46,7 +46,8 @@ The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
- `--cpu`: use the cpu rather than the gpu (much slower).
- `--height`, `--width`: set the height and width for the generated image.
- `--n-steps`: the number of steps to be used in the diffusion process.
- `--num-samples`: the number of samples to generate.
- `--num-samples`: the number of samples to generate iteratively.
- `--bsize`: the numbers of samples to generate simultaneously.
- `--final-image`: the filename for the generated image(s).
### Using flash-attention

View File

@ -9,6 +9,7 @@ use candle_transformers::models::stable_diffusion;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use stable_diffusion::vae::AutoEncoderKL;
use tokenizers::Tokenizer;
#[derive(Parser)]
@ -64,9 +65,13 @@ struct Args {
#[arg(long)]
n_steps: Option<usize>,
/// The number of samples to generate.
/// The number of samples to generate iteratively.
#[arg(long, default_value_t = 1)]
num_samples: i64,
num_samples: usize,
/// The numbers of samples to generate simultaneously.
#[arg[long, default_value_t = 1]]
bsize: usize,
/// The name of the final image to generate.
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
@ -236,8 +241,8 @@ impl ModelFile {
fn output_filename(
basename: &str,
sample_idx: i64,
num_samples: i64,
sample_idx: usize,
num_samples: usize,
timestep_idx: Option<usize>,
) -> String {
let filename = if num_samples > 1 {
@ -261,6 +266,33 @@ fn output_filename(
}
}
#[allow(clippy::too_many_arguments)]
fn save_image(
vae: &AutoEncoderKL,
latents: &Tensor,
vae_scale: f64,
bsize: usize,
idx: usize,
final_image: &str,
num_samples: usize,
timestep_ids: Option<usize>,
) -> Result<()> {
let images = vae.decode(&(latents / vae_scale)?)?;
let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?;
for batch in 0..bsize {
let image = images.i(batch)?;
let image_filename = output_filename(
final_image,
(bsize * idx) + batch + 1,
batch + num_samples,
timestep_ids,
);
candle_examples::save_image(&image, image_filename)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn text_embeddings(
prompt: &str,
@ -292,6 +324,13 @@ fn text_embeddings(
.map_err(E::msg)?
.get_ids()
.to_vec();
if tokens.len() > sd_config.clip.max_position_embeddings {
anyhow::bail!(
"the prompt is too long, {} > max-tokens ({})",
tokens.len(),
sd_config.clip.max_position_embeddings
)
}
while tokens.len() < sd_config.clip.max_position_embeddings {
tokens.push(pad_id)
}
@ -319,6 +358,13 @@ fn text_embeddings(
.map_err(E::msg)?
.get_ids()
.to_vec();
if uncond_tokens.len() > sd_config.clip.max_position_embeddings {
anyhow::bail!(
"the negative prompt is too long, {} > max-tokens ({})",
uncond_tokens.len(),
sd_config.clip.max_position_embeddings
)
}
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
@ -368,6 +414,7 @@ fn run(args: Args) -> Result<()> {
final_image,
sliced_attention_size,
num_samples,
bsize,
sd_version,
clip_weights,
vae_weights,
@ -461,6 +508,7 @@ fn run(args: Args) -> Result<()> {
.collect::<Result<Vec<_>>>()?;
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
let text_embeddings = text_embeddings.repeat((bsize, 1, 1))?;
println!("{text_embeddings:?}");
println!("Building the autoencoder.");
@ -482,7 +530,6 @@ fn run(args: Args) -> Result<()> {
} else {
0
};
let bsize = 1;
let vae_scale = match sd_version {
StableDiffusionVersion::V1_5
@ -546,12 +593,16 @@ fn run(args: Args) -> Result<()> {
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
if args.intermediary_images {
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename =
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
candle_examples::save_image(&image, image_filename)?
save_image(
&vae,
&latents,
vae_scale,
bsize,
idx,
&final_image,
num_samples,
Some(timestep_index + 1),
)?;
}
}
@ -560,11 +611,16 @@ fn run(args: Args) -> Result<()> {
idx + 1,
num_samples
);
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
candle_examples::save_image(&image, image_filename)?
save_image(
&vae,
&latents,
vae_scale,
bsize,
idx,
&final_image,
num_samples,
None,
)?;
}
Ok(())
}

View File

@ -288,12 +288,12 @@ fn main() -> Result<()> {
};
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let model = if args.quantized {
let filename = &filenames[0];
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
Model::Quantized(model)
} else {
let dtype = if device.is_cuda() {
DType::BF16
@ -302,7 +302,7 @@ fn main() -> Result<()> {
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = StableLM::new(&config, vb)?;
(Model::StableLM(model), device)
Model::StableLM(model)
};
println!("loaded the model in {:?}", start.elapsed());

View File

@ -12,12 +12,23 @@ use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use clap::Parser;
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
T5Base,
T5Small,
T5Large,
T5_3B,
Mt5Base,
Mt5Small,
Mt5Large,
}
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -36,6 +47,15 @@ struct Args {
#[arg(long)]
revision: Option<String>,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Enable decoding.
#[arg(long)]
decode: bool,
@ -71,6 +91,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to be used.
#[arg(long, default_value = "t5-small")]
which: Which,
}
struct T5ModelBuilder {
@ -82,8 +106,17 @@ struct T5ModelBuilder {
impl T5ModelBuilder {
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
let device = candle_examples::device(args.cpu)?;
let default_model = "t5-small".to_string();
let default_revision = "refs/pr/15".to_string();
let (default_model, default_revision) = match args.which {
Which::T5Base => ("t5-base", "main"),
Which::T5Small => ("t5-small", "refs/pr/15"),
Which::T5Large => ("t5-large", "main"),
Which::T5_3B => ("t5-3b", "main"),
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
Which::Mt5Large => ("google/mt5-large", "refs/pr/2"),
};
let default_model = default_model.to_string();
let default_revision = default_revision.to_string();
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
@ -93,14 +126,35 @@ impl T5ModelBuilder {
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
let api = Api::new()?;
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
{
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
} else {
vec![api.get("model.safetensors")?]
let repo = api.repo(repo);
let config_filename = match &args.config_file {
None => repo.get("config.json")?,
Some(f) => f.into(),
};
let tokenizer_filename = match &args.tokenizer_file {
None => match args.which {
Which::Mt5Base => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-base.tokenizer.json")?,
Which::Mt5Small => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-small.tokenizer.json")?,
Which::Mt5Large => api
.model("lmz/mt5-tokenizers".into())
.get("mt5-large.tokenizer.json")?,
_ => repo.get("tokenizer.json")?,
},
Some(f) => f.into(),
};
let weights_filename = match &args.model_file {
Some(f) => f.split(',').map(|v| v.into()).collect::<Vec<_>>(),
None => {
if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2" {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
} else {
vec![repo.get("model.safetensors")?]
}
}
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?;

View File

@ -115,7 +115,7 @@ pub fn main() -> anyhow::Result<()> {
let processor = image_processor::ViTImageProcessor::new(&processor_config);
let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?;
let image = processor.preprocess(image)?.to_device(&device)?;
let encoder_xs = model.encoder().forward(&image)?;

View File

@ -34,6 +34,7 @@ from the hub.
- `--timestamps`: enable the timestamp mode where some timestamps are reported
for each recognized audio extracts.
- `--model`: the model to be used. Models that do not end with `-en` are
multilingual models, other ones are English only models. The supported models
are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`, `medium`,
`medium.en`, `large`, and `large-v2`.
multilingual models, other ones are English only models. The supported OpenAI
Whisper models are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`,
`medium`, `medium.en`, `large`, `large-v2` and `large-v3`. The supported
Distil-Whisper models are `distil-medium.en`, `distil-large-v2` and `distil-large-v3`.

View File

@ -374,6 +374,8 @@ enum WhichModel {
DistilMediumEn,
#[value(name = "distil-large-v2")]
DistilLargeV2,
#[value(name = "distil-large-v3")]
DistilLargeV3,
}
impl WhichModel {
@ -386,7 +388,8 @@ impl WhichModel {
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::DistilLargeV2 => true,
| Self::DistilLargeV2
| Self::DistilLargeV3 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
}
@ -408,6 +411,7 @@ impl WhichModel {
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
Self::DistilLargeV3 => ("distil-whisper/distil-large-v3", "main"),
}
}
}

View File

@ -13,7 +13,7 @@ struct Block {
impl Block {
fn get(&self, key: &str) -> Result<&str> {
match self.parameters.get(&key.to_string()) {
match self.parameters.get(key) {
None => candle::bail!("cannot find {} in {}", key, self.block_type),
Some(value) => Ok(value),
}
@ -28,7 +28,7 @@ pub struct Darknet {
impl Darknet {
fn get(&self, key: &str) -> Result<&str> {
match self.parameters.get(&key.to_string()) {
match self.parameters.get(key) {
None => candle::bail!("cannot find {} in net parameters", key),
Some(value) => Ok(value),
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 175 KiB

View File

@ -99,7 +99,7 @@ pub fn report_detect(
let h_ratio = initial_h as f32 / h as f32;
let mut img = img.to_rgb8();
let font = Vec::from(include_bytes!("roboto-mono-stripped.ttf") as &[u8]);
let font = rusttype::Font::try_from_vec(font);
let font = ab_glyph::FontRef::try_from_slice(&font).map_err(candle::Error::wrap)?;
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
println!(
@ -119,27 +119,28 @@ pub fn report_detect(
);
}
if legend_size > 0 {
if let Some(font) = font.as_ref() {
imageproc::drawing::draw_filled_rect_mut(
&mut img,
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size),
image::Rgb([170, 0, 0]),
);
let legend = format!(
"{} {:.0}%",
candle_examples::coco_classes::NAMES[class_index],
100. * b.confidence
);
imageproc::drawing::draw_text_mut(
&mut img,
image::Rgb([255, 255, 255]),
xmin,
ymin,
rusttype::Scale::uniform(legend_size as f32 - 1.),
font,
&legend,
)
}
imageproc::drawing::draw_filled_rect_mut(
&mut img,
imageproc::rect::Rect::at(xmin, ymin).of_size(dx as u32, legend_size),
image::Rgb([170, 0, 0]),
);
let legend = format!(
"{} {:.0}%",
candle_examples::coco_classes::NAMES[class_index],
100. * b.confidence
);
imageproc::drawing::draw_text_mut(
&mut img,
image::Rgb([255, 255, 255]),
xmin,
ymin,
ab_glyph::PxScale {
x: legend_size as f32 - 1.,
y: legend_size as f32 - 1.,
},
&font,
&legend,
)
}
}
}

View File

@ -448,9 +448,9 @@ pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows10
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
///
/// The integrated loudness measurement is not just the average power over the
/// entire signal. BS.1770-4 defines defines two stages of gating that exclude
/// entire signal. BS.1770-4 defines two stages of gating that exclude
/// parts of the signal, to ensure that silent parts do not contribute to the
/// loudness measurment. This function performs that gating, and returns the
/// loudness measurement. This function performs that gating, and returns the
/// average power over the windows that were not excluded.
///
/// The result of this function is the integrated loudness measurement.

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.4.2"
version = "0.5.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.2" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.1" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);

View File

@ -139,7 +139,9 @@ impl FlashAttn {
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let softmax_lse = dev
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
.w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.4.2"
version = "0.5.1"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -1,5 +1,8 @@
fn main() {
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed=src/compatibility.cuh");
println!("cargo:rerun-if-changed=src/cuda_utils.cuh");
println!("cargo:rerun-if-changed=src/binary_op_macros.cuh");
let builder = bindgen_cuda::Builder::default();
println!("cargo:info={builder:?}");

Some files were not shown because too many files have changed in this diff Show More