Compare commits

...

1606 Commits

Author SHA1 Message Date
fb660b8d43 Add a cudnn feature to candle-nn/candle-transformers. (#2890) 2025-04-13 17:43:41 +02:00
2f9606b187 Exclude candle-book to avoid some CI failures. (#2889)
* Exclude candle-book to avoid some CI failures.

* Remove the book CIs.
2025-04-13 17:11:41 +02:00
f3a73f80d1 Support for cudnn conv1d. (#2888)
* Support for cudnn conv1d.

* More conv1d work.

* Get the conv1d to work with cudnn.

* Cleanup.
2025-04-13 16:47:37 +02:00
b44d38de0e Add the Orpheus TTS. (#2886)
* Add the Orpheus TTS.

* Add a small readme.

* Token fix.

* Support more voices.

* Clippy fixes.
2025-04-13 12:02:17 +02:00
d9198deb37 Im2col cuda optimization. (#2885) 2025-04-13 10:07:53 +02:00
15ed0b11ce Optimize the batched matmul for the cpu backend. (#2884) 2025-04-12 21:40:40 +02:00
34505fdf3a Avoid using batched-matmul in nn::Linear. (#2883)
* Avoid using batched-matmul in nn::Linear.

* Also avoid batched matmul in conv1d.

* Also tweak the conv2d.

* Batched tests.

* Also cover conv2d.
2025-04-12 19:53:58 +02:00
d7b7ce16e4 Upgrade ug. (#2882) 2025-04-12 13:19:32 +02:00
19fb6dac1f Bump the crate version. (#2881) 2025-04-11 22:28:21 +02:00
acc5bd335f Cuda cleanup. (#2880)
* Cuda cleanup.

* More fixes.
2025-04-11 21:43:35 +02:00
eb478ece92 Implementing DistilBertForMaskedLM. (#2866)
* Initial commit: model weights working, prediciton incorrect

* moved distilbertformaskedlm into distilbert modeling file

* made maskedLM like bert example, still incorrect predictions

* finally not getting NaNs, fixed attention mask

* getting correct output sentences

* get top k predictions

* fixed output formatting slightly

* added default arg for model_id

* lint

* moved masked token example code from distilbertformaskedlm example to distilbert example

* lint

* removed distilbertformaskedlm example

* cleanup

* clippy

* removed embedding normalization from example

* made output and model dependent on args instead of prompt

* lint

* replaced or_ok anyhow error with anyhow context

* changed error message for mask token not found
2025-04-11 13:25:39 +02:00
d339b01726 Fix hardcoded f32 dtype for attention_mask. Use the model dtype for compatibility. (#2872) 2025-04-08 06:12:14 +02:00
2f3bf42bcb Support more snac variants. (#2871) 2025-04-07 08:23:47 +02:00
e3370c6316 Add the SNAC audio tokenizer. (#2869)
* Add the SNAC audio tokenizer.

* More snac.

* Again more snac.

* Add some example code for snac.

* Get the weights to load.

* Add to the snac model.

* Fixes.

* Get round-tripping to work.

* Save/load code files.

* Clippy fix.

* Fmt fix.
2025-04-06 22:15:36 +02:00
338f6a102e Clippy 1.86 fixes for cuda. (#2868) 2025-04-05 15:45:35 +02:00
bc33df77e1 Add the missing voices for CSM. (#2867) 2025-04-05 06:52:36 +02:00
cf9d7bf24c Add the CSM model. (#2862)
* Add the CSM model.

* Add some code to load the model.

* Load the text tokenizer.

* Add frame generation.

* Get the sampling to work.

* Rope fix.

* Autoregressive generation.

* Generate some audio file.

* Use the actual prompt.

* Support multiple turns.

* Add a very barebone readme.

* Move some of the shared bits to the model.
2025-04-04 06:48:03 +02:00
9d31361c4f Fix for clippy 1.86. (#2864)
* Fix for clippy 1.86.

* More clippy fixes.

* More fixes.
2025-04-03 19:38:27 +02:00
648596c073 Added readmes to examples (#2835)
* added chatGLM readme

* changed wording in readme

* added readme for chinese-clip

* added readme for convmixer

* added readme for custom ops

* added readme for efficientnet

* added readme for llama

* added readme to mnist-training

* added readme to musicgen

* added readme to quantized-phi

* added readme to starcoder2

* added readme to whisper-microphone

* added readme to yi

* added readme to yolo-v3

* added readme to whisper-microphone

* added space to example in glm4 readme

* fixed mamba example readme to run mamba instead of mamba-minimal

* removed slash escape character

* changed moondream image to yolo-v8 example image

* added procedure for making the reinforcement-learning example work with a virtual environment on my machine

* added simple one line summaries to the example readmes without

* changed non-existant image to yolo example's bike.jpg

* added backslash to sam command

* removed trailing - from siglip

* added SoX to silero-vad example readme

* replaced procedure for uv on mac with warning that uv isn't currently compatible with pyo3

* added example to falcon readme

* added --which arg to stella-en-v5 readme

* fixed image path in vgg readme

* fixed the image path in the vit readme

* Update README.md

* Update README.md

* Update README.md

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2025-04-03 09:18:29 +02:00
d9904a3baf Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14.

* Adapt a couple more things.

* And a couple more fixes.

* More tweaks.

* And a couple more fixes.

* Bump the major version number.

* Proper module system for the cuda kernels.

* Proper ptx loading.

* Launch the sort kernel.

* Custom op.

* Start using the builder pattern.

* More builder.

* More builder.

* Get candle-core to compile.

* Get the tests to pass.

* Get candle-nn to work too.

* Support for custom cuda functions.

* cudnn fixes.

* Get flash attn to run.

* Switch the crate versions to be alpha.

* Bump the ug dependency.
2025-04-03 09:12:19 +02:00
d6db305829 Added new language pairs to marian-mt example. (#2860)
* added new language pairs to marian-mt

* lint

* seperated python code for converting tokenizers into its own file and and added a reqirements.txt for dependencies, updated instructions in readme and included python version

* Cleanup.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-04-02 23:50:14 +02:00
b4daa03e59 add as_cuda_slice_mut to CudaStorage and CudaDType (#2859) 2025-04-01 19:34:52 +02:00
9541467d6b Add flip to tensor (#2855)
* Add `flip` to `tensor`

* Move the tests to the proper places.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-04-01 09:07:16 +02:00
6429609090 Added Deepseekr1 Llama8b variant to quantized example (#2842)
* added deepseekr1 llama8b variant to quantized example

* lint
2025-03-30 10:55:21 +02:00
ba473290da Added DeepseekR1 Qwen7B variant to quantized-qwen2-instruct example (#2843)
* quantized deepseek qwen generating tokens

* removed is_deepseek from Args and replaced prompt if statement with pattern matching
2025-03-30 10:54:22 +02:00
59c26195db Fix CIFAR10 dataset types and dimension ordering (#2845) 2025-03-30 10:53:25 +02:00
cb02b389d5 Fix reinforcement learning example (#2837) 2025-03-26 16:27:45 +01:00
0d4097031c fixed rand import for mnist-training (#2833) 2025-03-26 08:10:03 +01:00
10853b803c fixed rand imports for whisper-microphone example (#2834) 2025-03-26 08:09:27 +01:00
f3d472952f fix: candle-flash-attn linux and msvc build (#2829)
* fix: candle-flash-attn linux and msvc build

* Missing newline at eof.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-03-25 08:45:12 +01:00
67b85f79f1 Pickle decoder fix and Long1 opcode addition. (#2824)
* Pickle decoder changes: added Long1 opcode, fixed tensor offset calculation

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-03-23 08:10:08 +01:00
0b24f7f0a4 Fix for whisper example. rand::distribution is now rand::distr (#2811) 2025-03-16 19:14:55 +01:00
3afb04925a Allow for growing the default KV cache when needed. (#2810) 2025-03-16 17:30:25 +01:00
cbf5fc80c2 Add Gemma 3 1b IT toe Gemma examples (#2809)
- Updates the Gemma example to include Gemma 3 1b instruction tuned.
2025-03-16 17:00:48 +01:00
468d1d525f Bump the crate version to 0.8.4. (#2808) 2025-03-15 07:42:24 +01:00
c930ab7e1a upgrade half library to fix rand (#2806)
fix lints
2025-03-14 09:01:54 +01:00
111edbc4ea Gemma 3 initial setup (text only). (#2802)
* Gemma 3 initial setup (text only).

* Use the rotating kv cache for the sliding window.
2025-03-14 07:56:02 +01:00
e286cf7cc9 Parse the json config for siglip models. (#2800)
* Parse the json config for siglip models.

* Bump the tokenizers dependency.

* Add a v2 model.

* Support more v2 model.s
2025-03-09 14:01:09 +01:00
e4ffb85228 Add ModernBert sentency classifier (#2796) 2025-03-08 14:48:22 +01:00
37db86ff79 Allow ModernBert to be used to generate embeddings. (#2791) 2025-03-03 12:39:04 +01:00
add3a714aa phi-4-mini (#2790) 2025-03-01 10:07:29 +01:00
26c16923b9 Make sorted_nodes pub function (#2780) 2025-02-22 10:23:45 +01:00
9e8bf70333 Avoid some clippy lints on 1.85. (#2778)
* Avoid some clippy lints on 1.85.

* Upload artifacts v4.
2025-02-22 10:23:22 +01:00
ac9cdbd448 Refactor From<Tuple> implementations by using macros, add tests (#2762) 2025-02-19 10:58:29 +01:00
e6cc76fc37 Implement DeepSeek V2 (#2744)
* Add deepseek v2

* Fix

* Remove unused

* Add kv cache

* Remove from cargo.toml

* Fix dtype selection logic

* Fix unnecessary u32->f32->gather->u32

* Remove fromstr impl

* Use local scopes for some clarity

* Typo

* Repeat k_pe

* Chain calls to remove mut

* Actually, remove all muts

* Update readme
2025-02-19 10:51:01 +01:00
fd7f7242a1 Bump the crate version to 0.8.3 (#2772)
* update to cudarc to v0.13.5 to support cuda 12.8

* Bump the crate version.

---------

Co-authored-by: Michael McCulloch <michael.james.mcculloch@fastmail.com>
2025-02-15 15:54:48 +01:00
3ddd20a5aa update to cudarc to v0.13.5 to support cuda 12.8 (#2771)
Co-authored-by: Michael McCulloch <michael.james.mcculloch@fastmail.com>
2025-02-15 15:47:23 +01:00
2423d633fc add dynamic position encoding to Siglip (#2770)
* add dynamic position encoding

* remove debug messages
2025-02-14 13:50:50 +01:00
7c2449f623 Metal: Improved reduce and softmax (#1819)
* Improve reduce perf and add contiguous impl

* Improve arg reduce and add contiguous impl

* Improve softmax kernel. 33%-39% higher thrpt

* fmt

* Fixed all bugs. Improved code quality. Added tests.

* Stash for debugging

* Stash for debugging 2

* Fixing argmax bug and improve performance

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>

* Fix test and add is_valid_simgroup_reduce_type trait

* Online softmax. Improved threadgroup reduce. Tidying up a bit.

* Remove redundant threadgroup_barrier from arg reduce

* Mostly tidying up. Some improvements

* Simplify indexed struct

* tidying

* Reuse operation operator instead of passing it in as a parameter

* Fix how operators are applied to indexed<vec<T,N>>

* Vectorized load. Scalar block reduce. Hitting max throughput for f32 reduce.

* Vectorized load for online softmax. Involves a reinterpret_cast of src which may be suboptimal.

* Metal as_type casting vec<bfloat, N> -> vec<float, N/2> for simd and fast math

* Use constant for input instead of const device. Fix strided reduce.

* Use contiguous reduce in tests

* Rename finalize -> to_scalar

* Support integer types max/min (switch with trait-inferred impl later)

* Was worried I was skipping work -> shuffling the 1D test cases

* Add build.rs to avoid metal kernel jit compile overhead

* Improve build. Extract utils

* Compile metal kernels for both macos and ios

* Fixed over xmas and then forgot about it

* Add calculate_reduce_threads util

* Remove old reduce.metal

* Improve f16/bf16 softmax precision by accumulating in f32

* Remove build.rs (for now)

* Move softmax bench to candle-nn

* Remove redundant thread calc util fn

* Use uint over ushort for indices etc

* Use fast exp in MDReduceOp

* Remove nested metal define for softmax

* Fix some clippy lint.

---------

Co-authored-by: Christopher Fleetwood <45471420+FL33TW00D@users.noreply.github.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-02-08 07:27:01 +01:00
0af3e428ec fix: place ug dep behind not wasm32 flag (#2760)
* place `ug` behind not wasm32 attr

so that wasm32 can compile

* mv `ug` to conditional target dep

assuming every non-wasm32 user wants this
2025-02-01 23:05:52 +01:00
43017539ab Adds DebertaV2/V3 (#2743)
* Adds DebertaV2/V3

* Fixes all clippy warnings

* Typos.

* Addresses PR review findings. Some refactorings

* Avoid some unwrap/unwrap_or.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-01-29 08:59:28 +01:00
e142bf9530 use moondream1 model/revision for moondream example (#2748) 2025-01-28 22:19:54 +01:00
d2c53f4f2f Remove the MFA gemm library. (#2755) 2025-01-28 21:48:17 +01:00
2a2852d1c1 Fix flash-attn build. (#2754) 2025-01-28 18:49:46 +01:00
8f20f2a722 Add the MLX merge sort kernels (#2751)
* Add some metal sort kernels imported from MLX.

* Add another test.

* Start adding the multiblock version.

* Proper kernel names.

* Split out the main metal file.

* Multi-block sort.

* More sorting.

* DType parametrization.

* Add a larger test.
2025-01-28 14:09:43 +01:00
ab9019425a Make the metal sdpa tests deterministic. (#2750) 2025-01-28 09:05:24 +01:00
da02b59516 Allow using composed strings as metal kernel names. (#2747) 2025-01-27 22:40:12 +01:00
27996a1a9e Remove the old MFA gemm kernels. (#2742)
* Remove the old MFA gemm kernels.

* Use bf16 in helium on metal.
2025-01-26 20:36:31 +01:00
1a32107fab Add a few metal gather ops. (#2740)
* Add a few metal gather ops.

* Fix some compilation issues.

* Adjust the tolerance.
2025-01-25 23:31:03 +01:00
333d94a19a fix: fix the codegeex4 model examples and transformers model (#2738)
* Update main.rs

* Update codegeex4_9b.rs

* Get things to compile.

* Add some default for when rope_ratio is missing.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-01-25 17:41:12 +01:00
3164a19a5d Add inpainting to the stable diffusion example (#2735)
* Update the stable diffusion example with inpainting support for 1.5, 2 and XL.

* Apply cargo fmt.

* Clippy fixes.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-01-23 10:08:38 +01:00
e6cd499e98 Fix candle-flash-attn build on Windows (msvc) (#2734) 2025-01-22 22:19:48 +01:00
77db8396d0 Explicit error when slice-set is called with the same src and dst. (#2733) 2025-01-22 21:31:49 +01:00
85f0aaefe5 Add serde::serialize to activations. (#2732) 2025-01-22 10:23:34 +01:00
e4c3a71f11 Fix GLM4 alignment issue (#2723)
* Fix GLM4 alignment issue

* Cleanups.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-01-20 22:51:46 +01:00
17cbbe4286 Sync upstream MLX sdpa vector kernels with mask (#2718)
* Sync upstream mlx sdpa vector kernels with mask

* Dispatch to the 2pass kernel

* Format
2025-01-16 11:30:10 +01:00
6fd2f63a15 Bump the ug dependency. (#2720)
* Bump the ug dependency.

* Fix some test.

* Fix the ug test.
2025-01-16 09:39:16 +01:00
efd0e6822f Fix the helium weights download. (#2717) 2025-01-13 18:21:37 +01:00
158817f230 Helium repo update. (#2716) 2025-01-13 18:04:14 +01:00
309cd0f7c7 Add the helium model. (#2715) 2025-01-13 17:39:49 +01:00
ab7ff7081e Fixes for running Phi-4 quantized. (#2714) 2025-01-13 14:35:33 +01:00
461e8c1685 ModernBERT model (#2713)
* layer_norm_no_bias

* Modernbert model.

* Format + cleanup error.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-01-13 08:39:27 +01:00
2344c4e4b8 Clippy fixes for 1.84. (#2710) 2025-01-10 10:15:15 +01:00
32defdb7d5 Update cudarc. (#2708) 2025-01-08 15:10:23 +01:00
236c35e578 Bump the caret version to 0.8.2. (#2703) 2025-01-07 15:50:16 +01:00
6f8351dfda add link to README (#2701) 2025-01-04 23:07:30 +01:00
57f41da13b Fix mistral attention on Metal (#2699)
Co-authored-by: Luka Zakrajsek <luka.zakrajsek@soniox.com>
2025-01-04 16:11:20 +01:00
cbaa0ad46f UniPC for diffusion sampling (#2684)
* feat: Add unipc multistep scheduler

* chore: Clippy and formatting

* chore: Update comments

* chore: Avoid unsafety in float ordering

* refactor: Update Scheduler::step mutability requirements

* fix: Corrector img2img

* chore: Update unipc ref link to latest diffusers release

* chore: Deduplicate float ordering

* fix: Panic when running with dev profile
2025-01-01 21:34:17 +01:00
b12c7c2888 Update the hf-hub dependency to 0.4.0. (#2691)
* Update the hf-hub dependency to 0.4.0.

* Fix the book.

* Use 0.4.1.
2024-12-31 19:07:47 +01:00
94ffc2ec6f Actually remove the default hf-hub cache path for glm. (#2696) 2024-12-31 11:00:44 +01:00
7354afc673 Use the default hf-hub cache for glm. (#2695) 2024-12-31 10:55:45 +01:00
2a705e6f37 Flash-Attn upgrade / SoftCap Candle-FlashAttn [3/n] (#2690)
* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace

* softcap is working, including test and api.

* make softcap test case better

* unpadded lse added
2024-12-31 10:04:47 +01:00
a594ef669c Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)
* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace

* softcap is working, including test and api.

* make softcap test case better

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-12-31 09:41:23 +01:00
71cd6d5533 Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace
2024-12-31 09:32:22 +01:00
d60eba1408 Streamline the glm4 example. (#2694) 2024-12-31 09:21:41 +01:00
e38e2a85dd Fix a cuda warning. (#2693) 2024-12-31 09:06:10 +01:00
460616fc84 Update README.org (#2670)
The command line error in the CPU section of the documentation.
2024-12-30 11:32:02 +01:00
91f1f019b1 Added XLMRobertaModel for Reranking (#2686)
* add xlm-roberta-base

* Add task enum for fill-mask and reranker in xlm-roberta example; update README and fix attention mask dimensions

- Introduced a new `Task` enum to replace string task identifiers in the xlm-roberta example.
- Updated the logic in `main.rs` to handle tasks using the new enum.
- Enhanced README with example output for fill-mask task.
- Fixed dimension retrieval in `prepare_4d_attention_mask` function for better clarity and safety.

* Clippy fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-12-30 11:16:57 +01:00
cd639131f0 Fix bug in whisper transformer (#2681)
* Fix bug in whisper transformer
- due to num_threads going to zero
in single threaded case

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-12-24 13:58:21 +01:00
11aa30be10 Fix Batcher iterator break when return_last_incomplete_batch and items.is_empty (#2654) (#2655) 2024-12-24 08:41:26 +01:00
1be6b090c7 Fix position encodings for Pixtral (#2678)
* init commit: add position id in meshgrid

* pass in subsampled positions

* clippy fix

* clippy fix
2024-12-23 13:22:35 +01:00
62ced44ea9 Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context.

* Switch two unwrap to context.
2024-12-22 09:18:13 +01:00
5c2f893e5a make DepthAnythingV2 more reusable (#2675)
* make DepthAnythingV2 more reusable

* Fix clippy lints.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-12-21 12:06:03 +01:00
67cab7d6b8 Bump the crate version to 0.8.1. (#2662) 2024-12-07 17:03:53 +01:00
1807be84f4 Change/bert encoder public (#2658)
* change: BertEncoder struct to public

* change: make certain fields in Config struct public

* change: all fields in bert config struct to be public

* change: add clone to bert encoder and others

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-12-04 21:22:30 +01:00
145aa7193c Add Nvembed v2 model (#2649)
* Update mod.rs

* Create mod.rs

* Create decoder.rs

* Create model.rs

* Create main.rs

* Create README.md

* Update README.md

* Update main.rs

* Update and rename decoder.rs to embedding.rs

* Update mod.rs

* Update model.rs
2024-12-03 10:56:01 +01:00
6f715f9256 add scatter add (#2656) 2024-12-01 18:39:38 +01:00
dba7a9c93e add u32 - U32 gather (#2653) 2024-11-30 23:18:07 +01:00
b52c2c6050 Clippy fixes for the cuda feature. (#2650) 2024-11-29 09:01:34 +01:00
4f59ed38b0 Adds support for stella_en_v5 embedding model -400M variant (#2608)
* Adds support for stella_en_v5 embedding model -400M variant

* Unified stella

* WIP: Unified Stella

* Combined stella for both 1.5B and 400M variants

* Cargo fmt for the CI

* removed redundant stella-400m model and example after merge into stella-en-v5

* cargo fmt --all

---------

Co-authored-by: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-11-29 09:01:08 +01:00
54e7fc3c97 Lint fixes introduced with Rust 1.83 (#2646)
* Fixes for lint errors introduced with Rust 1.83

* rustfmt

* Fix more lints.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-28 23:00:21 +01:00
23ed8a9ded Fix for whisper-microphone example failure if audio isn't chunk aligned (#2645)
At least on my macOS Sequoia system (MBP 14" 2021, M1 Pro), when I run
the `whisper-microphone` example after it has gathered 10 seconds of
audio, it fails before the transcription:

```
Error: Insufficient buffer size 384 for input channel 0, expected 1024
```

At least for the audio device I'm using (Airpods Pro Max), there is no
guarantee that each audio buffer is a multiple of 1024 samples.  Thus at
the end of the 10 seconds, `buffered_pcm` can have some samples at the
end that do not form a complete 1024 sample chunk.

This fixes that by tracking when there is a partial chunk at the end of
the buffer, and leaving it in `buffered_pcm` to be processed on the next
loop iteration.

Note that, in the interest of keeping this PR as small as possible, I
didn't make any other changes to this example.
2024-11-27 22:35:11 +01:00
21c686387c Onnx Support for Sign operation #2641 (#2642)
* Support for Sign operation #2641

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-26 23:10:09 +01:00
b4deb5c5a9 Provide a method to allow PTH files with state maps to be loaded. (#2639)
* Provide a method to allow PTH files iwth state maps to be loaded.

* add a line to the doc

* String-. &str
2024-11-26 22:52:53 +01:00
c12db594e3 fix typo (#2606) 2024-11-23 08:40:00 +01:00
f86f4d6224 Tweak the CI to avoid running out of disk space. (#2630)
* Tweak the CI to avoid running out of disk space.

* Linux only.
2024-11-19 04:32:36 +01:00
3159f91b90 20241118 docs (#2629)
* module docs

* varbuilder gguf docs

* add a link to gguf files

* small additonal mod doc titles

* safetensor docs

* more core docs

* more module docs in canlde_core

* 2 more link fixes
2024-11-19 04:07:07 +01:00
1a0f9ccf16 Import the ggml_cuda_dp4a function. (#2628) 2024-11-19 03:41:34 +01:00
e86565624b Fix for clippy. (#2626) 2024-11-18 14:32:38 +01:00
386fd8abb4 Module Docs (#2624)
* update whisper

* update llama2c

* update t5

* update phi and t5

* add a blip model

* qlamma doc

* add two new docs

* add docs and emoji

* additional models

* openclip

* pixtral

* edits on the  model docs

* update yu

* update a fe wmore models

* add persimmon

* add model-level doc

* names

* update module doc

* links in heira

* remove empty URL

* update more hyperlinks

* updated hyperlinks

* more links

* Update mod.rs

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-11-18 14:19:23 +01:00
12d7e7b145 More Model Module Docs (#2623)
* dinov2

* add another example

* ad dinov2reg4

* eva2

* efficientvit

* moondream

* update t5

* update t5

* rwkv

* stable diffusion docs

* add wasm link

* add segment_anything

* adjsut for clippy

* ignore bertdoc

* dinov2 ignore

* update block to be text

* remove the rust blocks for the moment

* bump python to 3.11

* add a setup-python step

* add py311 to test as well
2024-11-17 20:27:24 +01:00
a3f200e369 Module Docs (#2620)
* update bert docs

* update based

* update bigcode

* add pixtral

* add flux as well
2024-11-16 09:09:17 +01:00
00d8a0c178 Remove some unused macros. (#2618)
* Remove some unused macros.

* More unused fixes.
2024-11-15 16:46:55 +01:00
f689ce5d39 Documentation Pass for Models (#2617)
* links in chinese_clip

* links for clip model

* add mod docs for flux and llava

* module doc for MMDIT and MIMI

* add docs for a few more modesl

* mod docs for bert naser and beit

* add module docs for convmixer colpali codegeex and chatglm

* add another series of moddocs

* add  fastvit-llama2_c

* module docs mamba -> mobileone

* module docs from moondream-phi3

* mod docs for quantized and qwen

* update to yi

* fix long names

* Update llama2_c.rs

* Update llama2_c_weights.rs

* Fix the link for mimi + tweaks

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-11-15 08:30:15 +01:00
0ed24b9852 Add max-all/min-all. (#2616) 2024-11-14 21:08:04 +01:00
06350c31c7 Add some missing index-select metal kernels. (#2613)
* Add some missing index-select metal kernels.

* Make some matrix contiguous pre-matmul.
2024-11-12 17:10:12 +01:00
9453cc3095 Bump the crate version to 0.8.0. (#2612) 2024-11-12 14:11:46 +01:00
3769206583 Update docs (#2553)
* add module docs for candle-core

* doc each of the candle-nn modules and add the links to the doc page
2024-11-11 22:13:52 +01:00
e2b6b367fa Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32)

* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format

* Conditional compilation for bf16

* Use it in quantized llama

* Some review comments

* Use set_params!

* Remove unused

* Remove feature

* Fix metal sdpa for v stride

* Remove comma

* Add the dim method to layout and shape.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-05 09:28:00 +01:00
6454597943 Improved launch config for layer-norm/rms-norm. (#2591)
* Improved launch config for layer-norm/rms-norm.

* Add more testing for the fused layer/rms norm kernels.
2024-11-04 10:42:18 +01:00
3fba2b5fc4 Add the SmolLM2 models. (#2595)
* Add the SmolLM2 models.

* More SmolLM2 support.
2024-11-03 17:11:12 +01:00
530ab96036 Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)
* support skip layer guidance (slg) for stable diffusion 3.5 medium

* Tweak the comments formatting.

* Proper error message.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-01 18:10:40 +01:00
7ac0de15a9 Lazy upcasting for t5. (#2589) 2024-10-30 18:08:51 +01:00
d232e132f6 Support sd3.5 medium and MMDiT-X (#2587)
* extract attn out of joint_attn

* further adjust attn and joint_attn

* add mmdit-x support

* support sd3.5-medium in the example

* update README.md
2024-10-30 06:19:07 +01:00
139ff56aeb Reduce memory usage for sd 3.5. (#2582) 2024-10-28 22:45:02 +01:00
498bc2cdc9 Release the mmdit model earlier to reduce memory usage. (#2581)
* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.

* Release the mmdit model earlier to reduce memory usage.
2024-10-28 16:06:53 +01:00
0e2c8c17fb UG metal integration. (#2580) 2024-10-27 15:20:37 +01:00
594d984f9c Support for UG kernels. (#2579)
* Support for UG kernels.

* Add a dedicated test.
2024-10-27 13:37:19 +01:00
37e0ab8c64 Stable diffusion 3.5 support. (#2578)
* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.
2024-10-27 10:01:04 +01:00
07849aa595 Update README.md (#2577) 2024-10-26 18:23:52 +02:00
3699c1a053 Fix the repo name for llama 3.1. (#2576)
* Fix the repo name for llama 3.1.

* Fix the book.
2024-10-26 11:25:04 +02:00
a2e9d41b20 use softmax_last_dim (metal and cuda kernel) in llama attention layer (#2572) 2024-10-23 20:07:09 +02:00
7c09215ef4 ONNX: GatherElements, Xor (#2568) 2024-10-17 20:22:35 +02:00
dcd83336b6 Testcases (#2567) 2024-10-17 13:00:45 +02:00
a01aa89799 onnx: ReduceMin/Max Ops (#2563)
* Stella_en_1.5B_v5

* Separated  creation. This is a critical step for numerical accuracy and would be documented in the readme

* EmbedDim would require clone and copy

* WIP: example

* Examples added

* a litte more in README

* WIP: ONNX Reduce-max ops

* WIP: tests for ReduceMin

* Reduce min/ max v18+

* Reformatting tests for better review readability

* Error on empty set, backward compatibility (13 and below) with 'axes'
2024-10-15 10:34:07 +02:00
3d1dc06cdb Enable stable-diffusion 3 on metal. (#2560) 2024-10-14 08:59:12 +02:00
f553ab5eb4 Adds support for Stella_en_v5 embedding model - 1.5B variant (#2551)
* Stella_en_1.5B_v5

* Separated  creation. This is a critical step for numerical accuracy and would be documented in the readme

* EmbedDim would require clone and copy

* WIP: example

* Examples added

* a litte more in README
2024-10-13 23:09:12 +02:00
41ade774e8 fix: Allow marian configs to deserialize from json. (#2556) 2024-10-13 23:05:50 +02:00
6eab6b57f5 Fix the guide to gain access to Stable Diffusion 3 Medium (#2559) 2024-10-13 22:55:26 +02:00
ca7cf5cb3b Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example

Add get_qkv_linear to handle different dimensionality in linears

Add stable diffusion 3 example

Add use_quant_conv and use_post_quant_conv for vae in stable diffusion

adapt existing AutoEncoderKLConfig to the change

add forward_until_encoder_layer to ClipTextTransformer

rename sd3 config to sd3_medium in mmdit; minor clean-up

Enable flash-attn for mmdit impl when the feature is enabled.

Add sd3 example codebase

add document

crediting references

pass the cargo fmt test

pass the clippy test

* fix typos

* expose cfg_scale and time_shift as options

* Replace the sample image with JPG version. Change image output format accordingly.

* make meaningful error messages

* remove the tail-end assignment in sd3_vae_vb_rename

* remove the CUDA requirement

* use default_value in clap args

* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime

* resolve clippy errors and warnings

* use default_value_t

* Pin the web-sys dependency.

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-10-13 22:08:40 +02:00
0d96ec31e8 feat: intergrate chinese clip and add example (#2555)
* start to impl chinese clip

* impl vision model

* copy code from bert

* refactor use

* refactor use again

* fix text model

* refactor

* try to fix text model

* tuning

* tuning chinese clip

* delete useless code

* revert code

* Clippy fixes.

* Also apply cargo fmt.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-10-10 15:18:55 +02:00
937e8eda74 Add BertForMaskedLM to support SPLADE Models (#2550)
* add bert for masked lm

* working example

* add example readme

* Clippy fix.

* And apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-10-07 23:28:21 +02:00
edf7668291 improve (#2548) 2024-10-07 17:30:56 +02:00
e4a96f9e7c Switch to using the MLX matmul by default. (#2547) 2024-10-06 23:24:55 +02:00
f856b5c3a7 pyo3 update. (#2545)
* pyo3 update.

* Stub fix.
2024-10-06 10:09:38 +02:00
d2e432914e Tensor tools print all (#2543)
* Support whisper large-v3 turbo in the whisper-microphone example.

* Print all tensors when no argument is provided.
2024-10-05 10:05:14 +02:00
410c89f72a Add required feature for whisper example in Readme (#2539) 2024-10-04 14:29:55 +02:00
56aacb05da Make the RNN configs accessible from the models. (#2541) 2024-10-04 14:22:23 +02:00
6faecaa616 Fix for cudnn bf16 conv2d. (#2535) 2024-10-02 23:18:55 +02:00
90d04ff622 Support whisper large-v3 turbo in the whisper-microphone example. (#2533) 2024-10-02 22:09:14 +02:00
7b60bda4ed Add support for cuda streams. (#2532) 2024-10-02 21:30:58 +02:00
936300678d Add whisper large-v3 turbo to the example. (#2531) 2024-10-02 21:07:08 +02:00
f479840ce6 Add a seed to the flux example. (#2529) 2024-10-02 10:52:02 +02:00
fd08d3d0a4 Tweak some metal tests. (#2528) 2024-10-02 10:22:31 +02:00
a2bcc227df Efficient implementation of Tensor::ones() for metal (#2512)
* WIP: hopefully better const impl

* with GPU

* More tests on

* Reverting primitive for

* Incorporating review changes - added check elem count check in kerner, using  for call strategy

* rustfmt ran
2024-10-01 19:11:59 +02:00
def4c6cdee Cuda quantized mmv bugfix. (#2526) 2024-10-01 12:57:55 +02:00
888d886dd8 Add ColPali (#2524)
* add colpali

* cleanup

* fix clippy
2024-10-01 11:48:39 +02:00
6110ad8d4f Refactor the whisper microphone example. (#2523)
* Refactor the whisper microphone example.

* Tweak the whisper microphone example more.
2024-10-01 00:24:17 +02:00
aa35bf2ff5 Add/lstm direction (#2455)
* add: direction for lstm layer

* lint: remove unused Error import

* refactor: remove unnecessary int assignment to Direction enum:

* refactor: use &'static str type instead of String for direction_str:

* Run cargofmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-09-30 22:44:07 +02:00
724650446c Yet another cuda qmm padding fix. (#2509) 2024-09-30 21:53:30 +02:00
dfe9a00683 Pixtral polishing. (#2522)
* Pixtral polishing.

* Clippy fix.
2024-09-30 21:23:54 +02:00
683ab698de Add Pixtral. (#2521)
* Add Pixtral.

* More pixtral vision encoder.

* Sketch a pixtral example.

* Sketch a pixtral example.

* Better image loading.

* Support loading images embedded in safetensor files.

* Clippy fixes.

* Add the llava multimodal adapter.

* Add more of the llava bits.

* Add the pixtral config.

* More pixtral inference.

* Add the text generation bits.

* Get the example to work.

* Bugfix.

* Run some bits of the model in f32.

* Blessed version :)

* Better rope frequency computations.

* README update.
2024-09-30 19:31:14 +02:00
2f49e1b534 Add PaliGemma. (#2519)
* Add PaliGemma.

* PaliGemma inference loop.

* Running PaliGemma example.

* Tweak the prompt.
2024-09-29 19:56:56 +02:00
0ebb38813b Paligemma siglip vision config (#2518)
* Add the paligemma siglip vision config.

* More paligemma configs.
2024-09-29 17:53:52 +02:00
3a3c48b14b Bump the crate version to 0.7.2. (#2517) 2024-09-29 10:56:50 +02:00
261ed65f36 Add the SigLIP model. (#2515)
* Add the SigLIP model.

* Add more to the forward pass of the vision model.

* Complete the forward pass.

* Add the siglip example.

* Fix.

* Another fix.

* Get everything in place.

* Add a readme.
2024-09-28 23:48:00 +02:00
62525e8352 Remove some extra whitelines. (#2513) 2024-09-28 14:41:28 +02:00
2c25754281 Clippy fixes for onnx + fix a broken test. (#2510) 2024-09-26 23:37:59 +02:00
ed48f54b54 Expand split ops (#2505)
* candle-onnx: Add Split and Expand operators, Fix Where Op

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining Split examples as tests
TODO: Add.test case that motivates Where fix

* candle-onnx: Add ReduceSum operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Add ReduceL2 operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Fix Clip operator empty string as default arg issue

Optional input args may be signified by an empty string. The length of the input array is not enough because non optional args may follow optional ones.

I encountered this when trying to use the ONNX model found at https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 for example.

The LSTM op has a utility which I factored to be more generally accessible, and I have used it in the ops I have recently created or debugged.

I believe it is likely that this issue may also manifest in other ops, but I didn't want to change anything that I'm not testing.

* fix formatting

* fix small mistake made during refactor
2024-09-26 22:57:55 +02:00
ad8a4c5e5a Add some llama-3.2 examples. (#2508)
* Add some llama-3.2 examples.

* Support tie-word-embeddings for llama.
2024-09-26 21:00:18 +02:00
c3c392f45c Merge pull request #2507 from huggingface/ci-move
move CI/Cuda runner
2024-09-26 18:48:52 +02:00
a0184a4fe4 move CI/Cuda runner 2024-09-26 17:09:26 +02:00
10d47183c0 Quantized version of flux. (#2500)
* Quantized version of flux.

* More generic sampling.

* Hook the quantized model.

* Use the newly minted gguf file.

* Fix for the quantized model.

* Default to avoid the faster cuda kernels.
2024-09-26 10:23:43 +02:00
d01207dbf3 Add a RotatingKVCache. (#2493)
* Add a RotatingKVCache.

* Add some KvCache tests.

* Test the reset too.

* More kv-cache testing.

* More tests for the rotating kv-cache.

* Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge.

* Handle contiguity + bugfix + use in mimi.

* Add a way to test the mimi streaming mode.

* Mimi streaming fixes.

* More rotating kv-cache.

* Fix the attn mask generation.

* Handle the abs case.

* Add some tests for the generated mask.
2024-09-23 13:14:32 +02:00
8097559c1a Move the candle version to 0.7.1. (#2495) 2024-09-22 20:44:39 +02:00
829dcfa8dc Update cudarc to 0.12.1. (#2494) 2024-09-22 20:32:29 +02:00
c2fca0ca11 Bump the crate version. (#2491) 2024-09-21 15:13:12 +02:00
844d45cde4 Bugfix for the metal elu kernel. (#2490)
* Bugfix for the metal elu kernel.

* Add a test.
2024-09-21 15:03:19 +02:00
af2104078f Metal commands refactoring (#2489)
* Split out the commands part of the metal device.

* Make most fields private.

* Move the allocator back.

* Rework the encoder provider type.
2024-09-21 13:18:42 +02:00
5fc4f17727 Adding Granite 7b Instruct model example (#2487)
* Adding Granite 7b Instruct model example

* Minor refactoring to make it a little more idiomatic

* Clippy fixes.

* * Adding a README with some information about supported Granite models
* Changing the default prompt to accomodate better the Language
  modality of the Granite 7b Instruct model

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-09-21 11:52:01 +02:00
c58c5d5b01 Add the mimi audio-tokenizer. (#2488)
* Add the mimi audio-tokenizer.

* Formatting tweaks.

* Add a full example.

* Use the transformers names.

* More renamings.

* Get encoding and decoding to work.

* Clippy fixes.
2024-09-20 14:31:20 -06:00
382c6b51af Improve error message (#2485) 2024-09-20 07:11:41 -06:00
6eea45a761 Add a couple cast metal kernels. (#2479) 2024-09-15 22:27:46 +02:00
ebf722b446 Export TensorIndexer public to candle users (#2477) 2024-09-13 22:21:57 +02:00
c09afc211c Fix for metal tanh. (#2475) 2024-09-13 07:08:36 +02:00
b60faebea4 Missing metal kernels. (#2474) 2024-09-12 13:58:50 +02:00
72d649058b Hook the MLX matmul kernels in candle-core. (#2473) 2024-09-12 13:52:59 +02:00
0cb0bd1dfa Add some metal gemm benchark. (#2471)
* Add some metal gemm benchark.

* More benchmarks.
2024-09-11 22:52:37 +02:00
afb6575835 Use the new MLX kernels to handle the BF16 matmul. (#2470) 2024-09-11 17:34:05 +02:00
5635650d38 Integrate the MLX gemm kernels (#2468)
* Include the MLX gemm kernels.

* Clippy lints.

* Export the gemm_f32 kernel.

* Add the f16/bf16 variants.

* Add the initial dispatch code.

* More plugging of the mlx kernels.

* Add a currently broken test.

* Tweaks.

* Bugfix + get the tests to pass.

* Enable the gemm bf16 tests.

* Add some randomized tests.

* Update candle-metal-kernels/src/lib.rs

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>

* More fixes.

* More clippy fixes.

---------

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
2024-09-11 16:56:48 +02:00
13b2a8a4a0 Complete the missing backticks in the comments (#2469) 2024-09-11 16:37:05 +02:00
e3261216b1 Clippy fixes for 1.81.0. (#2461)
* Clippy fixes for 1.81.0.

* Another fix.
2024-09-05 23:46:55 +02:00
c02b7c3272 Fix FLUX.1 weights (#2457)
* fix FLUX.1 weights

* added flux1-dev.safetensors
2024-08-29 17:10:28 +02:00
86613c00e2 MobileCLIP models S1 and S2 (#2454)
* Allow loading images with given std and mean

* OpenCLIP text encoder component

* Two MobileCLIP models

* Clippy fixes.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-08-29 15:38:58 +02:00
29e25c458d FastViT fixes. (#2452)
* correct optional SE layer dimensions.
 * head_dim instead of num_heads is 32.
 * update test example output.
2024-08-28 11:20:09 +02:00
aafa24ed93 Update cudarc to 0.12. (#2451)
* Update cudarc to 0.12.

* Some cudnn tweaks.
2024-08-27 10:10:30 +02:00
fdc2622686 fix: qwen2 lm_head loading #2443 (#2445)
Co-authored-by: Yi Xu <xuyi@me.com>
2024-08-23 16:50:02 +02:00
ccdbe87639 Add FastViT model. (#2444) 2024-08-23 16:06:54 +02:00
2ec8729d51 Fix for parler-tts, do not add the last slice of padding tokens. (#2442)
* Fix for parler-tts, do not add the last slice of padding tokens.

* Support for the mini model.
2024-08-22 23:22:03 +02:00
e3c146ada6 silero-vad v5 example (#2321)
* silero-vad v5 example

This change adds an example of how to run silero-vad v5

* PR: rename 'vad' to 'silero-vad'

* Update README.md

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-08-22 22:50:42 +02:00
1e96b8b695 onnx: support negative index in Gather (#2440)
index_select does not support negative indexing, but
this change adds just enough workarounds in onnx to
allow evaluating silero-vad models (which make use of
negative indices).
2024-08-22 15:28:25 +02:00
a8288b7a72 onnx: workaround pow with negative base (#2439)
* onnx: workaround pow with negative base

rather than fully defining pow in the cpu backend (as in #2318),
this implements a much smaller change which is sufficient to evaluate silero-vad
onnx models. Specifically, checking if pow is run with 2.0 exponent, and if so
evaluate as simply `x*x` instead of the cpu backend of `e^(2.0 * ln(x))`.

* PR: use Tensor::powf insead

powf correctly handles a negative base.
2024-08-22 13:34:53 +02:00
6070278a31 Bump the version to 0.6.1. (#2438) 2024-08-22 09:23:52 +02:00
b47c0bc475 Update README.md (#2435) 2024-08-19 09:34:24 +02:00
14fd2d97e0 Add a readme for the parler-tts example. (#2434)
* Add a readme for the parler-tts example.

* Remove the python decode script.

* mp4 tweaks.

* Another readme tweak.
2024-08-19 09:30:12 +02:00
31a1075f4b onnx: implement LSTM op (#2268)
use candle-nn LSTM
2024-08-19 09:06:17 +02:00
236b29ff15 Add the DAC model. (#2433)
* Add the DAC model.

* More quantization support.

* Handle DAC decoding.

* Plug the DAC decoding in parler-tts.
2024-08-19 08:59:51 +02:00
58197e1896 parler-tts support (#2431)
* Start sketching parler-tts support.

* Implement the attention.

* Add the example code.

* Fix the example.

* Add the description + t5 encode it.

* More of the parler forward pass.

* Fix the positional embeddings.

* Support random sampling in generation.

* Handle EOS.

* Add the python decoder.

* Proper causality mask.
2024-08-18 20:42:08 +02:00
736d8eb752 Stream tensor (#2429)
* Support Minus(u) for arbitrary values of u, e.g. Minus(3).

* Forces u to be strictly positive.

* Add StreamTensor.
2024-08-17 21:54:28 +02:00
7cff5898ec Support Minus(u) for arbitrary values of u, e.g. Minus(3). (#2428)
* Support Minus(u) for arbitrary values of u, e.g. Minus(3).

* Forces u to be strictly positive.
2024-08-17 21:29:01 +02:00
b75ef051cf Fix the marian tokenizer importer. (#2426)
* Fix the marian tokenizer importer.

* Ignore the python caches.
2024-08-17 20:58:40 +02:00
c1b9e07e35 Add support for gemma-2. (#2425)
* Add gemma-2.

* Support a couple more models.

* Sliding window support.

* Example + readme updates.

* Update the main readme.
2024-08-17 20:31:23 +02:00
69fdcfe96a Apply rustfmt. (#2421) 2024-08-16 18:57:14 +02:00
2b75dd9551 Fix build issue in EOS Token in llama-multiprocess (#2420) 2024-08-16 18:46:31 +02:00
53ce65f706 Clippy fixes. (#2415)
* Clippy fixes.

* Bump the web_sys required version.
2024-08-14 10:13:53 +02:00
68aa9c7320 Fix the device for the bert attention mask. (#2414) 2024-08-14 10:01:12 +02:00
35e5f31397 Add Based LLM from Hazy Research. (#2411) 2024-08-12 21:21:19 +02:00
d3fe989d08 Add documentation examples for Tensor::i and Tensor::narrow methods (#2308)
* Add documentation examples for `Tensor` methods

* Apply fmt.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-08-10 08:11:09 +02:00
14db029494 Soft Non-Maximum Suppression (#2400)
* Soft NMS with thresholds

* NMS Test

* Soft nms w/ boxes removed below threshold

* Soft nms test

* No longer removing bounding boxes to fit Soft-NMS focus

* Initialize confidence

* Added comments

* Refactored out updating based on IOU/sigma

* Score_threshold -> confidence_threshold for clarity

* Remove bboxes below confidence threshold

* Softnms basic functionality test

* Softnms confidence decay test

* Softnms confidence threshold test

* Softnms no overlapping bbox test

* Testing confidence after no overlap test

* Single bbox and no bbox tests

* Signify test completion

* Handling result of test functions

* Checking all pairs of bboxes instead of a forward pass

* Equal confidence overlap test

* Clarified tests for implementation

* No longer dropping boxes, just setting to 0.0

* Formatted w/ cargo
2024-08-10 07:57:52 +02:00
6e6c1c99b0 Fix issues in the encodec example README.md (#2407)
Also squeeze the first dimension of the codes tensor in the example file to get the expected three dimensions.
2024-08-10 07:49:05 +02:00
b7d9af00cc fix: usage of actions/checkout@v2 (#2403)
* chore: changes from formatting on save

* fix: usage of `actions/checkout@v2`
2024-08-06 10:59:34 +02:00
59bbc0d287 Add the import script for the T5 tokenizer. (#2399) 2024-08-05 21:03:31 +02:00
dfdce2b602 Add the MMDiT model of Stable Diffusion 3 (#2397)
* add mmdit of stable diffusion 3

lint

add comments

* correct a misplaced comment

* fix cargo fmt

* fix clippy error

* use bail! instead of assert!

* use get_on_dim in splitting qkv
2024-08-05 19:26:15 +02:00
500c9f2882 add models support and example for THUDM/glm-4 (#2362)
* add models support and example for THUDM/glm-4

* fix the ci report

* fmt

* fix

* Update README.org

* Update README.org

* fmt

* Update README.org

* README.md add codegeex4

* README.md add glm4

* Typo.

* change expect into ?

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-08-05 17:48:09 +02:00
2be9bd211e Support for mistral-nemo. (#2396) 2024-08-04 19:52:40 +02:00
89eae41efd Support the flux-dev model too. (#2395) 2024-08-04 12:16:24 +02:00
c0a559d427 optimize gradient for silu a bit (#2393) 2024-08-04 11:24:17 +02:00
aa7ac1832d Simplify handling of flux modulations. (#2394) 2024-08-04 11:09:54 +02:00
19db6b9723 Add the flux model for image generation. (#2390)
* Add the flux autoencoder.

* Add the encoder down-blocks.

* Upsampling in the decoder.

* Sketch the flow matching model.

* More flux model.

* Add some of the positional embeddings.

* Add the rope embeddings.

* Add the sampling functions.

* Add the flux example.

* Fix the T5 bits.

* Proper T5 tokenizer.

* Clip encoder path fix.

* Get the clip embeddings.

* No configurable weights in layer norm.

* More weights related fixes.

* Yet another shape fix.

* DType fix.

* Fix a couple more shape issues.

* DType fixes.

* Fix the latent dims.

* Fix more shape issues.

* Autoencoder fixes.

* Get some generations out.

* Bugfix.

* T5 padding.

* Clippy fix.

* Add the decode only mode.

* Fix.

* More fixes.

* Finally get some generations to work.

* Add readme.
2024-08-04 08:14:33 +02:00
0fcb40b229 Revert the bf16 gemm metal changes for now. (#2386) 2024-08-01 23:08:47 +02:00
6991a37b94 update: LSTMState and GRUState fields to be public (#2384) 2024-08-01 16:30:32 +02:00
9ca277a9d7 Fix cargo fmt. (#2383)
* Fix cargo fmt.

* Clippy fix.

* Cosmetic tweaks.
2024-08-01 14:19:41 +02:00
2e9c010609 Jina Bert Example fix and more configuration (#2191)
* fix: fix jina bert example logic

* feat: enable jina embeddings de

* feat: allow more flexibility on Jina Bert
2024-08-01 13:59:20 +02:00
ac51f477eb Add Hiera vision model. (#2382) 2024-08-01 11:59:22 +02:00
d4b6f6eef6 Add a minimal test for the metal bf16 matmul. (#2381) 2024-08-01 11:22:46 +02:00
957d604a78 Enable BF16 on metal. (#2380) 2024-08-01 11:05:07 +02:00
ce90287f45 Add get_ids to GradStore (#2379) 2024-08-01 10:56:13 +02:00
1ba87a9450 Use BF16 on metal when possible. (#2378) 2024-08-01 10:48:58 +02:00
bd80078acf Fix log_sum_exp to handle large positive/negative inputs (#2367) 2024-08-01 10:37:02 +02:00
fea46cb719 Metal bgemm min changes (#2364)
* Add updated mfa metallib

* Add bgemm and tests
2024-08-01 10:06:04 +02:00
8696cf6494 Enable the affine kernel for u8/u32. (#2376) 2024-08-01 10:03:11 +02:00
4a52aeb437 bert attention mask (#1934)
* bert attention mask

* Allow for using None as a mask.

* Revert part of the changes so that the proper default mask applies.

* Cosmetic change.

* Another cosmetic tweak.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-08-01 08:26:19 +02:00
24d54d0ff9 Bump image crate version so ImageReader is available without aliasing (#2365) 2024-07-29 17:41:33 +02:00
636eff652a change DTypes (fixes #2355) (#2363) 2024-07-28 14:36:05 +02:00
0f5cbb08b3 Add support for Llama 3.1 (#2359)
* Add Llama 3.1 rope

* Clippy

* Format

* Clippy

* Add support for multiple eos tokens:

* Untagged either

* Remove either dep and fix settings.json

* Make the max positional embeddings configurable
2024-07-26 21:32:26 +02:00
ddafc61055 Use RAII for terminating the encoding. (#2353) 2024-07-24 16:29:56 +02:00
a925ae6bc6 Use a trait for the encoder provider (so that encoder can ultimately be reused). (#2352) 2024-07-24 09:27:30 +02:00
6056fd5c90 onnx: fix pad, unsqueeze (#2317)
* onnx: fix pad, unsqueeze

both implementations have off-by-one errors:
- Pad 'reflect' cycle for eg `dim==3` is `[0,1,2,1]` which has length of
  4 (or `dim*2 - 2`) not 5 (current code `dim*2 - 1`)
- Unsqueeze(-1) for tensor with `dim==3` should be 3 (ie `dim+index+1`)
  not 2 (ie currently `dim+index`)

in addition, Pad is incorrectly calculating the starting padding.
If we want to pad out 2 elements to the start, and we have this cycle
of indices of length 6, then we should skip 4 elements, but currently
we skip 2. A more visual representation of what's going on is below:

```
pad_start: 2
data:      [a,b,c,d]
indices:   [0, 1, 2, 3, 2, 1, 0, 1, 2, 3, 2, 1, 0, ..] // zigzag between 0..4
actual:    skip [ c  d| c  b  a  b]
expected:  ~  skip  ~ [ c  b| a  b  c  d]
```

The values between `[` and `|` are padding and the values between
`|` and `]` in the example should match the original data being padded.

* Fix clippy lints.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-07-23 23:10:57 +02:00
ebc9aa60bc fix clip example title (#2345) 2024-07-23 22:55:18 +02:00
2489a606fe feat(candle-transformers/models/codegeex4-9b): add codegeex4-9 (#2334)
* feat(candle-transformers/models/codegeex4-9b): add codegeex4-9b transoformers

* change mod.rs

* feat(candle-examples/codegeex4-9b)

* Update codegeex4_9b.rs

* Update main.rs

* Update codegeex4_9b.rs

* Update main.rs

* fmt

* fix

* fmt

* Clippy fix.

* Remove some print statements.

* Avoid using unwrap.

* 1. add README
2. change the print fmt

* Another clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-07-21 13:00:41 +02:00
3c815b1dca Pin the revision used by moondream. (#2340) 2024-07-18 10:49:46 +02:00
42891cc613 Add mathstral in the examples. (#2339) 2024-07-18 08:24:49 +02:00
f25173d68b Fix for backprop in ConvTranspose2D with stride of 2 (#2337)
* Add gradient test for conv_transpose2d with stride of 2.

* Swap dilation and stride in ConvTranspose2D backpropagation.

Without this, a shape mismatch occurs with a stride of 2 and dilation of 1.

* Add further tests of the ConvTranspose2D gradient.

Values calculated with torch, minor numerical errors adjusted and commented.
2024-07-17 19:22:23 +02:00
6a4741bbf9 Fix Elu gradient NaN on large input (#2328)
* Fix Elu gradient NaN on large input

* Reuse previously computed exp in Elu
2024-07-16 14:41:16 +02:00
30cdd769f9 Update the flash attn kernels. (#2333) 2024-07-15 20:37:36 +02:00
d74fbed334 Pinning cudarc to 0.11.6 (#2332) 2024-07-15 15:29:08 +02:00
c63048d374 add quantized qwen2 (#2329)
* add quantized version of qwen2 and corresponding example for qwen2-instruct

* fix quantized qwen2 clippy error
2024-07-12 10:00:03 +02:00
a226a9736b Add Mobilenet v4 (#2325)
* Support different resolutions in load_image()

* Added MobilenetV4 model.

* Add MobileNetv4 to README
2024-07-09 13:52:20 +02:00
25960676ca Add a basic metal example with capture (#2324)
* Add some tracing.

* Get the trace to work.
2024-07-09 12:38:11 +02:00
9cd54aa5d4 Add EVA-02 model ( https://arxiv.org/abs/2303.11331 ) (#2311)
* Add EVA-02 model ( https://arxiv.org/abs/2303.11331 )

* Clippy fix.

* And apply fmt.

---------

Co-authored-by: v-espitalier <>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-07-07 20:09:31 +02:00
eec11ce2ce onnx: implement Size op (#2316) 2024-07-07 19:56:36 +02:00
9182f9f5c2 ignore editor config folders (#2315) 2024-07-07 19:43:48 +02:00
ecff05d72b Beit: Add the gen_relative_position_index() function (#2306)
Co-authored-by: v-espitalier <>
2024-07-04 09:45:26 +02:00
7f1ba8038c Add Beit model ( https://arxiv.org/abs/2106.08254 ) (#2305)
Co-authored-by: v-espitalier <>
2024-07-01 22:11:48 +02:00
74e9e41911 make up for the missing last token output of phi2 example (#2299) 2024-06-29 21:34:42 +02:00
e27aac0a06 Add DINOv2Reg4 + PlantCLEF2024 (#2293)
* Add: DINOv2Reg4 with PlantCLEF2024 weights and example ( See https://arxiv.org/abs/2309.16588 and https://zenodo.org/records/10848263 )

* Remove extra files + update README to download them + remove extra lines

* minor fix (README remove extra spaces)

* minor fix (README: Fix image url)

* Modif: Add back interpolate_pos_encoding() + fix when no interpolation + remove extra comments + Update README ( source image changed and so the predictions )

* Fix: Improve code lisibility with '$ cargo clippy' and '$ cargo fmt'

* Another clippy fix.

---------

Co-authored-by: x-VEspit <vincent.espitalier@cirad.fr>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-06-29 11:49:15 +02:00
a3dd87f15e Adding Gemm and ArgMax operators to candle-onnx (#2231)
* feat(gemm): implement Gemm operator in candle-onnx

* feat(onnx): Add support for ArgMax operator in candle-onnx

* Apply rustfmt.

* Remove argmax as it was already present.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-06-28 21:40:31 +02:00
242e006bbb Depth Anything v2 (#2279)
* define structs

* construct ResidualConvUnit

* forward() for ResidualConvUnit

* implement FeatureFusionBlock

* implement Scratch

* implement DPTHead

* add identity module

* implement forward for DTPHead

* add get_intermediate_layers to DinoVisionTransformer

* implement DepthAnythingV2

* some minor tweaks

* fix compile errors

* fix var builder prefixes

* setup initial example

* use fixed patch size of 37 (518 / 14)

* debugged until output

* print min and max values

* add some dynamism to the output location

* scale input image

* extract prep function

* extract output path function

* normalize image with magic mean and std

* add spectral coloring

* squeeze in the right place

* make enterpolation optional

* use bail instead of panic

* omit unnecessary Shape call

* remove empty curly braces

* use bail instead of assert

* use vb and pp

* remove closures

* extract config object

* Apply rustfmt.

* Fix some clippy lints.

* More lints.

* Use the array methods.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-06-24 19:12:52 +02:00
6baa1d486b Fix a bug in the metal implemtation of col2im1d. (#2284) 2024-06-22 23:21:20 +02:00
36cf54525d Fix the fast bf16 gemm cublas kernels. (#2274)
* Use flash-attn in gemma.

* Fix for the fast bf16 cublas gemm.

* Fix some clippy lints.

* Fix another lint.

* Proper clippy fix.
2024-06-18 23:46:58 +02:00
2b10aaa05d implement Slice op (#2260) 2024-06-12 07:15:32 +01:00
9f804af29d feat(ci): add trufflehog secrets detection (#2262)
* feat(ci): add trufflehog secrets detection

* fix(ci): remove unnecessary permissions
2024-06-10 21:03:54 +01:00
54ff971e35 Support for the new Qwen2 models. (#2257)
* Support for the new Qwen2 models.

* Add more models.
2024-06-07 10:51:50 +01:00
b9fac7ec00 implement if, and pad reflect mode (#2251)
* implement if, and pad reflect mode

The intent of this change is to allow eval of the current silero_vad.onnx (v4).
This onnx file uses 'If' and 'Pad' nodes, which had not been supported
by simple_eval until now

* Cleanup (fmt, clippy, minor test tweaks).

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-06-06 22:36:23 +02:00
f65e90e7ef Bump the crate version. (#2248) 2024-06-05 15:49:15 +02:00
d39462856b Apply rustfmt. (#2247) 2024-06-04 22:54:09 +02:00
cb180eb23a ONNX: add ArgMin, ArgMax and LeakyRelu (#2246)
* Add basic RandomUniform implementation

* Use is_some to check if seed is present

* Added Exp operator implementation

* Added ArgMin operator implementation

* Added tests for ArgMin

* ArgMin now returns a tensor with i64

* Added tests from pytorch examples

* Added ArgMax operator implementation

* Added tests for ArgMax

* Added LeakyRelu implementation

* Added a test for LeakyRelu

* Typo fix

* Fix a weird automatic RustRover change

---------

Co-authored-by: Mateusz Okulus <mmokulus@gmail.com>
2024-06-04 22:49:02 +02:00
9182c828e6 Automatically upcast for to_u64 (#2244) 2024-06-04 11:32:36 +02:00
3f13ad3d79 Fix dataset id for MNIST (#2238) 2024-06-04 06:27:24 +02:00
cd4d941ed1 Add LLaVA support (#2234)
* first commit

* llava

* clippy and fmt

* some fixes

* minor fixes

* remove useless file

* refactor: Remove llava/constants.rs and update llava/mod.rs

* modify variable name

* modify code after clippy

* Minor tweaks.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-06-03 11:54:09 +02:00
03344d3c19 ONNX: Add Floor and Ceil (#2235) 2024-06-02 21:45:20 +02:00
1ec3b2cc18 add where_cond f32 for metal (#2236) 2024-06-02 14:30:06 +02:00
f7773d498a Deactivate some book test that breaks the CI. (#2233)
* Deactivate some book test that breaks the CI.

* Clippy fix.
2024-06-01 09:44:22 +02:00
7abc3b8cd7 Bump cudarc version to 0.11.4 (#2230) 2024-06-01 08:18:35 +02:00
46012ed31f Another cudarc update. (#2229) 2024-05-30 22:27:06 +02:00
f3fade3b03 Update cudarc to 0.11.2. (#2227) 2024-05-29 18:50:52 +02:00
ea260aeffd Add Debug, Clone, Deserialize to moondream config (#2222) 2024-05-28 06:08:00 +02:00
0814dfd148 Add a metal kernel for col2im1d. (#2214)
* Add a metal kernel for col2im1d.

* Enable the col2im variant.

* Bugfix.

* Revert the quantized tweak.
2024-05-25 11:03:23 +02:00
3ceca9901a Enable the new layer-norm. (#2213)
* Enable the new layer-norm.

* Shape fixes.
2024-05-24 16:48:21 +02:00
1df2bddccf Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels.

* Dedicated layer norm op.

* Add the slower variant.

* Plug the cuda implementation.

* Add the metal variant.

* Add a dedicated test.

* Bugfix.
2024-05-24 15:58:01 +02:00
6f0b807ffd More efficient cuda implementation for ConvTranspose1d. (#2211)
* More efficient cuda implementation for ConvTranspose1d.

* Small tweak.
2024-05-24 11:05:43 +02:00
d54e02d73d Avoid a contiguous call in the quantized phi 3 model. (#2209)
* Simplify the KvCache api.

* Avoid a contiguous call in the quantized phi3 model.
2024-05-23 21:24:55 +02:00
45e235a747 Simplify the KvCache api. (#2207) 2024-05-23 17:07:21 +02:00
31cf64147b Add a couple kv-cache helper functions. (#2206) 2024-05-23 16:21:47 +02:00
77ea479a18 Add Phi-3 Medium (#2205) 2024-05-23 13:33:17 +02:00
72e7ca529a Add some missing where-cond kernels for metal. (#2203) 2024-05-22 09:44:52 +02:00
7ff921c538 Add RandomNormal ONNX operator (#2200) 2024-05-21 21:47:32 +02:00
9b8537a62f Remove the deprecated wav crate in favor of hound. (#2202) 2024-05-21 21:43:35 +02:00
7ebc3548e1 Use flash-attn in gemma. (#2195)
* Use flash-attn in gemma.

* Fix flash-attn for head dim 256.
2024-05-18 19:18:59 +02:00
eefc1c77ef Support flash-attn in quantized phi3. (#2194) 2024-05-18 17:12:56 +02:00
01545f7303 Add a slice_set op. (#2193)
* Add a slice_set op.

* Add some testing.

* Add the dedicated kv-cache module.

* Derive debug and clone.

* Expose more kv-cache functions.

* Return the current data when appending.

* Use the new cache in the quantized phi3 model.
2024-05-18 15:58:18 +02:00
349c3e806a Support embedding model gte-Qwen1.5-7B-instruct (#2190)
* Support embedding model gte-Qwen1.5-7B-instruct

This is a text embedding model based on Qwen2. They share same
model architecture except the last MLP module. This commit brings in
minimal modification of the old Qwen2 implementation to support both
models.

An example is provided, and had been verified according to the official
PyTorch implementation.

* Avoid doing the 'last-token filtering' based on the absence of attention mask.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-05-16 21:34:10 +02:00
bdaa34216a chore: add fix for windows cudarc into the readme (#2189) 2024-05-16 14:32:50 +02:00
cc80e065e5 Allow the threshold argumet to be negative in the segment-anything example (#2187)
Threshold is 0.0 by default, negative values make more points included,
expanding the mask. Positive values make it more picky, making the mask
smaller.

Negative numbers start with a minus sign, which normally makes clap
consider it a flag.
2024-05-15 13:17:20 +02:00
13c64f6828 Fix VarBuilder::from_slice_safetensors (#2180)
Also implement SimpleBackend for SliceSafetensors

Signed-off-by: Harry Stern <harry@harrystern.net>
2024-05-12 07:26:06 +02:00
21f82a5155 Add SliceSafetensors. (#2179)
* Add SlicedSafetensors.

* And add some testing.
2024-05-11 13:15:42 +02:00
9cff7bc3f4 Make it possible to use TF32 accumulation in F32 matmuls. (#2178)
* Allow the use of tf32 accumulation in matmul.

* Better timings.

* Dummy versions for use when cuda is not enabled.
2024-05-11 12:28:39 +02:00
d9bc5ec151 Switch cudarc back to dynamic linking. (#2176) 2024-05-09 10:35:44 +02:00
84328e2b60 Update cudarc requirement from 0.11.0 to 0.11.1 (#2174)
* Upgrading cudarc dependency from v0.11.0 to v0.11.1 due to that version having resolved a compile-time bug.

See: https://github.com/huggingface/candle/issues/2173
2024-05-08 20:40:36 +02:00
82b641fd27 Update cudarc requirement from 0.10.0 to 0.11.0 (#2165)
* Update cudarc requirement from 0.10.0 to 0.11.0

Updates the requirements on [cudarc](https://github.com/coreylowman/cudarc) to permit the latest version.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.10.0...v0.10.0)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* Use the default cuda version.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-05-06 17:12:14 +02:00
01794dc16e Use write rather than try-write on the metal rw-locks. (#2162) 2024-05-05 07:22:46 +02:00
a75cd8164f Force the revision for the phi3-llama quantized models. (#2159) 2024-05-04 10:41:18 +02:00
b13a82a438 Separate quantized phi-3 implementation. (#2157)
* Separate quantized phi-3 implementation.

* Integrate the quantized phi3 model.=

* Small fixes, get the generation to work properly.

* Keep the old llama implementation around.

* Change the default.
2024-05-04 10:14:57 +02:00
59b18d974e Pin the version used for the quantized phi 3 gguf file. (#2156) 2024-05-03 15:03:22 +02:00
89f53b9d7b Bump the version number to 0.5.1. (#2155)
* Bump the version number to 0.5.1.

* Fix clippy lints for 1.78.

* More clippy fixes.
2024-05-03 11:17:05 +02:00
a09d451d11 Support top-k in tthe llama example. (#2150) 2024-05-01 22:25:47 +02:00
fa06f5f5f9 F16/BF16 bugfix (bis). (#2143)
* F16/BF16 bugfix (bis).

* Another fix.

* Yet another fix.
2024-04-29 14:08:44 +02:00
09d4845aa8 Bugfix the recent f16/bf16 changes. (#2142) 2024-04-29 13:30:11 +02:00
a0d03aded1 Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. (#2124)
* When converting a tensor to a variable, clone if the tensor is already a variable.

* Add a test to ensure training a batch norm works with VarMaps

---------

Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
2024-04-29 11:21:53 +02:00
3bbb88fcb4 Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op

* small fix

* add as a method on `Tensor`

* implement gradient calculation for sigmoid

* add sigmoid tests

* we should have a specialized op for this

* fix clippy

* fix clippy 2

* Revert all previous commits in favor of a `CustomOp` based solution

* use `CustomOp1` implementation

* fix rustfmt

* experimental add metal impl

* add cuda kernel impl

* fix fmt

* Add a test + reduce some cuda duplication.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-29 11:04:43 +02:00
ed7b99f525 Add a toggle for F16/BF16 accumulation in gemm. (#2141)
* Add a toggle to control f16/bf16 gemm precision.

* Use the faster variant in the quantized example.

* Bugfix.
2024-04-29 09:21:07 +02:00
287013ef28 Add a forward_via_f16 method to the qmatmul op. (#2138) 2024-04-28 20:35:01 +02:00
eb26e2467e Add the cuda dequantize f16 kernels. (#2137)
* Add the cuda dequantize f16 kernels.

* Expose the cuda kernels.

* Add some testing + fix.

* Test the other cases too.

* A few more tests.

* Add an environment variable to enable the dequantize f16 + matmul behavior.
2024-04-28 20:05:05 +02:00
c68ed8963f chore: fix some typos in comments (#2121)
Signed-off-by: hardlydearly <799511800@qq.com>
2024-04-28 08:34:32 +02:00
e5c8b88f90 Apply the cast before the scaling. (#2135) 2024-04-28 08:30:35 +02:00
805f3be8e1 Add a sort function. (#2134) 2024-04-28 08:18:04 +02:00
3b429f3023 Make the dtype configurable for phi. (#2133) 2024-04-27 21:32:49 +02:00
96a48e5cc4 Add argsort. (#2132)
* Add the argsort cuda kernels.

* CPU version of arg-sort.

* Hook the cuda kernel + rework the cpu bits.

* Add some dedicated test.

* Working cuda kernel.

* Metal kernel.

* Metal adjustments.

* Bugfix.

* Use the fast rope in qwen.

* Rework the expert selection in qwen.
2024-04-27 20:17:35 +02:00
6cf82fd7a3 Add Olmo models (#2127)
* add olmo support

* add olmo readme

* Fix fmt.

* Fix clippy.

* Get olmo to work on cuda.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-26 11:02:51 +02:00
cfab6e7616 Mention phi-v3 in the readmes. (#2122) 2024-04-24 20:54:24 +02:00
11d4a3c588 Add the phi-3 model. (#2120)
* Add the phi-3 model.

* Faster rope.

* Bugfix.

* Fix the detokenization.
2024-04-24 09:48:13 +02:00
9d3f1c8af5 Add the phi-v3 quantized model. (#2118)
* Add the phi-v3 quantized model.

* Also include phi-3 in the main phi example.
2024-04-24 08:22:23 +02:00
7211009179 Fix for rustfmt. (#2117) 2024-04-23 19:09:33 +02:00
6fadaf2eff candle-onnx: add operators RandomUniform and Exp (#2116)
* Add basic RandomUniform implementation

* Use is_some to check if seed is present

* Added Exp operator implementation

---------

Co-authored-by: Mateusz Okulus <mmokulus@gmail.com>
2024-04-23 19:02:19 +02:00
8a05743a21 Add StorageRef. (#2113)
* Add the storage-ref bits.

* Add the metal implementation.
2024-04-23 13:23:27 +02:00
b2e816752b Use the faster rms-norm kernel for llama. (#2107)
* Use the faster rms-norm kernel for llama.

* Use the fast variant by default.
2024-04-22 18:52:00 +02:00
618ecf5e23 Better time measurement for the llama example. (#2106) 2024-04-22 17:54:27 +02:00
267601eec1 Update tokenizers requirement from 0.15.0 to 0.19.1 (#2104)
Updates the requirements on [tokenizers](https://github.com/huggingface/tokenizers) to permit the latest version.
- [Release notes](https://github.com/huggingface/tokenizers/releases)
- [Changelog](https://github.com/huggingface/tokenizers/blob/main/RELEASE.md)
- [Commits](https://github.com/huggingface/tokenizers/compare/v0.15.0...v0.15.2)

---
updated-dependencies:
- dependency-name: tokenizers
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-04-22 17:10:46 +02:00
08a15cb79e Update zip requirement from 0.6.6 to 1.1.1 (#2103)
* Update zip requirement from 0.6.6 to 1.1.1

---
updated-dependencies:
- dependency-name: zip
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* Fix for the zip crate update.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-22 16:23:27 +02:00
c388be93e7 Updated quantized phi model (#2099)
* Quantized phi in a separate file.

* Add the quantized phi example + rework the model code.

* Improve the phi model.

* Get some generation out.

* Use the appropriate rope shape.

* Tweak the default prompt.

---------

Co-authored-by: Jane Doe <jane.doe@example.org>
2024-04-21 07:37:07 +02:00
d22f1d4f4e Derive clone and debug traits for Moondream model (#2100)
* moondream implementation

* add moondream example

* change config default activation

* Add assets and integrate phi mixformer with example

* Make use of kv cache and fix seq_len bug; Clean up example code

* Add README link to example

* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig

* Delete image

* Use apply instead of forward

* Use latest release special token; Fix token/s accuracy; Use GeluPytorchTanh in VisionConfig v2

* Derive debug and clone traits for Moondream model.
2024-04-21 07:08:28 +02:00
0067fe00a8 Metal Unary: Add benchmarks and process kernels in a tile based fashion (#2056)
* add basic unary bench for sqrt

* process unary commands in tiles of 4

* re-enable all benchmarks

* rename helper to unary

* modify approach to split up tiled and non-tiled operations

* undo bench ignore for other tests

* update tile size to 2

* only perform the optimization on the contiguous even numbered element case
2024-04-21 00:10:33 +02:00
587ee3bb6f Small cleanups to the llama multi-process example. (#2098) 2024-04-20 22:19:46 +02:00
dd78422701 Handle multiple dimensions in metal QMM + two fixes. (#2097) 2024-04-20 18:55:45 +02:00
9215e9ce8c Add missing onnx operations (#2096)
* Add missing onnx operations

* Add tests and fix errors

* Run rustfmt
2024-04-20 18:44:22 +02:00
52ae332910 Use llama v3 by default + add to readme. (#2094) 2024-04-20 16:11:24 +02:00
8b390ddd29 Only download the weights in the main process (and not in the child processes). (#2093) 2024-04-20 13:01:23 +02:00
c97d639fa0 Multiprocess/multi-GPU support for llama 3. (#2092)
* Multiprocess/multi-GPU support for llama 3.

* Modernize the mp example a bit.
2024-04-20 12:49:21 +02:00
b45c710dbf Fix for gemma MQA. (#2091) 2024-04-19 21:49:55 +02:00
9c532aef47 Also enable llama-v3 8b instruct. (#2088) 2024-04-19 08:50:06 +02:00
f7a6468238 Add support for llama3 on the quantized example (#2086)
* add support for l3b, new tokenizer

* add todo

* Add todo and use k_s model

* Use the official tokenizers.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-18 22:52:00 +02:00
2b93dffe64 Use faster rotary embeddings for llama like models. (#2087) 2024-04-18 22:34:29 +02:00
e6ee7ba4d4 Llama v3. (#2085)
* Llama v3.

* Tweak the default params + handle special tokens.

* Small tweak.
2024-04-18 22:19:54 +02:00
1690ab45d2 Fix the silu gradient issue on 0. (#2083) 2024-04-18 14:31:41 +02:00
8de0ce6cba Add more QMMV cuda kernels. (#2077)
* Add more QMMV cuda kernels.

* Enable the new kernels.

* Adapt the testing.
2024-04-18 08:36:43 +02:00
ce6d08df94 Minor fix to the readme. (#2080)
Co-authored-by: Jane Doe <jane.doe@example.org>
2024-04-17 22:43:00 +02:00
2817643db9 Add the mmv kernels for small batch sizes. (#2075)
* Add the mmv kernels for smaller sizes.

* Support more mmv kernels.

* Use the new kernels.

* Fix the call.

* Silly fix.

* Improve the testing.

* Fix for dmmv.

* Add another dedicated test for the batching mmv.
2024-04-16 21:30:51 +02:00
4d14777673 Utilize batches in Stable Diffusion (#2071)
* Utilize batches in Stable Diffusion that were already there, but unutilized.

Also refactor out the `save_image` function.

* Clippy + cosmetic fixes.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-16 06:49:04 +02:00
f135b7963d Fix for the batch dim in the quantized matmul example. (#2073)
* Fix for the batch dim in the quantized matmul example.

* Enable more tests on cuda.

* Add a test for qmm with a batch.

* Fix the zeros-dim test on metal.
2024-04-15 20:00:28 +02:00
af955f260c Make the falcon model cloneable. (#2067) 2024-04-15 09:39:03 +02:00
8ad822a983 Add a function to clear the KV cache in falcon. (#2066)
* Add a function to clear the KV cache in falcon.

* Clippy.
2024-04-15 09:29:25 +02:00
e198bb0816 Handle zero dims in some simple operations. (#2064)
* Handle zero dims in some simple operations.

* Handle zero-dims in matmul.

* More testing.
2024-04-15 09:18:54 +02:00
f7d5bf5b97 Faster kernels for quantized matmul on cuda (#2060)
* Hook the quantized matmul cuda kernels.

* Add a (currently broken) test.

* Kernel fixes.

* Fix by transposing the rhs matrix.

* Add the q4-1 kernels.

* Proper block sizes.

* More details in the tests.
2024-04-15 08:32:47 +02:00
c119600d6e Move image tensor to device in trocr example (#2063)
Signed-off-by: Harry Stern <harry@harrystern.net>
2024-04-15 06:50:32 +02:00
c449f65b12 Expose the synchronize function on the generic device. (#2062) 2024-04-14 23:02:03 +02:00
db7dbf3071 Add missing bfloat unary strided kernels and fix typo (#2058) 2024-04-14 20:01:13 +02:00
4ecedb1598 Add the full quantized matmul kernels for cuda. (#2057) 2024-04-14 17:52:08 +02:00
53e5380bf6 Add a synchronize method to devices. (#2055)
* Add a synchronize method to devices.

* Metal version.
2024-04-14 16:32:55 +02:00
50e49ecc5f Add a quantized version of recurrent-gemma. (#2054)
* Add a quantized version of recurrent-gemma.

* Share the rglru part.

* Get the quantized gemma model to work.
2024-04-13 20:07:01 +02:00
4c88c3ce06 Add benchmarks for qmatmul operations (#2048)
* Add qmatmul bench

* add all dtypes
2024-04-13 12:30:14 +02:00
8b8fb630df Add a convenient way to rename tensors accessed through a varbuilder. (#2052) 2024-04-13 12:09:41 +02:00
fb805b8ca2 Avoid crashes when running T5 models with F16 tensors on CPU (#2047)
* This change avoids crashes when running T5 models with F16 tensors on CPU.

* This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point.

* Revert "This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN. You could write more, like: This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point."

This reverts commit d886d3ce5e.
2024-04-13 11:07:28 +02:00
79e3bec789 Change for the encoder-only ProstT5 model (#2045)
* This change avoids crashes when running T5 models with F16 tensors on CPU.

* This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN.  This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point.
2024-04-13 11:06:24 +02:00
e6d412b156 Add ReduceMean onnx operation (#2049)
* Add ReduceMean onnx operation

* Format code with rustfmt
2024-04-13 11:00:25 +02:00
26cbbf8d84 Mandatory topk sampling for recurrent-gemma. (#2051) 2024-04-13 10:31:39 +02:00
2bf413caa3 Add the recurrent-gemma model. (#2039)
* Start adding the recurrent-gemma model.

* More griffin.

* Add the example + get the weights to load from the HF version.

* More inference code.

* Rope + kv-cache on the attention side.

* Add to the inference code.

* Add more to the recurrent gemma inference.

* Get some first inference to run.

* Add the softcap on logits.

* Fixes.

* Use partial rotary embeddings.

* Get inference to work.

* Add a comment.

* And add a readme.
2024-04-13 00:05:21 +02:00
3ad4770eb6 Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation.

* Move the function to utils + use it in mistral.

* Use the shared repeat-kv in a few more models.

* Fix.
2024-04-12 09:15:10 +02:00
a0460cd2b1 Add the code-gemma models. (#2038)
* Add the code-gemma models.

* Tweak to the gemma config.
2024-04-10 21:19:21 +02:00
b81ecf712d Support alternative dtypes for mamba (#2036)
* Allow different dtypes in mamba.

* Add a dtype flag.
2024-04-10 18:10:01 +02:00
a4d5a414e3 Support gather on bf16 for metal. (#2035) 2024-04-10 12:49:25 +02:00
798e0335cd Handle more tensor shapes in onnx "Gather" operation (#2026)
* Handle more tensor shapes in onnx "Gather" operation

* Add more tests

* Add comment

* Fix typo
2024-04-08 14:06:14 +02:00
718671a0d5 Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend.

* More BufferOffset usage.

* Use in where-cond.
2024-04-08 09:37:25 +02:00
c5fe4a7f89 Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module.

* Use the BufferOffset for unary ops.

* Fix clippy lints.

* Use the new BufferOffset.

* Adapt the binary ops.

* Affine.

* More ops (powf, elu, cast).
2024-04-07 22:37:53 +02:00
7f354473cf Optimize copy-2d for metal. (#2024)
* Optimize copy-2d for metal.

* Add a hacky stopping rule for moondream.
2024-04-07 12:34:16 +02:00
33c9b66554 Add the new gemma models. (#2023)
* Add the new gemma models.

* Revert the lightning changes.

* Support for the 1.1 models.
2024-04-06 21:25:38 +02:00
9fd52b3b71 Handle the batch dimension in quantized MMV on metal. (#2022) 2024-04-06 20:02:24 +02:00
e662431acf Fix the final rmsnorm for quantized-metavoice. (#2021) 2024-04-06 19:35:01 +02:00
ab892274d1 first commit (#2018) 2024-04-05 15:20:28 +02:00
b869a659ec Faster mask implementation for mixformers. (#2017)
* Faster mask implementation for mixformers.

* Clippy.
2024-04-05 09:38:26 +02:00
88f7793598 Moondream tracing. (#2016)
* Moondream tracing.

* A bit more tracing.
2024-04-05 09:11:08 +02:00
2ac302a5d1 Add the rope THD kernel. (#2014)
* Add the rope THD kernel.

* Cuda kernel for rope-thd.

* Add the metal kernels.

* Add a dedicated test.
2024-04-05 08:32:58 +02:00
ace282e5c2 Add flag to run Moondream in f16 precision (#2015)
* moondream implementation

* add moondream example

* change config default activation

* Add assets and integrate phi mixformer with example

* Make use of kv cache and fix seq_len bug; Clean up example code

* Add README link to example

* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig

* Delete image

* Use apply instead of forward

* Use latest release special token; Fix token/s accuracy; Use GeluPytorchTanh in VisionConfig v2

* Add flag to use f16

* Avoid breaking the quantized version on cuda.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-05 07:03:33 +02:00
c87381fc96 Use F16 for moondream on cuda. (#2013) 2024-04-04 23:30:10 +02:00
c5626b8271 Add support for "sign" on tensors (#2012)
* add the sign unary operator

* remove uneeded import

* remove uneeded import

* undo formatting

* undo formatting

* remove unnecessary redefintion

* allow gradient to flow through for sign and round

* fix cpu ops to ensure that negzero and positive zero are handled properly

* clippy fixes

* Properly avoid gradient tracking.

* Use a branchless version.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-04 22:32:47 +02:00
e6a5b82ba6 Fix the matmul layout for accelerate & mkl. (#2011)
* Fix the matmul layout for accelerate & mkl.

* Reduce the required precision for pow (because of accelerate).

* And a fix the gelu f16 test.
2024-04-04 19:18:03 +02:00
5aebe53dd2 update dtypes checks for several metal operations (#2010) 2024-04-04 18:39:06 +02:00
f76bb7794a Bumping the version number to 0.5.0. (#2009) 2024-04-04 17:48:45 +02:00
30b145150f Optimize the gelu f16 opt. (#2008)
* Optimize the gelu f16 opt.

* And add a test.
2024-04-04 16:28:23 +02:00
f48c07e242 Include topk sampling in the quantized example. (#2005)
* Include topk sampling in the quantized example.

* Also sample with top-k on the mistral side.
2024-04-04 09:27:54 +02:00
8967c46563 Split the cuda error file. (#2003) 2024-04-04 08:27:23 +02:00
1e46cf8b19 Minor cleanups in reduce.metal. (#2004) 2024-04-04 08:26:02 +02:00
bd8db2a771 refactor to reduce the amount of code wrapped in template syntax (#2002) 2024-04-04 08:13:12 +02:00
318d143224 Relax the contiguous check for cuda kernels. (#2000)
* Relax the contiguous check for cuda kernels.

* Ensure contiguity for RNNs.

* Unrelated fix for segment anything.

* Better error message + allow concatenating empty slices.
2024-04-03 09:02:38 +02:00
2be1a35710 Added link to the Coursera ML algorithm implementations (#1989)
* Added link to the coursera ML algo implementations

* Fixed link
2024-04-03 07:16:32 +02:00
26226068a4 Moondream WASM (#1999)
* moondream wasm wip

* examples, more

* fix eos token check

* README

* cleanip

* cleanup, clippy
2024-04-03 07:11:50 +02:00
cd6b9e317c Add benchmarks for the candle-nn package (#1995)
* add benchmarks for the candle-nn package

* uncomment test

* format
2024-04-03 07:03:54 +02:00
08c049def3 Improve the handling of matmul with squeezed layouts. (#1998)
* Improve the handling of matmul with squeezed layouts.

* Fix for the cuda backend.

* Revert the temporary fix.
2024-04-02 23:17:05 +02:00
d17b2cdad9 Match Moondream's latest release (#1997)
* moondream implementation

* add moondream example

* change config default activation

* Add assets and integrate phi mixformer with example

* Make use of kv cache and fix seq_len bug; Clean up example code

* Add README link to example

* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig

* Delete image

* Use apply instead of forward

* Use latest release special token; Fix token/s accuracy; Use GeluPytorchTanh in VisionConfig v2
2024-04-02 21:37:09 +02:00
fb918a23c8 first commit (#1994) 2024-04-02 16:31:05 +02:00
b23436bf90 Stable diffusion fix. (#1993)
* Stable diffusion fix.

* And add a comment.
2024-04-02 14:36:28 +02:00
be9c200cbb Expose the t5 config fields + allow t5-large. (#1987) 2024-04-01 20:58:34 +02:00
ea0d8d3753 Quantized moondream implementation and BOS token (#1980)
* moondream implementation

* add moondream example

* change config default activation

* Add assets and integrate phi mixformer with example

* Make use of kv cache and fix seq_len bug; Clean up example code

* Add README link to example

* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig

* Delete image

* Use apply instead of forward

* Pass bos token at the beginning of tensor.

* Quantize moondream.

* Forward with image bos token.

* Clippy.

* Use q4_0 quantization.

* Add pointers for sequence and tokens; Remove seq_len conditional
2024-04-01 19:37:54 +02:00
308ea070ed modify access for conv and op to be pub to allow external packages to have custom backends (#1986) 2024-04-01 17:44:49 +02:00
b20acd622c Update for pyo3 0.21. (#1985)
* Update for pyo3 0.21.

* Also adapt the RL example.

* Fix for the pyo3-onnx bindings...

* Print details on failures.

* Revert pyi.
2024-04-01 17:07:02 +02:00
5522bbc57c Add fn 'get_with_hints_dtype' in VarBuilder (#1877) (#1897)
* quantized models(awq/squeezellm/...) have multiple data type tensors, use 'get_with_hints_dtype' to load tensors with given dtype
2024-04-01 12:10:08 +02:00
888c09a3db add identity op (#1976) 2024-04-01 12:08:25 +02:00
318cb82f16 Quantized cuda tweaks. (#1981)
* Quantized cuda tweaks.

* Add some safety checks.

* Factorize the dequantization bits.
2024-04-01 11:06:42 +02:00
c7557b65dc Switch the default to using the faster kernels. (#1978)
* Switch the default to using the faster kernels.

* Add the force-dmmv flag.
2024-04-01 10:00:11 +02:00
cd29c7ccd4 More ggml cuda kernels (#1977)
* Add more cuda kernels for quantized matmul.

* Add the vec-dot bits.

* Expose the quantized matmul-vec kernels.

* Also include the quantize-q8-1 kernel.

* Glue code for the q8-1 quantization.

* mm-vec product via q8-1 quantization.

* Add a test.

* Add a mm test.

* Get the test to return some sensible results.

* Also test dmmv.

* Fix the launch params.

* Allow for tweaking the force_dmmv parameter while it's experimental.
2024-04-01 00:15:48 +02:00
f9954b73ba Add options to use local files + specify a custom repo or branch. (#1973) 2024-03-31 09:32:50 +02:00
eead1dcead Clippy fix. (#1972) 2024-03-31 08:57:40 +02:00
92f81d2fcb Add Moondream transformer implementation and example (#1970)
* moondream implementation

* add moondream example

* change config default activation

* Add assets and integrate phi mixformer with example

* Make use of kv cache and fix seq_len bug; Clean up example code

* Add README link to example

* Remove pos_embed scaling; Remove assets; Add to README; Expand VisionConfig

* Delete image

* Use apply instead of forward
2024-03-31 08:54:56 +02:00
3144150b8d Move the tensor-tools binary in a separate crate. (#1969) 2024-03-30 15:49:37 +01:00
b190fd8592 Remove some unnecessary calls to contiguous. (#1968)
* Remove some unnecessary calls to contiguous.

* Slightly improved kv cache concatenation.
2024-03-30 13:22:00 +01:00
efe4a0c84b Add a print command to tensor-tools. (#1967)
* Add a print command to tensor-tools.

* Add some flags to tweak the formatting.
2024-03-30 11:34:33 +01:00
665da30487 Backend refactoring. (#1966)
* Backend refactoring.

* Metal tweaks.

* Move the cudnn module.
2024-03-29 23:02:11 +01:00
356a170ae9 Update parquet requirement from 50.0.0 to 51.0.0 (#1867)
Updates the requirements on [parquet](https://github.com/apache/arrow-rs) to permit the latest version.
- [Changelog](https://github.com/apache/arrow-rs/blob/master/CHANGELOG-old.md)
- [Commits](https://github.com/apache/arrow-rs/compare/50.0.0...50.0.0)

---
updated-dependencies:
- dependency-name: parquet
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-03-29 21:58:15 +01:00
7ecbc6d50b fix minor typo (#1924) 2024-03-29 18:09:57 +01:00
8ad12a0e81 Add some examples using the MT5 variants. (#1963) 2024-03-29 18:09:29 +01:00
eb1b27abcd Readme fix. (#1961) 2024-03-28 23:24:46 +01:00
708e422456 Qwen MoE model. (#1960)
* Qwen MoE model.

* Add the MoE model to the example.

* Fix the scaling.

* Readme updates.

* Readme tweaks.
2024-03-28 23:10:57 +01:00
c5092f2c29 Add a couple t5 models. (#1958) 2024-03-28 17:58:06 +01:00
cdc8b57b5c Fix clippy lints + minor cleanups. (#1957)
* Fix clippy lints + minor cleanups.

* fmt.

* Derive clone.
2024-03-28 14:17:46 +01:00
b0340d72ec CLIP model implementation with example (#1950)
* CLIP model implementation with example

* CLIP Implementation fixes, batch images

* CLIP model remove images from git

* CLIP model remove unnecessary use of batch_indices
2024-03-28 13:44:12 +01:00
b3484e7a5e Fix for the RWKV models. (#1955)
* Fix for the RWKV models.

* More general fix + revert the rwkv hack.

* Remove the old hack.
2024-03-28 10:17:38 +01:00
ada5d7c096 add send and sync trait bounds for scheduler config in stable diffusion models (#1952)
* first commit

* add Sync deriving

* static

* remove static
2024-03-28 10:03:00 +01:00
13ae5a34c7 Ensure that the kernels get rebuilt on cuh changes. (#1954) 2024-03-28 06:56:48 +01:00
ab86cd37c8 Support i64 in index-select on metal. (#1951)
* Support i64 in index-select on metal.

* Add some testing of index-select for all dtypes.
2024-03-27 16:30:07 +01:00
a9abde5f93 More flexible matmul contiguity checks. (#1949)
* More flexible matmul contiguity checks.

* Also relax the checks on the metal side.
2024-03-27 10:59:05 +01:00
75b6d4b0da add config for mamba 2.8b model parameter (#1946)
* first commit

* Make the mamba config public.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-03-27 07:47:23 +01:00
66f0a4eeea Another fix for squeezing. (#1943) 2024-03-26 17:05:26 +01:00
4523ecfb2a Faster repeat penalty (#1940)
* Avoid the attention mask where possible.

* Faster repeat penalty.
2024-03-26 11:31:20 +01:00
f5dfe883d7 Extend supported dtypes for metal (im2col & upsample_2d) (#1938)
* update im2col dtype implementations

* update dtypes for upsample
2024-03-26 06:48:56 +01:00
196765e995 Use the new rope kernel in mistral. (#1937)
* Use the new rope kernel in mistral.

* Compute the cos and sin with full precision.

* Bugfix.
2024-03-25 23:26:05 +01:00
60676780a9 Fix detail in new RoPE implementation (#1935) 2024-03-25 18:20:09 +01:00
d3a8d291d5 Avoid the attention mask where possible. (#1933) 2024-03-25 15:31:04 +01:00
cd254074f3 Really unique identifier for metal device ids. (#1932)
* Really unique identifier for metal device ids.

* Same device.
2024-03-25 11:48:16 +01:00
e7f8e72588 Contiguous variant of the rope kernel. (#1929)
* Contiguous variant of the rope kernel.

* Add the cuda kernel.

* Metal kernel.
2024-03-25 09:11:20 +01:00
1b98f84a2b Fast kernels for rotary embeddings. (#1928)
* Fast kernels for rotary embeddings.

* Add a test for the fast CPU kernel.

* Rope cuda bindings.

* Cuda kernel.

* Metal kernel (part 1).

* Cuda kernels.

* Finish the metal kernel.

* Use the new kernels in the quantized example.

* Fix warning.
2024-03-24 22:48:52 +01:00
cf7d7fcf2f Also avoid the mask in the llama example. 2024-03-24 19:04:32 +01:00
8c0db87992 Avoid using the attn mask when not necessary. 2024-03-24 18:55:56 +01:00
e2b4829531 Support more mistral models. (#1927)
* Support more mistral models.

* Use the appropriate rope parameter.
2024-03-24 08:04:04 +01:00
5e70821dd0 Allow for arbitrary temperature modifications. 2024-03-23 15:47:39 +01:00
a62a97340c Add topk sampling. (#1923) 2024-03-23 15:26:09 +01:00
fdfe8fd129 Preliminary support for inplace ops. (#1921)
* Preliminary support for inplace ops.

* Add a test.
2024-03-23 14:16:19 +01:00
790037390c Add cast_bf16_x/cast_x_bf16 when CUDA_ARCH<800 but CUDA_VERSION >= 11000 (#1919)
- it make possible to load bf16 models on T4(sm75)
2024-03-23 13:44:10 +01:00
6f877592a7 Avoid broadcasting on the batch dimension for the attention mask. (#1920) 2024-03-23 13:08:53 +01:00
cc856db9ce Backwards for ConvTranspose2D (#1910)
* add documentation  for nackprop

* add backwards for ConvTranspose2D

* add test python code to test
2024-03-23 07:05:55 +01:00
fc1fe5e45b Support scatter/index_add with i64 indices for f16 (#1915) 2024-03-22 11:51:41 +01:00
32f567bac4 Fix loading the gguf files. (#1913) 2024-03-22 10:28:38 +01:00
fee33b45c2 Add support for strided index-select on Metal (#1909)
* initial implementation

* use correct index, but still not breaking like it should have...

* fix test
2024-03-22 07:30:02 +01:00
6708870e63 Add the alloc_uninit function. (#1901)
* Add the alloc_uninit function.

* Dummy metal fix.

* Lazy initialization.
2024-03-22 07:25:23 +01:00
a00e24d752 Improve the error message on overlong prompts. (#1908) 2024-03-21 21:08:07 +01:00
c07e4057ab Fix for the llama model. (#1906) 2024-03-21 19:36:10 +01:00
c0bdd9c7a6 Use the fast RmsNorm in the quantized model. (#1904) 2024-03-21 18:49:35 +01:00
9563a5fee4 Add support for conv_transpose2d on Metal backend (#1903)
* add support for conv transpose 2d and add bench mark for float types

* update bench calculation

* enable testing all conv operations on metal
2024-03-21 18:08:45 +01:00
ec97c98e81 Async tensor copying. (#1900) 2024-03-21 13:09:42 +01:00
bb3ee48039 whisper readme (#1899) 2024-03-21 12:54:09 +01:00
0c11e055be support distil-large-v3 (#1898) 2024-03-21 11:46:49 +01:00
18036c6ccb Update the image crate + use the re-exported version. (#1893)
* Update the image crate + use the re-exported version.

* Update to using ab_glyph.
2024-03-21 10:56:41 +01:00
0fddec762e RmsNorm kernel for metal. (#1895)
* RmsNorm kernel for metal.

* Wrapper for the metal kernel.

* Get the ops to actually work.

* Fix, get the tests to pass.
2024-03-21 09:48:56 +01:00
74b7f59261 Prepare for the custom-op extension. (#1892) 2024-03-21 07:02:20 +01:00
af7f8b87d3 Custom op for RmsNorm (#1890)
* Trying out a custom RmsNorm cuda kernel.

* CPU implementation for rms-norm.

* Cuda wrappers.

* Add some validation.

* Add some testing.

* More testing.
2024-03-21 06:36:28 +01:00
b219903d0f Cuda backend optimization (#1886)
* Attempt at making the kernel faster.

* Also adapt the cast kernels.

* Also apply to binary ops.
2024-03-20 18:32:55 +01:00
469635a3eb Minor cleanup. (#1885) 2024-03-20 14:38:27 +01:00
455c42aa72 Avoid copying the data on squeeze and unsqueeze. (#1884)
* Avoid copying the data on squeeze and unsqueeze.

* Fix the quantized llama example.

* Unrelated fix for the quantized stable-lm example on cuda.

* Fix for mamba on cuda (unrelated to the PR).
2024-03-20 13:04:36 +01:00
2a8679509e Add support for conv_transpose1d for metal backend (#1874)
* first attempt

* progress

* integrate into metal backend

* finish and get test passing

* add other dtype support

* update transpose1d dtypes supported
2024-03-19 08:46:58 +01:00
143c481c20 Expose candle gather op in pyo3. (#1870) 2024-03-18 21:54:15 +01:00
f115895b9e Apply rustfmt. (#1873) 2024-03-18 21:43:31 +01:00
90fc82211f Use a common with_tracing::RmsNorm in a few models. (#1871)
* Add RmsNorm with tracing.

* Use with_tracing::RmsNorm in some models.
2024-03-18 21:40:06 +01:00
6a966cf9e0 Add a DQN example to the reinforcement-learning section (#1872) 2024-03-18 21:22:53 +01:00
04a61a9c72 Add avg_pool2d metal implementation for the metal backend (#1869)
* implement metal avg pool 2d

* fixX

* add suggested precision workaround for the accumulator
2024-03-18 18:50:14 +01:00
58605252e8 Microphone support for the encodec example. (#1866) 2024-03-18 11:19:46 +01:00
d365ef32d9 Improve the encodec example: handle resampling. (#1865)
* Improve the encodec example: handle resampling.

* Play the audio directly.
2024-03-18 10:09:40 +01:00
754fa1e813 Add support for max_pool2d for Metal backend (#1863)
* first pass at implementation of maxpool2d

* Add definitions for other dtypes

* add tests for other dtypes

* Cosmetic tweaks + re-enable maxpool2d tests for metal.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-03-18 08:33:30 +01:00
184105792f add test for index add and add missing match statements (#1862) 2024-03-17 22:19:12 +01:00
a15f859ab4 Fix for the encodec example. (#1861) 2024-03-17 21:15:12 +01:00
e316cb6997 add support for casting between all datatypes (#1860) 2024-03-17 20:55:11 +01:00
ce9fbc3682 Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d.

* Move the cat operations.

* Avoid transpositions in cat.

* Bugfix.

* Bugfix for the cuda kernel.

* Add a benchmark.

* Add more testing.

* Test fix.

* Faster kernel.

* Add the missing kernel.

* Tweak the test.

* Add a metal kernel.

* Fix for the metal kernel.

* Get the tests to pass on metal.

* Also use this opportunity to fix the metal kernel for ELU.

* Add some bf16 kernels.

* Clippy fixes.
2024-03-17 10:49:13 +01:00
db8b24ae92 Add support for index u8/i64 and input f16/bf16 scatter-add on metal (#1849)
* add support and tests for scatter add on metal

* add support for all datatypes
2024-03-17 08:09:43 +01:00
74bf6994b1 Move the image tensor to the appropriate device. (#1856) 2024-03-16 22:25:46 +01:00
cdc4c172c4 Implement the error trait for DTypeParseError. (#1852) 2024-03-15 08:37:27 +01:00
e1f9c3776d StableLM-2 models were updated to use GPT-2 tokenization. (#1847) 2024-03-14 21:01:36 +01:00
3318fe30fb Update gemma README (#1843)
* Update gemma README

* Fixit
2024-03-13 21:41:36 +01:00
2bb9c683b9 Update README.md (#1840)
Adds the candle-einops to the readme as an external resource
2024-03-13 14:36:25 +01:00
ff03fd3fb3 Expose some helper functions to create quantized models. (#1837) 2024-03-12 11:30:24 +01:00
df5f69444e Properly handle the batch dimension in cuda quantized matmul. (#1832) 2024-03-10 20:23:43 +01:00
0c5eecbc0f Add some tracing to metavoice. (#1826) 2024-03-09 12:24:11 +01:00
56c9d3ee7b Fix the model path for rwkv. (#1825) 2024-03-09 11:21:48 +01:00
dd00482ea3 Quantized version of the metavoice model. (#1824)
* Quantized version of the metavoice model.

* Integrate the quantized version of metavoice.
2024-03-09 11:06:04 +01:00
936f6a4840 Fix dequantization. (#1823) 2024-03-08 23:12:13 +01:00
3440cec3a0 Fast CPU kernel for transposed 1d convolutions. (#1822)
* Fast CPU kernel for transposed 1d convolutions.

* Bugfix.
2024-03-08 22:43:07 +01:00
e7fc1daa21 Bump the crate versions to 0.4.2. (#1821) 2024-03-08 22:01:51 +01:00
be5b68cd0b Metal random-generation bug fixes (#1811)
* use_resource API misunderstood. It is not additive. Several usages must be bit-ORed together.

* The seeding was incorrect and used the address instead of the value of the passed in seed.

* Add a check that likely exhibits failure to update the seed between generation of random tensors.

* Buffer overrun, the length given to the std::ptr::copy call was in bytes, and not 32-bit units.

* By default seed the RNG with a time-based value, so that different runs may produce different output, just like the CPU engine.
Use device.set_seed if determinism is warranted.

* Revert "By default seed the RNG with a time-based value, so that different runs may produce different output, just like the CPU engine. Use device.set_seed if determinism is warranted."

This reverts commit d7302de9

Discussion in https://github.com/huggingface/candle/pull/1811#issuecomment-1983079119

* The Metal random kernel failed to set element N/2 of tensors with N elements, N being even.  The reason was that all threads but thread 0 all created 2 random samples, but thread 0 only one, i.e. an odd number.  In order to produce an even number of samples, the early termination of thread 0 should only everr occur for odd sized tensors.

* Add a test catching any deterministic tensor element in rand and randn output.

---------

Co-authored-by: niklas <niklas@appli.se>
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2024-03-08 16:11:50 +01:00
ea984d0421 Expose more printer options. (#1817) 2024-03-08 15:04:18 +01:00
9634583781 Expose a couple layout methods. (#1816) 2024-03-08 10:52:22 +01:00
758366160e add clone to candle dropout (#1814) 2024-03-08 08:18:01 +01:00
0a3487a776 Add a --seed argument to the stable-diffusion example. (#1812)
* Add a --seed argument to the stable-diffusion example.

* Make the case when no seed is specified, that it will not be set, but use the engine's default.  This will make the CPU engine work again when no --seed is given, and will cause a bailout when a seed is there, as the engine does not currently support it.

---------

Co-authored-by: niklas <niklas@appli.se>
2024-03-08 08:17:36 +01:00
0c09d10f32 Improve metal buffer usage (#1807)
* Improve metal buffer usage

* Clone cpu storage when loading to reduce wait_until_complete calls
* Use powers of two for buffer sizes so reuse is more likely.
* Select best available buffer by size.
* Add count to MetalStorage -> can use buffer with different size

Co-authored-by: Chris Fleetwood <christopher.fleetwood@huggingface.co>

* Simplify new buffer creation without blit copy. Revert &[] -> Vec

* Add documentation on newBufferWithBytes safety / synchronization

* Drop unused buffers after command buffer is done syncing.

---------

Co-authored-by: Chris Fleetwood <christopher.fleetwood@huggingface.co>
2024-03-07 09:42:34 +01:00
8a99cf7dd2 Add a flag to select the dtype used in metavoice. (#1805) 2024-03-05 12:16:00 +01:00
bd9ab9bc04 Add a cuda kernel for dequantizing q8_0. (#1804) 2024-03-05 09:50:37 +01:00
8cc0a183ba Speaker embeddings computation for metavoice. (#1800)
* Speaker embeddings computation for metavoice.

* Compute the speaker embeddings.
2024-03-04 14:13:01 +01:00
6530932285 Add the new models to the main readme. (#1797) 2024-03-03 16:25:14 +01:00
924ccae30c Add an initial Segformer implementation (#1617)
* add segformer

* Make the id2label field optional.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-03-03 16:01:46 +01:00
60dc72b96b More metavoice tweaks. (#1796) 2024-03-03 15:05:25 +01:00
20abb72fec Normalize loudness of the generated audio (#1795)
* Normalize loudness of the generated audio.

* Lints.

* One more lint.

* Avoid running the bs1770 tests.

* Another attempt at discarding doc comments.

* Also normalize the loudness in the encodec example.
2024-03-03 14:00:42 +01:00
ca5d727ba2 Use the same padding in metavoice as in the python version. (#1794) 2024-03-03 12:04:48 +01:00
09e0148cce Tweaks to run metavoice on metal (#1792)
* Enable tanh + tweak conv-transpose.

* Run the encodec decoding on cpu.

* Clippy fixes.
2024-03-03 07:46:44 +01:00
de11623752 Metavoice position fix (#1791)
* Add the metavoice transformer.

* Sketch the speaker-encoder module.

* Adding to the metavoice model.

* Start adding the metavoice example.

* Get some logits out.

* Load the second stage model.

* Get the second step to run.

* Tweak the example.

* Add encodec tilting.

* Glue the different bits together.

* Fix a shape issue.

* Use a constant.

* BPE tokenization.

* Fix the position index in metavoice.
2024-03-02 21:00:35 +01:00
21f1d04976 Add the instruction finetuned gemma variants. (#1790) 2024-03-02 18:56:59 +01:00
4fff5b51f5 Metavoice - first cut (#1717)
* Add the metavoice transformer.

* Sketch the speaker-encoder module.

* Adding to the metavoice model.

* Start adding the metavoice example.

* Get some logits out.

* Load the second stage model.

* Get the second step to run.

* Tweak the example.

* Add encodec tilting.

* Glue the different bits together.

* Fix a shape issue.

* Use a constant.

* BPE tokenization.

* Add a warning.
2024-03-02 18:50:01 +01:00
314630638d Rustfmt fix. (#1788) 2024-03-02 10:35:07 +01:00
3e3def4134 Update StableLM config (#1787) 2024-03-02 09:56:57 +01:00
6980774a91 fix rwkv example eos token (#1785) 2024-03-01 10:22:28 +01:00
64d4038e4f Mention rwkv v6 in the readmes. (#1784) 2024-03-01 08:58:30 +01:00
979deaca07 EfficientVit (MSRA) model (#1783)
* Add EfficientVit (Microsoft Research Asia) model.

* Mention models in README
2024-03-01 08:53:52 +01:00
b485e4b6ee add models of rwkv v6 and quantized rwkv v6 (#1781)
* add models of rwkv v6 and quantized rwkv v6

* fix ci clippy fail
2024-03-01 08:37:56 +01:00
2c95b7394a Handle Q5_0 and Q5_1 quants in cuda. 2024-02-29 10:54:01 +01:00
4fd00b8900 Add the StarCoder2 model. (#1779)
* Add the StarCoder2 model.

* Add the example code and get things to work.

* And also tweak the readme.
2024-02-28 21:02:41 +01:00
57267cd536 Add a flag to force running the quantized model on CPUs. (#1778)
* Add a flag to force running the quantized model on CPUs.

* Add encodec to the readme.
2024-02-28 14:58:42 +01:00
60ee5cfd4d Support more modes in the encodec example. (#1777)
* Support more modes in the encodec example.

* Remove the old encodec model from the musicgen bits.
2024-02-28 09:22:33 +01:00
56e44aabe3 Make some dependencies optional in the examples. (#1776) 2024-02-28 07:17:03 +01:00
d0aca6c3c6 Encodec encoding demo. (#1775) 2024-02-28 06:49:03 +01:00
15e8644149 Apply dilations in the encodec model. (#1772)
* Apply dilations in the encodec model.

* Add some encoding bits.
2024-02-27 23:26:35 +01:00
0c49e95dfb Encodec model. (#1771)
* Encodec model.

* Fixes.

* Add the padding functions.

* Get the LSTM bit to work.

* Get the encodec model to generate some tokens (decoder only for now).

* Minor tweak.

* Minor tweak.
2024-02-27 22:59:40 +01:00
205767f9de Avoid tensor copying in the quantized example. (#1770) 2024-02-27 20:32:30 +01:00
5e526abc8c Bump the version number to 0.4.1. (#1768)
* Fix the block size for some cuda kernels.

* Bump the version number to 0.4.1.
2024-02-27 14:19:59 +01:00
6400e1b0a0 Fix the block size for some cuda kernels. (#1767) 2024-02-27 14:08:33 +01:00
32544a2ad6 Add an option to split the prompt. (#1766) 2024-02-27 11:24:11 +01:00
badf886583 Cuda kernel for dequantizing q8k. (#1760)
* Cuda kernel for dequantizing q8k.

* Clippy lints.
2024-02-26 08:42:44 +01:00
918136ba46 add quantized rwkv v5 model (#1743)
* and quantized rwkv v5 model

* Integrate the quantized rwkv model in the initial example.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-02-25 21:43:40 +01:00
1a6043af51 Tweak the VarMap set type. (#1758) 2024-02-25 20:50:08 +01:00
2f22afd80e Cuda acceleration for quantized model. (#1754)
* Boilerplate for the quantized cuda support.

* More basic cuda support.

* More cuda quantization (quantize on cpu for now).

* Add the dequantization bit.

* Start adding some dedicated cuda kernels from llama.cpp.

* Move the kernel code.

* Start interfacing with the kernel.

* Tweak the kernel launch params.

* Bugfix for quantized metal.

* Fix some clippy lints.

* Tweak the launch parameters.

* Tweak cuda basics to perform a quantized matmul.

* Perform the dequantization on the cpu + use cublas for matmul.

* Add the dequantization kernel.

* Test the qmatmul.

* More kernels.

* Matmul-vec kernel.

* Add a couple kernels.

* More dequantization kernels.
2024-02-25 18:11:47 +01:00
8d04f70f4d Fix the eos token for gemma. (#1753) 2024-02-24 11:07:02 +01:00
eeb7e2b683 Apply rustfmt to the newly added tests. (#1749) 2024-02-23 06:48:28 +01:00
11ea7aac4d tests (#1724) 2024-02-23 06:35:46 +01:00
32eb56d6b3 Fix typo in README (#1740) 2024-02-22 12:35:26 +01:00
28057781aa Make the cache for the llama model explicit too. (#1745) 2024-02-22 12:04:33 +01:00
544018b6d0 Explicit caching in llama2.c. 2024-02-22 10:22:03 +01:00
c753f72c85 Support for attention bias in gemma + refactor things a bit. (#1744)
* Support for attention bias in gemma + refactor things a bit.

* Fix the cuda tests.
2024-02-22 09:35:28 +01:00
8013b50829 Add grads for interpolate1d (#1742)
* add backprop for interpolate1d

* fix clippy lint

* correct fix clippy lint
2024-02-22 08:44:01 +01:00
45d5322d62 Add the Gemma models. (#1741)
* Add the Gemma models.

* Add the gemma example.

* Adapt the RmsNorm.

* Get the 2b model to work.

* 7b support.

* Use the config head dim.

* Yet another fix.

* Make the matrixes contiguous.

* Also get the 7b model to work.

* And add to the readme.
2024-02-21 22:02:50 +01:00
a2cb2edead Add a couple backtraces on cpu errors. (#1738) 2024-02-20 19:54:13 +01:00
fc67d878bb Bugfix for conv-transpose1d (#1734)
* Add a currently broken test.

* Bugfix + fix test.
2024-02-19 09:04:49 +01:00
3ba37443e5 Bugfix for applying the bias in conv1d-transpose. (#1732) 2024-02-18 22:51:20 +01:00
1fb728772d Support for groups in conv-transpose1d. (#1731)
* Groups support in conv-transpose-1d.

* Remove dangling file.
2024-02-18 21:28:07 +01:00
cb86b0c82c Fix float unpickling. (#1730) 2024-02-18 19:33:55 +01:00
6284ad784c Module implementation for options. (#1728) 2024-02-18 14:12:55 +01:00
678d44a7f6 Expose the weights and biases in transposed convolutions. (#1727) 2024-02-18 10:35:01 +01:00
41416d2376 Expose more conv1d functions/structs. (#1726) 2024-02-17 18:50:55 +01:00
5ebcfeaf0f Make the r, k, v tensors contiguous. (#1719) 2024-02-16 09:17:35 +01:00
7c7400fb63 Use the tokenizer-output-stream in the llama example. (#1715)
* Use the tokenizer-output-stream in the llama example.

* Also use tokenizer-output-stream for llama2-c.
2024-02-15 16:47:33 +01:00
058a910d0e Add a readme for rwkv. (#1712) 2024-02-14 15:31:33 +01:00
26fe162ab5 Custom tokenizer for rwkv. (#1711)
* Custom tokenizer for rwkv.

* Custom tokenizer.

* Getting the tokenizer to work.
2024-02-14 15:11:38 +01:00
121a71e01f Fix the silu cuda kernel. (#1710) 2024-02-14 11:08:18 +01:00
2d5f2a728d Add the RWKV model (v5). (#1707)
* Start adding the RWKV model.

* More of the forward step.

* Handle rescaling.

* FeedForward.

* More work on RWKV.

* Better state tracking.

* Finish a first pass on forward.

* Fix the shape mismatches.

* Do not rescale in f32.

* Rename to rwkv-v5.

* Add the new models to the readme.
2024-02-14 10:58:32 +01:00
68f7655895 Add ConvNeXt-V2 and smaller model variants. (#1709) 2024-02-14 10:53:07 +01:00
b60064780d feat: add silu activation function (#1706)
* feat: add silu activation function

* use silu/arg in grad

* update candle-nn

* use node
2024-02-14 10:27:22 +01:00
14010a8498 Update our cuda runner. (#1705)
* Update our cuda runner.

* Fix install rust.

* Simplify.

* Docker in docker.

* Install curl

* Install curl

* No sudo.

* devel

* Put curl again.

* Add missing deps.

* pkg-config.

* Cleanup.
2024-02-13 19:06:15 +01:00
0de0795220 Qmetal tweaks (#1704)
* Add the dummy qmetal backend.

* Fix the metal compilation.
2024-02-13 18:11:17 +01:00
c1b418586c Fixing quantized llama demo on metal. (#1703) 2024-02-13 16:28:56 +01:00
ad73e93da2 Detach the tensors on batch-norm eval. (#1702)
* Detach the tensors on batch-norm eval.

* Fix pyo3 bindings.

* Black tweak.

* Formatting.

* Also update the pyo3-onnx formatting.

* Apply black.
2024-02-13 14:26:32 +01:00
13c67226e6 feat: support microphone whisper streaming (#1678)
* feat: support microphone whisper streaming

* fix: cleanup print stmts and adjust how input is read

* fix: remove incorrect comment

* feat: split into new example and simplify

* fix: feature flag example file

* fix: fmt fixes

* feat: simplify and remove redundant files
2024-02-12 18:01:21 +01:00
d0aa197b07 ConvTranspose1d cuda support. (#1697)
* ConvTranspose1d cuda support.

* Add the conv-transpose1d kernel.

* Remove some unused variables.
2024-02-12 15:03:18 +01:00
274bf11633 Support defaultdict in PyTorch checkpoints. (#1696)
* Support defaultdict in PyTorch checkpoints.

* Fix clippy lint.
2024-02-12 10:26:56 +01:00
1e26d539d9 Improved mamba model optimized for inference (#1694)
* Sketch the mamba model for inference.

* Complete the forward pass.

* Add the mamba example.

* Optimize the selective-scan part.

* Fix a couple shape mismatches and get inference to work.

* Tweak the readmes.

* More readme tweaks.
2024-02-11 17:04:57 +01:00
74497e6bf7 Fixing the qwen tokenizer location. (#1693)
Using the chatglm one causes a bug where the "<|endoftext|>" is not
found.
2024-02-11 08:52:36 +01:00
8ab384e63d docs: add trocr examples (#1692) 2024-02-10 16:14:50 +01:00
27ffd644a9 Mention TrOCR in the readmes. (#1691) 2024-02-10 15:49:38 +01:00
bf20cc854c Support sinusoidal embeddings in trocr. (#1690)
* Support sinusoidal embeddings in trocr.

* Support tie-word-embeddings.
2024-02-10 15:17:51 +01:00
42ce593ec6 Use the repo config for trocr rather than hardcoding it + small tweaks. (#1689)
* Use the repo config for trocr rather than hardcoding it + small tweaks.

* Add support for the printed models.

* Fail with an appropriate error message on missing position embeddings.
2024-02-10 13:15:03 +01:00
67589791d2 Remove the unused pragma in vit + handle the final layernorm. (#1688) 2024-02-10 11:08:50 +01:00
1c8d61f051 ChatGLM custom tokenizer. (#1687) 2024-02-10 10:47:04 +01:00
90447bc993 Add the custom tokenizer. (#1686) 2024-02-09 17:36:50 +01:00
40ce16001b Use the proper endoftext token for gwen. (#1685) 2024-02-09 17:02:03 +01:00
5657e596cd Add the Qwen2 model (#1684)
* Initial check-in for the qwen2 model.

* More qwen2 inference.

* Polish the qwen example.

* Fix the rope basis.

* Get the inference to work.

* Support different model sizes.
2024-02-09 15:02:49 +01:00
0dee8ea19b Add the ChatGLM model. (#1237)
* Add the ChatGLM model.

* Rotary embeddings.

* Add to the forward pass.

* Add to the forward pass.

* Add the rotary embeddings.

* Add the KV cache.

* Add the chatglm example.

* Bugfix.

* More glm fixes.

* Fix some shape issues.

* Get the inference to work.
2024-02-09 11:51:38 +01:00
9cadd4e644 feat: support multithread spectrogram and small perf tweaks (#1674)
* feat: support multithread spectrogram and small perf tweaks

* feat: clippy improvement for loop variable

* fix: add back speed up scale down logic

* fix: readd mirroring logic

* feat: prefer scoped thread and simplify/improve logic/traits
2024-02-08 21:54:12 +01:00
020a979de2 Fix clippy lints for 1.76. (#1682) 2024-02-08 16:48:47 +01:00
cdc3823d8f Pickle support: dig within the _rebuild_parameter calls. (#1681) 2024-02-08 13:09:49 +01:00
e5eb9602d0 Add support for loading Fortran contiguous tensors (#1672)
* Add support for loading Fortran contiguous tensors

This commit introduces the ability to handle Fortran contiguous tensors in the tensor loading process. Previously, the code only supported loading tensors that were contiguous in memory, failing with an error for non-contiguous tensors. With this update, tensors identified as Fortran contiguous (column-major order) are now correctly handled by reversing their dimensions after loading. This enhancement ensures broader compatibility with different tensor layouts, improving the robustness of tensor loading operations.

- Check if a tensor is Fortran contiguous using the `is_fortran_contiguous` flag.
- For Fortran contiguous tensors, reverse the dimensions after loading to correctly represent their layout in memory.
- Continue to bail out with an error for tensors that are neither C contiguous nor Fortran contiguous, maintaining the previous behavior for non-contiguous tensors without explicit support.

This change addresses the issue of loading Fortran contiguous tensors, which was previously unsupported, thereby extending the functionality of the tensor loading mechanism to accommodate a wider variety of tensor layouts.

* Add reshape step to handle fortran contiguous case

* Skip fortran contiguous fix if rank is < 2

* Fail on rank 0, 1 if contiguous
2024-02-07 21:49:59 +01:00
b75e8945bc Enhance pickle to retrieve state_dict with a given key (#1671) 2024-02-06 21:17:33 +01:00
a90fc5ca5a Add VarBuilder::from_backend (#1670)
`candle-nn` already exposes a trait to define custom backends. However,
it's not possible to actually construct a `VarBuilder` with a custom
backend because the constructor is not exposed.

This change makes the constructor public and renames it from `new` to
`from_backend` to avoid that it is seen as the primary
constructor (which could be confusing to users).
2024-02-06 15:26:11 +01:00
adfae2460a Fix rustfmt. (#1669) 2024-02-06 12:06:06 +01:00
678f64dd27 Fix token generation in bilingual models (non-English outputs) (#1668)
Co-authored-by: Guoqing Bao <guoqing.bao@enflame-tech.com>
2024-02-06 12:03:53 +01:00
b545f54a19 Fix clippy lints. (#1667) 2024-02-06 09:03:36 +01:00
1ba11f22d6 Fix: pth files don't load on Windows (#1661)
* Don't treat zip path as OS path

* Add a test case

* Add code to generate test pth data
2024-02-06 08:50:55 +01:00
982722019b add roll function to tensor (#1666) 2024-02-06 08:49:45 +01:00
a83ca2ece0 Bump the crate version to 0.4.0. (#1658) 2024-02-04 19:08:01 +01:00
153c940a9c Update docs to reflect current usage of example (#1610)
modified:   candle-examples/examples/onnx/README.md
2024-02-04 11:59:47 +01:00
50be8a98ba Quantized support for stable-lm2. (#1654)
* Quantized support for stable-lm2.

* Quantized support for v2-zephyr.
2024-02-04 11:57:05 +01:00
58cc896e69 make llama derive clone (#1648)
Co-authored-by: danielclough <danielclough@users.noreply.github.com>
2024-02-04 11:56:03 +01:00
5cdd84e0f6 onnx: add the Flatten operator. (#1638)
* onnx: add the Flatten operator.

* onnx flatten: merge axis condition

---------

Co-authored-by: 王泽龙 <wangzelong@shenqishen.com>
2024-02-03 16:28:47 +01:00
a510ddec4e Mention the new models in the readme. (#1651) 2024-02-03 15:19:57 +01:00
d32abbce53 Add StableLM-2, StableLM Code and Zephyr variants (#1650)
* Add StableLM Code and Zephyr variants

* Add V2 models

* Update README
2024-02-03 14:58:41 +01:00
dfab45e1c8 Supports more audio formats (#1628)
* Supports more audio formats

* Simplify the handling of the different buffer types.

* Check the sample rate.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-02-03 14:26:04 +01:00
96bc704d17 Update mixformer.rs (#1601)
Update the source of the configuration_mixformer_sequential.py
It has been removed, therefore, it is still available in this -> d38e6f954ec29b96fe2cf033937dad64e279b5d9
2024-02-03 13:42:16 +01:00
a52d407ae6 Add ConvNeXt model. (#1604) 2024-02-03 13:34:28 +01:00
9e824ec810 Explicit version for packages that are not in the workspace. (#1642) 2024-01-31 18:57:38 +01:00
beadb1b434 Explicit candle version so that cargo publish can be used easily. (#1641) 2024-01-31 18:42:22 +01:00
6d83d42efb Merge pull request #1606 from FL33TW00D/feature/larger-batches
fix: larger batches
2024-01-29 15:31:10 +00:00
b6afb46601 chore: final 2024-01-22 15:15:19 +00:00
fd7c856564 Merge pull request #1533 from huggingface/ivarflakstad/metal-prng 2024-01-22 07:30:20 +01:00
73d79e6092 chore: actual fix 2024-01-19 09:35:42 +00:00
b1879f17f6 chore: switch to buffer 2024-01-19 08:57:49 +00:00
4f79f5df8a fix: larger batches 2024-01-18 14:30:14 +00:00
1cf34368b7 Merge pull request #1602 from mimiquate/fix-metal-kernel-type
Metal: Use uint8_t as output type in int64_t binary op kernel
2024-01-18 08:40:34 +01:00
17e6e2d7ee Fixes metal kernel u8 type 2024-01-17 15:47:08 -03:00
80b1c689f9 Revert public EncoderParam 2024-01-17 18:09:28 +01:00
db923517b3 Merge branch 'main' into ivarflakstad/metal-prng 2024-01-17 18:03:57 +01:00
403680f17d Quantized GGUF style (#1523)
* Metal quantized modifications proposal.

- Add a device param, wherever needed.
- Create new QMetal storage thing that implements QuantizedType.
- Update everywhere needed.

Fix Python.

Fixing examples.

Fix: fmt + clippy + stub.

Moving everything around.

Only missing the actual implems.

Fixing everything + adding dequantized kernels.

More work.

Fixing matmul.

Fmt + Clippy

Some clippy fixes.

Working state.

Q2K Metal -> Bugged (also present in GGML).
Q4K CPU -> Bugged (present previously, new test catch it).
Q5K CPU -> Bugged (present previously).
Q8_1 Both -> Never really implemented it seems
Q8K metal -> Never implemented in metal

Fixing Q2K bug (present in ggml).

* Cleanup.

* Fix the rebase.

* Removing the fences speeds everything up and *is* correct this time...

* Cleanup the fence.

* After rebase.

* Bad code removal.

* Rebase after phi2 merge + fix replit default to CPU.

* Making the CI happy.

* More happy tests.

---------

Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
2024-01-17 10:27:58 +01:00
86a8e58897 Update metal random kernel and set_seed method
* set_seed via buffer content pointer copy + did_modify_range

* ensure random.metal kernel does not write outside of buffer range when tid==0
2024-01-17 09:12:44 +01:00
5270224f40 Add MobileOne model. (#1595)
* Add MobileOne model.

* Clippy fixes

* Remove a comment.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-01-16 06:34:16 +01:00
7e3349d7c3 Update parquet requirement from 45.0.0 to 50.0.0 (#1592)
Updates the requirements on [parquet](https://github.com/apache/arrow-rs) to permit the latest version.
- [Changelog](https://github.com/apache/arrow-rs/blob/master/CHANGELOG-old.md)
- [Commits](https://github.com/apache/arrow-rs/compare/45.0.0...45.0.0)

---
updated-dependencies:
- dependency-name: parquet
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-01-15 22:35:01 +01:00
1257fc6719 Update safetensors requirement from 0.3.1 to 0.4.1 (#1591)
Updates the requirements on [safetensors](https://github.com/huggingface/safetensors) to permit the latest version.
- [Release notes](https://github.com/huggingface/safetensors/releases)
- [Changelog](https://github.com/huggingface/safetensors/blob/main/RELEASE.md)
- [Commits](https://github.com/huggingface/safetensors/compare/v0.3.1...v0.3.3)

---
updated-dependencies:
- dependency-name: safetensors
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-01-15 22:34:40 +01:00
ea36f3b11f Use the new phi model by default. (#1589) 2024-01-15 12:30:27 +01:00
79478ff5a1 Seed should be updated by random kernel result. 2024-01-15 11:58:25 +01:00
86b7c01b30 Update gemm to the latest version. (#1587) 2024-01-15 09:44:51 +01:00
bdd8107fda Expose the ndarray trait. (#1586) 2024-01-14 20:09:49 +01:00
ecf88a6d38 Merge branch 'main' into ivarflakstad/metal-prng 2024-01-14 17:10:54 +01:00
e6d86b0819 Add the pow operator. (#1583)
* Add the pow operator.

* Support the pow operation in onnx.
2024-01-13 20:24:06 +01:00
88618255cb Fix the rotary embeddings for the new phi implementation. (#1582)
* Fix the rotary embeddings for the new phi implementation.

* Match the activation.

* KV cache fix.

* Use the config activation function.
2024-01-13 19:44:41 +01:00
539ead927a Update the Phi model to use the updated architecture. (#1580)
* Update the Phi model to use the updated architecture.

* Add more of the phi model.

* Repeat KV + caching.

* Apply the rotary embeddings.

* Add support for the new phi model in the phi example.

* Fix a couple glitches.

* Fix a couple more glitches.
2024-01-13 17:38:27 +01:00
a46864bd56 Fix "Minimal Mamba" link in README. (#1577) 2024-01-12 17:47:07 +01:00
bafe95b660 Fix format. (#1576) 2024-01-12 14:23:17 +01:00
a3d92ab226 Metal: Activate bfloat affine and add benchmark (#1543)
* Use cfg to seperate benchmark results based on features

* Add bfloat affine and benchmarks

* Fix flops calculation

* Remove allow pragma

* Avoid some unnecessary returns.

* Improve benchmarks layout

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-01-12 11:19:49 +01:00
e90bcdcc7c Metal: f16 and bf16 where_cond + benchmark (#1545)
* Use cfg to seperate benchmark results based on features

* Add metal where_cond for f16 and bf16. Add benchmark

* Remove allow pragma

* Avoid some unnecessary returns.

* Improve benchmarks layout

* Updated feature separated benchmarks

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-01-12 11:18:11 +01:00
8e06bfb4fd Mention VGG in the readme. (#1573) 2024-01-12 09:59:29 +01:00
6242276c09 Pin the revision used for phi-v2 + make it the default. (#1572)
* Pin the revision used for phi-v2 + make it the default.

* Tweak the custom-ops build.
2024-01-12 09:19:30 +01:00
e06e8d0dbe fmt 2024-01-12 07:26:42 +01:00
e63bb8661b Merge branch 'main' into ivarflakstad/metal-prng 2024-01-12 07:19:58 +01:00
41915184bb Bugfix for dequantizing q5k layers. (#1569) 2024-01-11 23:15:11 +01:00
c1876b8041 Merge pull request #1567 from bayedieng/close-ifdef 2024-01-11 22:14:38 +01:00
85e5680277 remove metal version check 2024-01-11 21:02:03 +00:00
1327419776 close ifdef 2024-01-11 17:14:12 +00:00
402349d120 feat(bf16): add cast support + tests for cast + bin ops (#1524) 2024-01-11 15:49:13 +01:00
9f0c99f0c1 Seperate benchmarks by enabled features (#1538)
* Use cfg to seperate benchmark results based on features

* Remove allow pragma

* Avoid some unnecessary returns.

* Improve benchmarks layout

* Derive bench_name from actual device

* Run CPU benchmarks even when GPU feature is enabled

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-01-11 15:35:38 +01:00
0fc95c9f0c Add a dequantize command to tensor-tools. (#1565)
* Add a dequantize command to tensor-tools.

* Clippy fixes.
2024-01-11 11:21:01 +01:00
2480c5dbdd Add RepVGG model. (#1561)
* Add RepVGG model.

* Add RepVGG README

* Extract var to top level

* Replace hashmap with a match

* Add a variant for the model kind + avoid some unnecessary config cloning.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-01-11 07:07:40 +01:00
63944714f2 Use candle_nn::embedding instead of local copies in a few models. (#1562) 2024-01-10 21:36:27 +01:00
d3bdd788cf Use __HAVE_BFLOAT__ to check for bfloat support instead of metal version check (#1540) 2024-01-10 18:50:30 +01:00
ae06cb74bb Add relu kernel for metal (#1488)
* Add relu kernel for metal

* Copy error messages proposed in #1491

* Revert non relu changes

* Fix name changes

* Fix the last of us (:

* Fix copy and paste mistakes

* Fix typo

* Revert order changes

* Revert order change

* Add deleted functions back

* Run rustfmt
2024-01-10 18:27:17 +01:00
a897fda74e Update memmap2 requirement from 0.7.1 to 0.9.3 (#1556)
Updates the requirements on [memmap2](https://github.com/RazrFalcon/memmap2-rs) to permit the latest version.
- [Changelog](https://github.com/RazrFalcon/memmap2-rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/RazrFalcon/memmap2-rs/compare/v0.7.1...v0.7.1)

---
updated-dependencies:
- dependency-name: memmap2
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-01-10 16:27:59 +01:00
1f1179913a Update gloo requirement from 0.8 to 0.11 (#1558)
Updates the requirements on [gloo](https://github.com/rustwasm/gloo) to permit the latest version.
- [Release notes](https://github.com/rustwasm/gloo/releases)
- [Changelog](https://github.com/rustwasm/gloo/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rustwasm/gloo/commits)

---
updated-dependencies:
- dependency-name: gloo
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-01-10 16:27:20 +01:00
6e98cf2a92 Update cudarc requirement from 0.9.14 to 0.10.0 (#1559)
Updates the requirements on [cudarc](https://github.com/coreylowman/cudarc) to permit the latest version.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.9.14...v0.9.15)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-01-10 16:27:05 +01:00
2cc1247999 Update tokenizers requirement from 0.13.4 to 0.15.0 (#1555)
Updates the requirements on [tokenizers](https://github.com/huggingface/tokenizers) to permit the latest version.
- [Release notes](https://github.com/huggingface/tokenizers/releases)
- [Changelog](https://github.com/huggingface/tokenizers/blob/main/RELEASE.md)
- [Commits](https://github.com/huggingface/tokenizers/commits)

---
updated-dependencies:
- dependency-name: tokenizers
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-01-10 16:26:53 +01:00
edf3fcd1c4 fix: deprecated option field (open-pull-requests-limit-per-dependency) (#1554) 2024-01-10 15:12:46 +01:00
53e4755015 feat: add dependabot to the project (#1553)
* feat: add dependabot to the project

* feat: add let's accept patches/fix from other libs

* Revert "feat: add let's accept patches/fix from other libs"

This reverts commit d31a956f81.
2024-01-10 14:57:20 +01:00
87efb5d8eb Updated feature separated benchmarks 2024-01-09 19:04:31 +01:00
ad181f9cdc Merge branch 'ivarflakstad/seperate-benchmarks-by-feature' into ivarflakstad/metal-prng 2024-01-09 18:55:40 +01:00
88945f2c22 Improve benchmarks layout 2024-01-09 18:31:28 +01:00
12b2a337f3 Handle start-offset when loading a tensor from a pickle file. (#1546) 2024-01-08 09:20:48 +01:00
fb05af4c42 Avoid some unnecessary returns. 2024-01-08 07:19:59 +01:00
ad075a5f7e Remove allow pragma 2024-01-08 06:48:33 +01:00
0eb90ed783 Simpler repro for the neon optimization issue + bugfix (#1544)
* Simpler repro for the neon optimization issue.

* Bugfix for q4k.

* Improve the fix, share the dot-prod bit.

* Clippy fixes.

* Fix for q6k.

* Also fix for q2k.

* Use the new shared dotprod.

* Add more testing.
2024-01-07 20:21:49 +01:00
89b5a06858 Use bindgen-cuda for the custom-kernel example. (#1536)
* Use bindgen-cuda for the custom-kernel example.

* Only depend on the kernels when cuda is enabled.

* Skip rustfmt.
2024-01-07 17:18:46 +01:00
3f04a79ada Use cfg to seperate benchmark results based on features 2024-01-07 14:40:15 +01:00
30313c3081 Moving to a proper build crate bindgen_cuda. (#1531)
* Moving to a proper build crate `bindgen_cuda`.

* Fmt.
2024-01-07 12:29:24 +01:00
e72d52b1a2 Unpin more of the workplace relative dependencies. (#1535) 2024-01-07 12:26:20 +01:00
b4cb982e49 Simplifying our internal cargo dependencies. (#1529) 2024-01-07 12:04:14 +01:00
6ebe043273 Merge branch 'main' into ivarflakstad/metal-prng 2024-01-07 11:52:03 +01:00
6bf52b9fdf Gaussian normal distribution of PRNG via Box-Muller transform 2024-01-07 11:39:46 +01:00
84250bf52f fix index_pos bug when kv cache is disabled. (#1517)
* fix index_pos bug when kv cache is disabled

* Tweak the fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-01-06 11:43:01 +01:00
8d1a57c9a0 chore: update flash attention kernels (#1518)
* chore: update flash attention kernels

* fmt

* remove unused kernels

* force f32

* correct stride
2024-01-05 18:28:55 +01:00
955e63c803 Implement hybrid Tausworthe + LCG psuedo random number generator in metal 2024-01-05 13:27:59 +01:00
3a7304cb0d add link to gpt-from-scratch-rs (#1525) 2024-01-05 11:59:46 +01:00
fa3ea98ba9 Adding bfloat16 support for the cast kernels. (#1520) 2024-01-04 12:12:56 +01:00
135ae5f3eb Simplify the one-hot implementation, support arbitrary rank. (#1514)
* Simplify the one-hot implementation, support arbitrary rank.

* More cleanup.
2024-01-01 11:40:17 +01:00
41614b4a9b Add one-hot/cold encoding (#1489)
* add one-hot encoding

* one_hot: improve error handling, use generic to_vecN::<D>

Bails if the index value is equal to or greater than the depth value,
which would result in an out-of-bounds error.

A redundant check is added to ensure the index value does not exceed
the length of the one-hot matrix size, which would also result in an
out-of-bounds error.

Bails if the index value is less than -1. If the index value is -1,
then it ignores the setting of the on_value for the index value. Only
values that are less than -1 are considered errors.

* one-hot: use two generics, one_hot::<I, O>, for input and output data types

Separating the input and output data types allows the input tensor
indices to be a different data type than the output encoded tensor data type.

For example, one_hot::<i64, u8>(...) will take an input tensor of i64 values
and encode the output tensor using u8 values.

The generic I::DTYPE must match the data type of the input indices, otherwise
the method will bail.

Additionally, this method adds an `allow_f64` option to enable the input indices
data type to be f64 values. f64 values are disabled by default.

TODO: indices data type and the generic I data type are currently not compile-time
checked.

* one_hot: remove input generic, use indices dtype matching

This commit removes the to_f64() type cast and explicitly
matches the DType from the input tensor. Currently, only U8,
U32 and I64 is supported for input tensors.

The match arms on the dtype is verbose. It would be nice
to use a generic type with the WithDtype traitbound to
pass to the to_vecN method and then return an inner value.

Open to suggestions for better approaches here to reduce
the match arm verbosity.

* one_hot: use flat_map iterator over dims instead of nested for loop

This commit replaces the nested for loops with an flat map iter over
the dimensions of the input tensor.

This commit also adds a test for a rank 3 input tensor.

* one_hot: use mandatory on/off-values, remove const msgs

This commit also updates doc tests, comments and test cases.

* Small cleanups.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-01-01 11:18:40 +01:00
03ce8caf40 Format properly the Stable Diffusion example run with params (#1511)
Move out the --sd-version flag out of the prompt.
2024-01-01 11:13:35 +01:00
b0fe5e4453 Do not implement Module for BatchNorm. (#1513) 2024-01-01 10:13:13 +01:00
1fb2dd905c Add support for tiny-llama-1.1b. (#1512) 2023-12-31 12:18:25 +01:00
a0facd0e67 Small tweaks to batch-norm. (#1505) 2023-12-30 17:06:07 +01:00
4290b81244 [Breaking] Add training to batchnorm with exponential moving average (#1504)
* Add training to batchnorm with exponential moving average

* Add more checks to batch norm

* Resolve some review comments

* Add with_momentum varients of `new` methods

* Add check for range of momentum variable; update batch norm test

* Run cargo fmt

* Add back num_features parameter

* Format; tiny simplification
2023-12-30 16:42:08 +01:00
51e577a682 Add Policy Gradient to Reinforcement Learning examples (#1500)
* added policy_gradient, modified main, ddpg and README

* fixed typo in README

* removed unnecessary imports

* small refactor

* Use clap for picking up the subcommand to run.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-30 09:01:29 +01:00
0a245e6fa4 Metal: support unary abs (#1503)
* Metal: support unary abs

* cargo fmt
2023-12-30 00:00:12 +01:00
87d7f81b43 Metal: more u8/u32 (#1502)
* Adds more metal u8

* Metal: more u32
2023-12-29 23:56:21 +01:00
4373534d59 Metal: i64 basic support (#1495)
* Adds basic metal i64 support

* metal copy i64
2023-12-29 19:42:50 +01:00
f4a2787217 Merge pull request #1498 from huggingface/debugging_windows_ci
Fix CI
2023-12-29 12:33:50 +01:00
488e02a3f6 Merge pull request #1496 from bayedieng/unary
Implement urecip op for metal backend
2023-12-29 12:20:52 +01:00
adc95ca2bf Ignore skipped. 2023-12-29 12:15:57 +01:00
4907c63ea1 Ignore stop on remote forks. 2023-12-29 12:12:10 +01:00
d76ac20e0e Fix. 2023-12-29 12:06:38 +01:00
f5c98f22c7 Merge pull request #1491 from mimiquate/metal-errors
Improves metal's not implemented error messages
2023-12-29 12:03:40 +01:00
5b12fbb143 Trying to fix flakyness by making hub_2 and hub_3 serial tests (potential issue on mingw with mmap). 2023-12-29 11:13:33 +01:00
cc06ba2294 fix bad pattern matching and function name 2023-12-29 09:46:24 +00:00
a6bd0b47a5 Fix the CI. 2023-12-29 10:17:52 +01:00
b59b1b2bb6 remove generated png 2023-12-28 21:50:58 +00:00
3922b42c18 add urecip op to metal backend 2023-12-28 21:50:12 +00:00
1e442d4bb9 Fix lints for clippy 1.75. (#1494) 2023-12-28 20:26:20 +01:00
cd889c0f8a add config_amazon_mistral_lite (#1493)
Co-authored-by: Ubuntu <danielclough@users.noreply.github.com>
2023-12-28 19:59:58 +01:00
8e93e76a91 fixes error message 2023-12-28 15:03:05 -03:00
b3e838f3e2 cargo fmt 2023-12-28 14:07:34 -03:00
8bf892403a Improves metal's not implemented error messages 2023-12-28 11:04:06 -03:00
d35f0a1376 Bump the crate version to 0.3.3. (#1490) 2023-12-28 13:38:30 +01:00
65cb90bd40 Add some mention to SOLAR-10.7B in the readme. (#1487) 2023-12-27 15:25:39 +01:00
996a7f2e24 Rework the llama example config, add the solar model. (#1485) 2023-12-26 22:24:04 +01:00
3071ea6c3e Use the new hub helper function. (#1484) 2023-12-26 09:44:30 +01:00
37c539f2b7 Helper function to load sharded safetensors files (#1481)
* Fix the quantized mistral example.

* Add a helper function to load sharded safetensors weights.

* Use the sharded loader.
2023-12-25 21:49:21 +01:00
eae3a20d43 Merge pull request #1479 from huggingface/upsample_metal
Adding upsample_nearest_2d.
2023-12-25 14:25:53 +01:00
13a5d15ebc Adding upsample_nearest_2d. 2023-12-25 14:25:19 +01:00
1505d85276 Merge pull request #1461 from huggingface/metal-conv
Adding the convolutions (1d + 2d) to candle on metal.
2023-12-25 12:48:09 +01:00
95e18ef675 Fixing matmul for convolutions. 2023-12-25 12:29:34 +01:00
7135791dd5 Fix the quantized mistral example. (#1478) 2023-12-25 09:31:24 +01:00
88589d8815 Support mistral instruct v0.2. (#1475)
* Support mistral instruct v0.2.

* Use the safetensors model now that they are available.
2023-12-23 16:18:49 +01:00
5b35fd0fcf MMLU evaluation for Phi. (#1474)
* MMLU evaluation for Phi.

* Improve the evaluation.
2023-12-23 15:28:36 +01:00
ba1fae590e Validate the kernel size in pooling ops. (#1473)
* Validate the kernel size in pooling ops.

* Revert the changes to basics.
2023-12-23 11:19:22 +01:00
78d982e1bd Fix for mamba 2.8b. (#1472) 2023-12-23 11:01:39 +01:00
d8b9a727fc Support different mamba models. (#1471) 2023-12-23 10:46:02 +01:00
ceb78d3e28 Sketch the minimal mamba example. (#1465)
* Sketch the minimal mamba example.

* Fix rustfmt.

* Forward pass for mamba.

* Finish the forward pass.

* Inference fixes.

* Bugfixes.

* More fixes.

* Add a readme.
2023-12-22 00:28:50 +01:00
f6408a3779 feat: add clear_kv_cache to mistral and qmistral models (#1464) 2023-12-21 21:19:19 +01:00
10d94659c3 Adding the convolutions (1d + 2d) to candle on metal. 2023-12-21 10:39:24 +01:00
563a79afa1 make fn name generic (#1459)
Co-authored-by: Ubuntu <danielclough@users.noreply.github.com>
2023-12-21 02:16:31 +01:00
8ede5f4210 add fn config_chat_ml (#1458)
* add fn config_chat_ml

* Add a link to the original config.

---------

Co-authored-by: Ubuntu <danielclough@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-12-20 21:03:24 +01:00
9fc210fae8 Merge pull request #1318 from huggingface/metal4
Starting to fix some tests.
2023-12-20 15:37:31 +01:00
9b5e4843a6 Optimizing decode matmul (Phi at 28tok/s on M3).
Adding some benchmark in order to help checking out matmul performance.
2023-12-20 09:54:19 +01:00
03641293ee Clippy pass. 2023-12-18 15:22:43 +01:00
064ba17bd7 Remove print. 2023-12-18 11:04:16 +01:00
e8ee253ee0 Missing cast. 2023-12-18 11:01:18 +01:00
8bd3d6b94b Index add. 2023-12-18 10:46:01 +01:00
6a3ca7da0c Scatter add. 2023-12-18 10:32:22 +01:00
96f1a28e39 Add a simple full method. (#1455)
* Add a simple implementation of the full method.

* Add the docstring.
2023-12-17 20:15:57 -05:00
586b6f6fff Adding gather op. 2023-12-17 23:34:12 +01:00
e4b0cc59f5 Adding CMP 2023-12-17 22:32:25 +01:00
0a6e0a8c9a Implement randn (CPU-> device) 2023-12-17 19:09:08 +01:00
972903021c Finish reduce kernels. 2023-12-17 19:07:00 +01:00
94817dac56 Bump the crate version to 0.3.2. (#1452) 2023-12-17 05:34:53 -06:00
1e86717bf2 Fix a couple typos (#1451)
* Mixtral quantized instruct.

* Fix a couple typos.
2023-12-17 05:20:05 -06:00
c630622a07 Expose AdamW parameters (#1449)
* Expose AdamW parameters

* Use reference
2023-12-16 18:41:56 -06:00
c4cfcf1539 Tweak the readme for phi and the default sample length. (#1450) 2023-12-16 18:11:36 -06:00
1782e93de6 Mixtral quantized instruct. (#1447) 2023-12-16 16:16:39 -06:00
cfdf9640a3 Readme tweaks. (#1446) 2023-12-16 06:23:12 -06:00
e12cbfd73b Update the readme to mention mixtral. (#1443) 2023-12-15 19:29:03 -06:00
30a958e5dd Quantized mixtral model (#1442)
* Add the Mixtral model.

* Add more of the mixtral layers.

* Add the final layers for mixtral.

* Sketch the expert selection.

* Add some expert routing logic.

* Hopefully finish the routing logic for mixtral.

* Add the mixtral example.

* Fix the weight filenames.

* Bugfix.

* Another fix.

* Yet another fix + remove the unused pragma.

* Shape fix.

* Support for quantized mixtral.

* Support mixtral in the quantized example.

* Mlp or moe type.

* Fix the expert field namings.

* Refactor the mlp bit.

* More MoE logic.

* Add the MoE quantized logic.

* Fix the experts length.
2023-12-15 19:16:06 -06:00
614842b311 Add the Mixtral model. (#1437)
* Add the Mixtral model.

* Add more of the mixtral layers.

* Add the final layers for mixtral.

* Sketch the expert selection.

* Add some expert routing logic.

* Hopefully finish the routing logic for mixtral.

* Add the mixtral example.

* Fix the weight filenames.

* Bugfix.

* Another fix.

* Yet another fix + remove the unused pragma.

* Shape fix.

* Add a readme.
2023-12-15 14:19:56 -06:00
79eab519fd Fix phi example (#1436)
* Fix phi example

* Remove the cuda mention.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-15 07:01:10 -06:00
6bc92e63cb Addressing a lot of comments. 2023-12-15 13:06:04 +01:00
aa04015098 Remove unwrap(). 2023-12-15 12:23:28 +01:00
8b5059e951 Remove test file. 2023-12-15 11:55:30 +01:00
26540641c1 Renamed all kernel names. 2023-12-15 11:24:47 +01:00
34d83377f6 Better error message on older macos 2023-12-15 11:18:54 +01:00
77197379cc More cleanup. 2023-12-15 11:17:05 +01:00
916a8c5464 Revert candle-transformers. 2023-12-15 11:15:21 +01:00
243e83f2b9 Adding a bunch of docs !
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2023-12-15 11:03:05 +01:00
cf27868b57 More cleanup. 2023-12-15 01:44:22 +01:00
40c3e1bd5a cleanup. 2023-12-15 01:41:14 +01:00
ece4c69a68 Fixing softmax. 2023-12-15 01:35:08 +01:00
4eeaf205d6 Fix softmax for long sequences (missing barrier). 2023-12-14 19:37:03 +01:00
f419a38e1a Fix use resource. 2023-12-14 16:52:37 +01:00
361f2ad2af Working with merging encoders and using fences. 2023-12-14 16:05:33 +01:00
e60f9b5dfc Speedup ShardedSafeTensors to load Tensors with default hints (#1384)
* Speedup ShardedSafeTensors to load Tensors with default hints

* Tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-14 08:08:56 -06:00
7be982f6f7 Mention phi-2 in the readme. (#1434) 2023-12-14 08:02:27 -06:00
104e196d46 Phi 2 wasm (#1432)
* add phi 2.0 quantized model wasm

* cols

* spell

* bug
2023-12-14 06:04:17 -06:00
5e33c85c8f Quantized version for phi-v2. (#1430)
* Quantized version for phi-v2.

* More quantized support.
2023-12-13 21:16:34 -06:00
2b3a018be7 Support for phi-2. (#1429)
* Support for phi-2.

* Use the v2 naming scheme.
2023-12-13 20:59:29 -06:00
931432ed55 Fixing tests + matmul from MFA 2023-12-13 16:58:36 +01:00
0404a3eb5b Removed MPSMatrix entirely (buggy). 2023-12-13 16:21:48 +01:00
a9d0657432 Better version ? 2023-12-13 12:09:20 +01:00
4cb443d00a Fix the logsumexp test. (#1426) 2023-12-12 10:56:11 -06:00
87dc559817 Lots of updates including some stack of command buffers. 2023-12-12 17:41:56 +01:00
77252ffb82 Add logsumexp function (#1424) 2023-12-12 10:32:17 -06:00
18eb87f25f Upsample grad (#1420)
* encode size of upsample in enum

* working convolution method for limited 2d kernels

* add test for sf 3 interpolation

* add higher dimensional tests, fix to work with multichannel input

* Remove commented out line.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-10 08:43:24 +01:00
da0af3cb3e Merge pull request #1408 from jbochi/metal_gelu2
Fix NaN errors for Gelu in Metal
2023-12-09 19:46:36 +01:00
9bd94c1ffa Speed up bert with approx gelu (#1410) 2023-12-06 17:46:37 +01:00
803ac8405b Put back affine strided tests
Co-Authored-By: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2023-12-06 17:04:15 +01:00
6e25822d4f Fix gelu for large x 2023-12-06 09:59:44 -05:00
236b820e28 Another prelu bugfix. (#1407) 2023-12-06 09:54:41 +01:00
2648e797c2 Use the proper broadcasting for prelu. (#1406) 2023-12-05 07:09:31 +01:00
b5c283e86f Add the prelu layer. (#1402) 2023-12-03 16:06:09 +00:00
8418154ee0 Add nvcc ccbin support to examples (#1401) 2023-12-03 16:01:16 +00:00
99b7273b03 Add compute cap env support to examples (#1400) 2023-12-03 16:00:24 +00:00
16161145ae Add the leo models to the quantized examples. (#1398) 2023-12-03 12:30:41 +00:00
0738df5290 Add more mentions to SDXL Turbo in the readme. (#1397) 2023-12-03 10:41:21 +00:00
37bf1ed012 Stable Diffusion Turbo Support (#1395)
* Add support for SD Turbo

* Set Leading as default in euler_ancestral discrete

* Use the appropriate default values for n_steps and guidance_scale.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-03 08:37:10 +01:00
dd40edfe73 Add Euler Ancestral Discrete Scheduler (#1390)
* Add Euler Ancestral Discrete Scheduler

* Fix a bug of init_noise_sigma generation

* minor fixes

* use partition_point instead of custom bsearch

* Fix some clippy lints.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-12-02 19:59:23 +00:00
5aa1a65dab Add quantized Starling, fix open-chat prompt (#1393)
* Add quantized Starling, fix open-chat prompt

* Fix open-chat and starling prompts
2023-12-02 16:47:19 +00:00
2ca086939f Put back affine strided tests 2023-11-30 11:40:39 +01:00
4349ff1fc2 Starting to fix some tests.
Few fixes.

Going back on remote metal-rs.

Reusing a single buffer (for now) to speed things up.

Adding some half kernels.

All tests are panicking instead of random failure.

Putting back f16 index select.

Add erf.

Working version for llama2-c.

Fixes + cache compute_pipeline_state.

BF16 metal fix.

Remove some prints.

new_owned -> new()..to_owned().

Better batched matmul.

Metal operational.

Reuse buffers on our own reference counts.

Tmp gemm.

Revert "Tmp gemm."

This reverts commit c65f68e988.

Interleave committing.

Speeding up copies using blit.

Fmt.

Fmt.

Remove the assert!

Fmt all.

Fixes after big rebase.

Add softmax for half and bfloat + tests

Fixing Llama example + accumulate softmax in float.
2023-11-30 11:30:31 +01:00
7c3cfd1086 Use the llama weight names for the Yi example. (#1381) 2023-11-27 20:42:52 +00:00
e2eb6590ed Merge pull request #1323 from huggingface/metal3
Adding the test scaffolding.
2023-11-27 13:06:01 +01:00
481c45d78d Add a basic implementation for slice-assign. (#1377) 2023-11-26 17:31:22 +00:00
14a2bdc062 Small tweak: remove the macro usage for the range indexing trait. (#1376) 2023-11-26 16:30:59 +00:00
bfa7c8fc01 Implement the module trait directly for QMatMul. (#1372) 2023-11-25 10:09:45 +00:00
762e996ce6 Distibert (#1366)
* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting

* distilbet files

* Apply various cleanups.

* More cleanups.

* More polish.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-11-24 15:09:14 +00:00
ca19a9af62 Fix linspace implementation (#1358)
* Fix linspace implementation

`steps` should be strictly greater than 1 to make it consistent with the context.

* Handle steps == 0 and steps == 1.

* Fix rustfmt.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-11-23 07:35:13 +00:00
ec23427d60 Ensure to copy data to cpu before iterating. (#1360) 2023-11-23 07:24:25 +00:00
f83e14f68d Add candle-lora transformers to readme? (#1356)
* Demonstrate lora transformers in readme

* Shorten readme
2023-11-21 17:54:24 +00:00
c7e613ab5e Update the readme. (#1354) 2023-11-21 09:38:27 +00:00
8f63f68289 Fix the kalosm link (#1353) 2023-11-21 06:18:14 +01:00
1edc3ddf24 Allowing feature metal to compile. 2023-11-20 20:17:16 +01:00
b380657bfe Merge pull request #1309 from huggingface/metal2
Adding the actual backend
2023-11-20 17:24:01 +01:00
60f624a902 Moving tests around. 2023-11-20 16:17:19 +01:00
8d6c6de8e0 Missing new test. 2023-11-20 14:38:35 +01:00
7ec345c2eb Adding the test scaffolding. 2023-11-20 14:38:35 +01:00
671fc29b36 Fmt. 2023-11-20 14:38:20 +01:00
dc64adb8e4 Fixing cos_f16 test. 2023-11-20 14:17:07 +01:00
c66e5d4716 Fix comments. 2023-11-20 14:13:44 +01:00
bd3b243725 Update candle-metal-kernels/Cargo.toml 2023-11-20 14:12:57 +01:00
2813fb5dbc Cleanup fixed a few ops removed debugging scaffolding. 2023-11-20 14:12:57 +01:00
7cfffcac10 Debugging rope. 2023-11-20 14:12:57 +01:00
38de52bc4b Fixed matmul (display still broken without casting back to CPU first? ) 2023-11-20 14:12:57 +01:00
d46670f7c0 Tmp state. 2023-11-20 14:12:57 +01:00
f710fab02e Fixing the kernels + launches to make them faster.
Cool work by @ivarflakstad

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2023-11-20 14:12:57 +01:00
f82bf2d915 Adding indexing.
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
2023-11-20 14:12:57 +01:00
df6814f34e Refactor to simplify our lives for settings the params in the encoder. 2023-11-20 14:12:57 +01:00
39406a6721 Adding the actual backend 2023-11-20 14:12:56 +01:00
976ad9f9c2 Remove tracing. 2023-11-20 14:12:29 +01:00
a4c4a56429 Metal part 1 - Scaffolding for metal. 2023-11-20 14:12:05 +01:00
f49bf6a81d Fix OpenChat 3.5 tokenizer (#1347) 2023-11-19 18:48:04 +00:00
992a788da1 Add OpenChat 3.5 to quantized examples (#1346)
* Add OpenChat to quantized examples

* Add chat prompt

* Make the openchat example more in line with the other models.

* Fix a typo.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-11-19 18:28:52 +00:00
8d8f48c60c feat: add test for individual onnx ops (#1332)
* feat: add test for individual onnx ops

* fix: prefer consts when possible

* feat: add move op tests
2023-11-19 08:17:09 +01:00
d31f11035f Support for CumSum in ONNX models. (#1340) 2023-11-17 22:03:40 +00:00
9ab3f9729f Use the whisper-v3 tokenizer now that it has been added. (#1337)
* Use the whisper-v3 tokenizer now that it has been added.

* Use the appropriate nospeech token.
2023-11-16 22:10:31 +00:00
a1f41ab37b feat: adds reset_kv_cache (#1335) 2023-11-16 21:17:42 +00:00
92a05b51cf fix: address clippy 0.1.74 issues (#1336)
- clippy::needless-borrows-for-generic-args
- clippy::reserve-after-initialization
2023-11-16 21:15:22 +00:00
c6763e3b41 Add a simple implementation of cumsum. (#1334)
* Add a simple implementation of cumsum.

* Add another test.
2023-11-15 21:11:15 +00:00
347e31c9ff Add the tril/triu/eye ops. (#1333)
* Add tril/triu/eye.

* Revert the metal crate tweak.
2023-11-15 20:34:37 +00:00
f4fcf60900 Update readme.md (#1322)
Updating the readme to coincide with other examples. If you try to run it as previously written, you will get a "cannot find the path specified" error.
2023-11-12 09:46:19 +00:00
12561b31d3 Fix pose estimation image path (#1326) 2023-11-12 09:45:26 +00:00
a209ce8ceb Update for 0.3.1. (#1324) 2023-11-11 18:48:52 +00:00
f1e678b39c Mention the Yi-6b/Yi-34b models in the readme. (#1321) 2023-11-11 12:39:11 +01:00
a007f8fdb4 Add the Yi-6b and Yi-34b models. (#1320)
* Add the Yi-6b model.

* Add the 34b model.

* Add the yi example.

* Fix the weight file names.
2023-11-11 12:00:48 +01:00
2341aa079e Fix quantized zephyr chat prompt (#1314) (#1317)
* Fix quantized zephyr chat prompt (#1314)

* Avoid using a mutable variable.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-11 09:14:12 +01:00
9e666d4229 Add the var method. (#1315)
* Add the var method.

* Add a test.
2023-11-10 22:47:57 +01:00
1b12142a02 Add min to buckets in relative_position_bucket (#1312) 2023-11-10 11:57:25 +01:00
d2c3f14773 Fix for flash-attn. (#1310)
Co-authored-by: laurent <laurent@par2dc5-ai-prd-cl01dgx02.cm.cluster>
2023-11-10 10:27:27 +01:00
26c4e5bf1d Metal part 1 - Scaffolding for metal. (#1308)
* Metal part 1 - Scaffolding for metal.

* Remove tracing.
2023-11-10 08:35:48 +01:00
18d30005c5 Add support to UL2 model family (#1300)
* Add support to UL2 model family

* Update docs with UL2

* Create ActivationWithOptionalGating to avoid polluting activations

* Also refactor quantized t5

* Remove useless conversion

* Revert Activation::NewGelu name change

* Remove useless return

* Apply rustfmt and clippy recommendations

* Reuse t5::ActivationWithOptionalGating in quantized version

* (cosmetic change) use a match rather than ifs + avoid early returns.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-09 18:55:09 +01:00
6958384327 Add support for TrOCR Model (#1303)
* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting

* add trocr model

* fix formatting

* commit the actual model lol

* more formatting

* remove tokenizer config
2023-11-09 18:49:17 +01:00
e6697471bb Add weight and bias functions to LayerNorm (#1306) 2023-11-09 16:09:01 +01:00
73d02f4f57 fix: negative axis (#1296)
* fix: negative axis

* Use normalize_axis.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-08 23:28:21 +01:00
f772213e84 Fix bug introduced in madlad PR (#1298) 2023-11-08 17:55:46 +01:00
2feb0b054f Add the mel filters for 128 bins. (#1295) 2023-11-08 08:23:53 +01:00
2d28497197 Preliminary support for whisper v3. (#1294)
* Preliminary support for whisper v3.

* Add the missing files.
2023-11-08 06:42:52 +01:00
f3a4f3db76 PyO3: Add optional candle.onnx module (#1282)
* Start onnx integration

* Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx

* Implement ONNXModel

* `fmt`

* add `onnx` flag to python ci

* Pin `protoc` to `25.0`

* Setup `protoc` in wheel builds

* Build wheels with `onnx`

* Install `protoc` in manylinux containers

* `apt` -> `yum`

* Download `protoc` via bash script

* Back to `manylinux: auto`

* Disable `onnx` builds for linux
2023-11-08 06:37:50 +01:00
7920b45c8a Support for timegroupnorm in encodec. (#1291) 2023-11-07 22:39:59 +01:00
d4a45c936a Quantized model small tweaks (#1290)
* Support the shape op in ONNX.

* Share the axis normalization bits.

* Add some limited support for gather.

* Unsqueeze.

* Comparison with broadcasting.

* Add Not + handle i32.

* Tweaks for the quantized model.
2023-11-07 21:21:37 +01:00
c912d24570 Update README: Move T5 to Text to Text section (#1288)
I think it makes more sense to have it there, since it's a seq2seq model with cross attention, and not a LM. There are also Decoder only T5 models that work as LMs, but that's not the standard.
2023-11-07 16:14:04 +01:00
d5c2a7b64b Add info about MADLAD-400 in readme files (#1287) 2023-11-07 15:21:59 +01:00
508f811b93 Add support for MADLAD400 (#1285)
* Add support for madlad

* Add support for quantized MADLAD
2023-11-07 05:35:37 +01:00
a773a4b22b [ONNX] Support a couple more ops. (#1284)
* Support the shape op in ONNX.

* Share the axis normalization bits.

* Add some limited support for gather.

* Unsqueeze.

* Comparison with broadcasting.

* Add Not + handle i32.
2023-11-06 22:44:58 +01:00
5a363dbc26 Adds check for 7b-zephyr and uses correct template (#1283)
* Adds check for 7b-zephyr and uses correct template

* Handle zephyr as mistral.

* Disable the protoc bits of the CI.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-06 21:05:39 +01:00
abc4f698c5 Add candle-sampling (#1278) 2023-11-06 12:53:29 +01:00
a923e8b53a Add a link to candle-ext to README.md (#1277) 2023-11-06 12:44:39 +01:00
2a45bcf943 Put the onnx example behind a feature flag. (#1276)
* Put the onnx example behind a feature flag.

* Exclude the onnx bits from the workspace.

* README tweaks.
2023-11-06 07:45:07 +01:00
47f4ddb011 Added info about missing protoc (#1275)
Co-authored-by: figgefigge <fredric.1337mail.com>
2023-11-06 06:47:32 +01:00
f365a075e5 Add more models to the onnx example. (#1273)
* Add more models to the onnx example.

* Input validation.

* Input validation.

* Bugfix.

* Implement clip.

* BatchNorm support.

* Get the efficientnet onnx to work.
2023-11-05 16:57:26 +01:00
60fdab4e17 Detach all grads during backprop. (#1243)
* Detach all grads during backprop.

* Add an environment variable to select the backprop behavior.

* Update the comment.
2023-11-05 14:07:41 +01:00
928a9d906e [ONNX] Do not generate values for constants. (#1272)
* Do not generate values for constants.

* Add an onnx based example using squeezenet.
2023-11-05 11:23:14 +01:00
d1d89bac1f feat: download cifar dataset parquet files (#1259) 2023-11-05 10:55:49 +01:00
39ad840a90 Better tensor initialization in ONNX. (#1270)
* Better tensor initialization in ONNX.

* MaxPool support.

* Add AvgPool.

* Get the squeezenet example to work.
2023-11-04 22:17:45 +01:00
b5e4f84bed Refactor the onnx attribute getters. (#1268)
* Refactor the onnx attribute getters.

* Add get-attr-opt.

* Add support for convolutions.

* Add support for convolutions.
2023-11-04 21:31:48 +01:00
7051fb8098 feat: add backprop for elu (#1269)
* feat: add backprop for elu

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-04 21:26:41 +01:00
dc68c130e4 Support more ONNX ops. (#1267)
* Add LogSoftmax.

* Support for Transpose.
2023-11-04 15:10:14 +01:00
bc9a1bf239 Improve the ONNX basic example + bugfixes (#1266)
* Generate some zeros tensor in the onnx simple-eval example.

* Fix the casting operation.

* Support more ops.

* Handle reshape.

* Concat.

* Softmax.
2023-11-04 10:02:47 +01:00
f7c957d64f ONNX casting support. (#1265)
* ONNX casting support.

* Handle tensor constants.

* Bugfix the binary ops.
2023-11-04 08:34:24 +01:00
8cbb9d0e6c Add some preliminary ONNX support (#1260)
* Add the onnx protos.

* Move the reading bits.

* Install protoc on the CI.

* Install protoc on the cuda CI too.

* Use clap for the onnx tool.

* Tweak the CI protoc install.

* Add some simple evalution function.

* Add some binary operator support.
2023-11-04 06:36:05 +01:00
bfe95115c6 Update README.md (#1264) 2023-11-04 05:32:32 +01:00
6fa3151820 Allow using gguf-v3 files. (#1262) 2023-11-03 23:07:53 +01:00
0a58886ccb add distil-whisper link (#1261) 2023-11-03 21:34:42 +01:00
3173b1ce3b feat: impl backprop for erf and gelu-erf (#1258)
* impl backprop for erf anf gelu-erf

* feat: unary tests added for erf and gelu-erf

* fix: (clippy) remove immediately dereferenced ref

* fix: improve comments with pytorch code snippet

* fix: adjust comment typo in backprop impl
2023-11-03 21:32:30 +01:00
ad63f20781 add Kalosm to the list of external resources (#1257) 2023-11-03 19:16:46 +01:00
1cfc5d6d0c Backprop support for conv1d (cpu only for now). (#1255) 2023-11-03 14:23:53 +01:00
b07b2350b6 Test for the transposed conv1d. (#1254) 2023-11-03 13:10:28 +01:00
1b5063f3ca Add vllm external resource (#1253) 2023-11-03 12:40:31 +01:00
3b0d1e7d03 Transposed conv1d in candle-nn. (#1252) 2023-11-03 11:18:25 +01:00
be4555c5a5 Add the conv-transpose1d op. (#1251)
* Skeleton structure for conv-transpose1d.

* CPU implementation for conv-transpose1d.
2023-11-03 09:44:46 +01:00
6975c65112 Share the layer-norm implementation. (#1248) 2023-11-03 06:30:05 +01:00
a2a20aeecc Add the swiglu activation from the chatglm PR. (#1246) 2023-11-02 20:01:34 +01:00
e08fbb6543 Add support for distil whisper (#1245)
* Add support for distil-whisper.

* Add distil-large.

* Rename the large model.
2023-11-02 19:32:35 +01:00
d39d0c40fd Add hard-sigmoid and hard-swish activations (#1244)
* Add hard-sigmoid and hard-swish activations

* Update ops.rs

* Use / rather than div.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-02 18:20:27 +01:00
b97463098c llama2-c wasm fix. 2023-11-02 10:31:47 +01:00
fbd69f952c Lazy detach. (#1242) 2023-11-02 07:33:48 +00:00
6c990a33ea Remove the unused pragma for marian. (#1236) 2023-11-01 20:04:52 +00:00
1704f1b3ae Consolidate the with-tracing usage. (#1234) 2023-11-01 18:21:36 +00:00
693fad511c Preliminary support for ssd1b. (#1233) 2023-11-01 14:37:52 +00:00
36fb84f038 Add a hack for generating random uniform/normal for f16/bf16. (#1228) 2023-10-31 20:27:59 +00:00
c12ad45562 Add a KV cache to marian decoding. (#1226) 2023-10-31 08:47:44 +00:00
7d0202710b Instructions for generating the tokenizer configs for marian-mt. (#1225) 2023-10-31 07:56:26 +01:00
392a00a147 Add support for the marian base model. (#1221) 2023-10-30 19:20:36 +00:00
4c967b9184 Use the hub files for the marian example. (#1220)
* Use the hub files for the marian example.

* Use the secondary decoder.

* Add a readme.

* More readme.
2023-10-30 17:29:36 +00:00
c05c0a8213 PyO3: Add equal and __richcmp__ to candle.Tensor (#1099)
* add `equal` to tensor

* add `__richcmp__` support  for tensors and scalars

* typo

* more typos

* Add `abs` + `candle.testing`

* remove duplicated `broadcast_shape_binary_op`

* `candle.i16` => `candle.i64`

* `tensor.nelements` -> `tensor.nelement`

* Cleanup `abs`
2023-10-30 15:17:28 +00:00
969960847a Bugfixes for marian-mt. (#1219)
* Bugfixes for marian-mt.

* Apply the final decoding head.

* More fixes.
2023-10-30 11:44:19 +00:00
5fc66bd4ba Support negative steps in arange. (#1218) 2023-10-30 07:40:54 +00:00
174b208052 PyO3: Better shape handling (#1143)
* Negative and `*args` shape handling

* Rename to `PyShapeWithHole` + validate that only one hole exists

* Regenerate stubs

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2023-10-29 15:41:44 +00:00
154c674a79 Add i64-abs. (#1216) 2023-10-29 15:28:53 +00:00
7bbde55c61 Marian MT model (#1210)
* Skeleton files for the marian MT model.

* Marian initialization.

* Implement the attention forward method.

* Forward pass for the encoder side.

* Expose the encoder and decoder.

* Start plugging the decoder.

* Forward pass for the decoder layer.

* Set up the marian example.

* Add some missing backtraces.

* Bugfix.
2023-10-29 15:12:22 +00:00
c3f2676d49 PyO3: Add CI to build & upload wheels as artifacts. (#1215)
* Add maturin ci

* fix paths

* Change sdist path
2023-10-29 13:44:05 +00:00
46d6566c99 Fix the conv2d gradient computation. (#1214) 2023-10-29 09:50:04 +00:00
55bc3382cf Allow for different behavior between training and eval (#1213)
* Forward with training.

* Do not use dropout on vgg evaluation.
2023-10-29 07:53:09 +01:00
dece37c6f4 feat: implement VGG13, VGG16 and VGG19 (#1211)
* feat: implement VGG13, VGG16 and VGG19

* Cosmetic fixes.

* More cosmetic tweaks + avoid re-loading the weights on each final layer.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-10-29 06:10:23 +00:00
498c50348c Add DDPG and fix Gym wrapper (#1207)
* Fix Gym wrapper
- It was returning things in the wrong order
- Gym now differentiates between terminated and truncated

* Add DDPG

* Apply fixes

* Remove Result annotations

* Also remove Vec annotation

* rustfmt

* Various small improvements (avoid cloning, mutability, get clippy to pass, ...)

---------

Co-authored-by: Travis Hammond <travis.hammond@alexanderthamm.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-10-28 19:53:34 +01:00
012ae0090e Infer the config for llama2-c. (#1208) 2023-10-28 19:00:39 +01:00
95a857cf57 Move the llama2-c model in transformers. (#1205) 2023-10-28 16:51:19 +01:00
612f5b8156 Make more models cloneable. (#1203) 2023-10-28 07:43:08 +01:00
ef33df7ae2 No need for the even constraint on vecdot-q40-q80. (#1202) 2023-10-28 07:23:59 +01:00
c8face3f95 Add the relu2 and relu6 activations. (#1201) 2023-10-27 20:51:16 +01:00
85bea43e5b Make the whisper model cloneable (#1200)
* Add a quantized variant of llama2.c

* Clippy fixes.

* Make the whisper model cloneable.
2023-10-27 16:59:19 +01:00
b3181455d5 Add fuse-conv-bn method for Conv2d (#1196)
* Add fuse-conv-bn method for Conv2d

* no unwrap

* run rustfmp and clippy
2023-10-27 15:56:50 +01:00
e2826e70b3 Add a quantized variant of llama2.c (#1197)
* Add a quantized variant of llama2.c

* Clippy fixes.
2023-10-27 15:34:06 +01:00
916619f70b Minor cleanup (#1194)
* Add some missing backtraces.

* Small cleanup.
2023-10-27 14:08:29 +01:00
9b1158b315 Add some missing backtraces. (#1193) 2023-10-27 06:09:11 +01:00
70d06ab4b0 Add support for the phi-hermes finetuned model. (#1192) 2023-10-27 05:57:08 +01:00
0ec5ebcec4 Use the hub model file when possible. (#1190)
* Use the hub model file when possible.

* And add a mention in the main readme.
2023-10-26 20:00:50 +01:00
c8e197f68c Fixes for jina-bert. (#1189) 2023-10-26 18:52:30 +01:00
5f20697918 Add the jina-bert embeddings model. (#1187)
* Add the jina-bert model.

* Use alibi.

* Remove the unused pragma.

* Recompute the alibi embeddings.

* Generate the token type ids.

* Use the module trait.

* Add the jina-bert example.

* DType fix.

* Get the inference to work.
2023-10-26 16:54:36 +01:00
e37b487767 Add Blip to online demos README.md (#1184)
* Add Blip to online demos README.md

* Punctuation.

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2023-10-26 11:07:01 +01:00
e5dc8cb4f4 [Wasm] BLIP Example (#1183)
* blip wasm start

* fix dependency issue, move token stream here

* vanilla js worker

* roll back vscode

* spell
2023-10-26 07:24:02 +01:00
e7b886d56f Add a link to the optimisers crate. (#1180) 2023-10-25 21:51:45 +01:00
6a446d9d73 convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor

* separate tests for convert pytorch tensor
2023-10-25 19:39:14 +01:00
0acd16751d Expose the fields from batch-norm. (#1176) 2023-10-25 15:35:32 +01:00
c698e17619 Enable the test for meshgrid + fix the implementation. (#1175) 2023-10-25 13:47:54 +01:00
e4c9adfdbe Implemented meshgrid (#1174)
* Implemented meshgrid

* Resolved feedback from LaurentMazare

* Rustfmt

* Updated docstring

* Removed outdated error mode from docstring
2023-10-25 12:49:11 +01:00
b6053b938b [Wasm] Add puffin phi model to wasm (#1166)
* load config from file, add puffin phi links

* format

* add prompt examples
2023-10-25 07:09:03 +01:00
45dbe541bc fix ucopy for f64 tensors (#1170) 2023-10-24 17:06:03 +01:00
7bd0faba75 Add support for accelerate in the pyo3 bindings. (#1167) 2023-10-24 06:34:37 +01:00
807e3f9f52 derivative for GELU (#1160)
* derivative for GELU

* add tests
2023-10-23 20:23:45 +01:00
eae94a451b PyO3: Add mkl support (#1159)
* Add `mkl` support

* Set `mkl` path on linux
2023-10-23 20:10:59 +01:00
86e1803191 Add Binary Cross Entropy With Logit Loss to nn crate (#1157)
* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting
2023-10-23 17:12:44 +01:00
25c3cc4149 Mention the flash-attention restriction in the readme. (#1158) 2023-10-23 10:26:56 +01:00
a11af79e23 Add a quantized blip model. (#1155)
* Add a quantized blip model.

* Integrate the quantized blip model to the actual example.
2023-10-22 20:33:25 +01:00
8a82d623e5 Handle LongStorage in pytorch checkpoints. (#1152) 2023-10-22 18:34:36 +01:00
df2f89b6cf Add some KV cache to blip. (#1150)
* Add some KV cache to blip.

* Mention BLIP in the readme.
2023-10-22 09:44:48 +01:00
62fc965617 Expose the track-op method. (#1148) 2023-10-22 06:57:03 +01:00
5b32c2a41e Remove the unused pragma and properly apply the bias. (#1147) 2023-10-22 06:47:40 +01:00
3115fe42e4 Blip attention mask + readme (#1146)
* Add the attention mask to the blip model.

* Add a readme.
2023-10-21 22:44:13 +01:00
2531b13bf8 Blip fixes (#1145)
* Some fixes for the blip example.

* Stop generating on sep tokens.

* Clippy fixes.

* rustfmt.
2023-10-21 21:34:48 +01:00
0d9bb4eb18 Add the blip example. (#1144)
* Add the blip example.

* Tweak the example.

* Implement the cross-attn logic.

* Fix some shape mismatches.

* Get some logits out.

* Get some caption to be generated.
2023-10-21 20:05:02 +01:00
e8f760ee44 Add get_on_dim. (#1142) 2023-10-21 15:01:38 +01:00
94e3373883 Blip forward pass (#1141)
* More forward methods for the blip model.

* Blipping continues.
2023-10-21 10:19:23 +01:00
34d9e91748 Add the blip image captioning model (#1140)
* Blip text model.

* Blip vision bits.

* Blippity.

* More blip.
2023-10-20 22:09:11 +01:00
cfb423ab76 PyO3: Add CI (#1135)
* Add PyO3 ci

* Update python.yml

* Format `bert.py`
2023-10-20 19:05:14 +01:00
7366aeac21 Make func cloneable. (#1137) 2023-10-20 16:28:50 +01:00
99cf13e8e2 Add the sequential layer. (#1136) 2023-10-20 16:08:50 +01:00
b43ab6cd1d PyO3: Add None and Tensor indexing to candle.Tensor (#1098)
* Add proper `None` and `tensor` indexing

* Allow indexing via lists + allow tensor/list indexing outside of first dimension
2023-10-20 09:59:00 +01:00
31ca4897bb Readme updates. (#1134) 2023-10-20 09:08:39 +01:00
55351ef57d Add some vision transformers models (#1132)
* Start adding vision-transformers.

* Add self-attn.

* More vision transformers.

* vit-vit.

* Add the actual vit model.

* Add the example code for the vision transformers.
2023-10-19 22:24:18 +01:00
6684b7127a PyO3: Add pytorch like .to() operator to candle.Tensor (#1100)
* add `.to()` operator

* Only allow each value to be provided once via `args` or `kwargs`
2023-10-19 21:46:21 +01:00
93c25e8844 Expose the larger resnets (50/101/152) in the example. (#1131) 2023-10-19 13:48:28 +01:00
cd53c472df Support ResNet 50/101/152. (#1130) 2023-10-19 10:48:31 +01:00
6f76383f38 Add a readme for the resnet example. (#1129) 2023-10-19 09:58:50 +01:00
8e773cc0c6 Experiment with resnet (#1128)
* Add some preliminary support for resnet.

* Add an actual resnet example.
2023-10-19 09:25:03 +01:00
87eb1658e1 Add pad_with_same. (#1127)
* More model cloning.

* More cloning on quantized models.

* Add pad-with-same.

* Add some tests.
2023-10-18 23:13:37 +01:00
902d0b9166 More model cloning. (#1126)
* More model cloning.

* More cloning on quantized models.
2023-10-18 21:55:46 +01:00
185b54a33b Make some model cloneable. (#1125) 2023-10-18 19:30:47 +01:00
620c94d12e Add support for Zephyr-7b in the quantized model. (#1124) 2023-10-18 17:31:26 +01:00
86e7d539d2 Add the quantized mpt model. (#1123)
* Add the quantized mpt model.

* Support the quantized model for replit-code.
2023-10-18 16:29:38 +01:00
cb034506cd Remove the unused pragma in mpt. (#1122) 2023-10-18 15:47:50 +01:00
63c204c79e Add a mention to the replit-code model in the readme. (#1121) 2023-10-18 11:27:23 +01:00
767a6578f1 MPT alibi fixes. (#1120)
* MPT alibi fixes.

* Some more fixes.

* Finally get the model to return some sensible outputs.

* Add a readme.
2023-10-18 10:58:05 +01:00
662c186fd5 Better error message when overflowing in narrow. (#1119) 2023-10-18 08:40:14 +01:00
2cd745a97c MPT fixes. (#1117)
* MPT fixes.

* Another couple fixes.

* Another shape fix.
2023-10-17 21:53:31 +01:00
a72b50e2c0 Build alibi bias. (#1115)
* Build alibi bias.

* Apply the alibi attention bias.

* Add the replit-code example.
2023-10-17 20:41:37 +01:00
872c3f14b0 Add the MPT model. (#1114)
* Add the MPT model.

* Add ffn and block.

* Forward pass for the mpt block.

* Repeat-kv.
2023-10-17 16:06:48 +01:00
f9e93f5b69 Extend stub.py to accept external typehinting (#1102) 2023-10-17 11:07:26 +01:00
b355ab4e2e Always broadcast magic methods (#1101) 2023-10-17 10:57:12 +01:00
2fe24ac5b1 Rework the cuda casting bits. (#1112) 2023-10-17 09:44:51 +01:00
00948eb656 Formatting tweak. (#1111) 2023-10-16 21:02:53 +01:00
af67672207 Add support for Puffin-Phi-v2. (#1110)
* Add support for Puffin-Phi-v2.

* Tweak the file name.

* Support the config for puffin-phi-v2.

* Update the readme.
2023-10-16 20:54:21 +01:00
6c588c4792 Refactor the pth tensor exctraction. (#1109) 2023-10-16 18:16:34 +01:00
122da87580 feat: add pth varbuilder (#1108) 2023-10-16 16:20:36 +01:00
75629981bc feat: parse Cuda compute cap from env (#1066)
* feat: add support for multiple compute caps

* Revert to one compute cap

* fmt

* fix
2023-10-16 15:37:38 +01:00
0106b0b04c Read all the tensors in a PyTorch pth file. (#1106) 2023-10-16 13:50:07 +01:00
588ad4835a Fix the verbose prompt for phi. (#1097) 2023-10-15 10:53:25 +01:00
b73c35cc57 Improve the reshape error messages. (#1096)
* Improve the reshape error messages.

* Add the verbose-prompt flag to the phi example.
2023-10-15 10:43:10 +01:00
8f310cc666 Avoid trying to backprop through non-differentiable layers. (#1094) 2023-10-14 22:03:41 +01:00
8921d5027c Add support for phi-1.0 (#1093)
* Add support for phi-1.0

* Update the readme.
2023-10-14 20:15:43 +01:00
29c7f2565d Add some reinforcement learning example. (#1090)
* Add some reinforcement learning example.

* Python initialization.

* Get the example to run.

* Vectorized gym envs for the atari wrappers.

* Get some simulation loop to run.
2023-10-14 16:46:43 +01:00
9309cfc47d Create a new curand instead of reseeding. (#1089) 2023-10-14 10:03:59 +01:00
a193bf5f60 Another gemm update. (#1088) 2023-10-14 09:36:52 +01:00
2c110ac7d9 Add the pooling operators to the pyo3 layer. (#1086) 2023-10-13 20:18:10 +01:00
75989fc3b7 Use an attention mask in the e5 padding case. (#1085) 2023-10-13 18:53:40 +01:00
07af87a1d8 Typos. (#1084) 2023-10-13 16:21:20 +01:00
eefad2b95f Update to gemm 0.16.1 (#1083) 2023-10-13 06:40:20 +01:00
5e6df4a3f7 Update to gemm-0.16. (#1082)
* Update to gemm-0.16.

* Enable wasm-simd128.
2023-10-12 21:56:59 +01:00
7473c4ceca Fix the npy read function and add some testing. (#1080) 2023-10-12 15:25:05 +02:00
c096f02411 Add a matvec cpu benchmark. (#1076) 2023-10-12 09:29:18 +01:00
e7560443e4 Convmixer example (#1074)
* Add a convmixer based example.

* Mention the model in the readme.
2023-10-11 19:51:10 +01:00
89b525b5e7 Convmixer (#1073)
* Only optimize float tensors.

* Use full tensors for zeros and ones.

* Add a benchmark for the matmul slowness.

* Add the convmixer model.

* Proper adaptive pooling.
2023-10-11 18:24:32 +01:00
37dbbff261 Use full tensors for zeros and ones (#1071)
* Only optimize float tensors.

* Use full tensors for zeros and ones.
2023-10-11 08:16:04 +01:00
9fea56d28e Only optimize float tensors. (#1069) 2023-10-10 09:05:41 +01:00
bc3351bce4 Tracing for StableLM and quantized StableLM. (#1068) 2023-10-10 08:09:25 +02:00
b34d7f0248 Remove some unusued bits. (#1067) 2023-10-09 19:49:57 +01:00
4d04ac83c7 Override the repo for SDXL f16 vae weights. (#1064)
* Override the repo for SDXL f16 vae weights.

* Slightly simpler change.
2023-10-09 06:52:28 +01:00
392fe02fba Move the common quantized-nn code to a shared module. (#1063) 2023-10-09 06:22:22 +01:00
59ab6d7832 Quantized version of StableLM. (#1058)
* Quantized version of StableLM.

* Adapt the stable-lm example to support quantizsed.

* Use some separate hub repo.

* Another repo name tweak.
2023-10-08 15:42:38 +01:00
783735cf22 Use softmax-last-dim where possible. (#1057) 2023-10-08 13:16:42 +01:00
9abeddd750 Make the cuda rng seedable. (#1056) 2023-10-08 09:32:36 +01:00
2e5fb0b251 Do not use the kv-cache on external key-value states. (#1054) 2023-10-07 22:37:19 +01:00
823fe23f9b Add flash-attn support for stable-lm. (#1052) 2023-10-07 21:12:54 +01:00
d833527fda Use candle_nn::LSTM in encodec. (#1051)
* Use candle_nn::LSTM in encodec.

* More Encodec implementation.

* Decoder implementation.
2023-10-07 19:43:06 +01:00
a4967600d0 More general seq forward functions for RNNs. (#1050) 2023-10-07 15:08:01 +01:00
aa53368aeb Better control on the optional dequantization in QMatMul (#1049)
* Cosmetic change to the quantized whisper model.

* Fix the dequantization.

* Add the dequantize all variable.
2023-10-07 10:16:18 +01:00
955e00b2e8 Add to the readmes for stable-lm. (#1047) 2023-10-06 21:26:04 +01:00
d5f7267087 Add the stable-lm example. (#1046)
* Add the stable-lm example.

* Get stable-lm to generate some proper text.
2023-10-06 19:20:35 +01:00
904bbdae65 Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations

* Add `state_dict` and `load_state_dict` functionality

* Move modules around and create `candle.nn.Linear`

* Add `nn.Embedding` and `nn.LayerNorm`

* Add BERT implementation

* Batch q-matmul

* Automatically dequantize `QTensors` if a `Tensor` is expected

* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality

* Unittests for `Module`, `Tensor` and `candle.utils`

* Add `pytorch` like slicing to `Tensor`

* Cleanup and BERT fixes

* `black` formatting + unit-test for `nn.Linear`

* Refactor slicing implementation
2023-10-06 19:01:07 +01:00
b0442eff8a Sketch the stable-lm model. (#1045) 2023-10-06 18:19:06 +01:00
4631c48273 Remove some todos. (#1042) 2023-10-05 22:42:20 +01:00
716883e9b0 Add the clamping for stable-diffusion. (#1041) 2023-10-05 22:20:39 +01:00
47c25a567b feat: [SAM] able to download the result as png (#1035)
* feat: able to download the result as png

* feat: update function and wording
2023-10-05 22:14:47 +01:00
7f7d95e2c3 Add the round-to function. (#1039) 2023-10-05 20:28:09 +01:00
f47bd9bab5 Delete invalid comment (#1038) 2023-10-05 19:28:08 +01:00
8f7973958c fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 (#1037)
* fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0

* cargo fmt
2023-10-05 18:46:13 +01:00
f0c619a4af Use AsRef<str> for set_one. (#1033) 2023-10-05 06:05:44 +01:00
b86ac0c507 Quant t5: Add coedit model to wasm demo and readme (#1031) 2023-10-04 20:57:33 +01:00
27e70a5093 Whisper quantized wasm (#1028)
* [Whisper] Update to use quantized model

* [whisper] add language detection

* [whisper] change assets location

* [whisper] adapt js example with quantized models

* [whisper] better task parsing

* [whisper] minor fixes
2023-10-04 20:22:57 +01:00
c18a856e76 Add the rounding operators. (#1030)
* Add the rounding operators.

* Avoid tracking gradients for the rounding operations.

* Add some rounding tests.
2023-10-04 17:58:44 +01:00
3349c89252 Add quantized t5 args for weight and config (#1029) 2023-10-04 17:02:49 +01:00
11d3687cc6 Simd128 optimized q8k vecdot. (#1026) 2023-10-03 15:29:48 +01:00
dac73edb34 AVX optimized q8k vecdot. (#1024) 2023-10-03 12:10:58 +01:00
b4da19d1be Merge pull request #1023 from evgenyigumnov/simlified-book-polish
small misspeling and polish fix
2023-10-03 12:29:41 +02:00
ff513314fc small misspeling and polish fix 2023-10-03 15:47:04 +06:00
043cc25766 Fix for the index-select cuda setup. (#1022)
* Fix for index-select.

* Better fix + add some testing.
2023-10-03 10:21:46 +01:00
7b06872f90 Merge pull request #926 from evgenyigumnov/book-trainin-simplified
Book train simlified example
2023-10-03 10:41:30 +02:00
65825e7240 [SAM] Add undo button and background point mode (#1020)
* [SAM] Add undo button and background point mode

* [SAM] remove pts on near clicks

* [SAM] check shiftKey toggle point mode

* [SAM] clear points when clearing image
2023-10-02 23:33:46 +01:00
7670fe7d1f neon optimized q8k multiplication. (#1021)
* neon optimized q8k multiplication.

* Bugfixes.

* simdification.
2023-10-02 23:26:34 +01:00
cddfc3944c Add the q8k vec-dot multiplication. (#1019) 2023-10-02 21:53:34 +01:00
089fc3b584 Improve the quantized whisper setup. (#1018)
* Improve the quantized whisper setup.

* Fix the config file paths.

* Use the standard matmul where possible.
2023-10-02 17:17:46 +01:00
e04c789230 Add a quantized variant of whisper (#1017)
* Add the quantized-whisper model.

* Quantized the whisper model.

* Adapt the whisper example to handle quantization.

* Add the quantized flag.

* Load the proper weights.
2023-10-02 14:59:53 +01:00
263a172202 Improve the testing of the optimized quantized vec-dot ops (#1016)
* Expose the unopt functions for testing.

* Better testing of the optimized quantized computations.
2023-10-02 09:50:43 +01:00
638ccf9f46 Fix include code. 2023-10-02 10:22:44 +02:00
0baf5a1e19 Fixed PR warnings. 2023-10-02 10:15:10 +02:00
5130a7da32 Simd128 version of q6k vec-dot. (#1015)
* Add a specific function for the simd128 q6k vec-dot.

* Simdification.

* More simdification.
2023-10-01 19:44:12 +01:00
41143db1af [segment-anything] add multi point logic for demo site (#1002)
* [segment-anything] add multi point logic for demo site

* [segment-anything] remove libs and update functions
2023-10-01 18:25:22 +01:00
096dee7073 Bump the version to 0.3.0. (#1014)
* Bump the version to 0.3.0.

* Changelog update.
2023-10-01 13:51:57 +01:00
f6054e9d60 Fix the prompt for mistral when using instruct/interactive mode. (#1013) 2023-10-01 06:44:30 +01:00
328167ec04 Integrate TheBloke quantized mistral weights. (#1012) 2023-09-30 22:39:42 +01:00
4e55aaa51f Simd128 version of the q2k-q8k vecdot product. (#1011)
* Sketch the simd128 version of q2k vecdot.

* Use a single accumulator.

* Simdify the q2k-q8k vecdot product.

* Cosmetic change.
2023-09-30 20:12:41 +01:00
deee7612da Quantized version of mistral. (#1009)
* Quantized version of mistral.

* Integrate the quantized mistral variant.

* Use the quantized weight files.

* Tweak the quantization command.

* Fix the dtype when computing the rotary embeddings.

* Update the readme with the quantized version.

* Fix the decoding of the remaining tokens.
2023-09-30 18:25:47 +01:00
06207332bc Streaming mode for reporting the generated tokens (#1007)
* Token streaming.

* Use the token output stream.

* Flush the output.

* Ensure that the last characters get reported.
2023-09-30 15:04:11 +01:00
4021272875 Use flash-attn for mistral. (#1004) 2023-09-30 12:15:10 +01:00
87e3a4e175 Mistral: exit on eos token. (#1001)
* Mistral: exit on eos token.

* Print the proper stats.

* Also add a short flag.
2023-09-30 07:07:06 +01:00
6203ced495 Add negative prompts to segment-anything. (#1000) 2023-09-30 06:17:42 +01:00
34842fb234 [segment-anything] Print IOU values to help with debugging (#999) 2023-09-30 05:44:42 +01:00
d188d6a764 Fix the multiple points case for sam. (#998) 2023-09-29 22:39:43 +02:00
0ac2db577b Add an entry about WSL slowness to the faq. (#997) 2023-09-29 17:04:52 +01:00
fc59bc31bf fix: add missing gpu fill_* (#996) 2023-09-29 15:49:30 +01:00
03348e2e6f Update mistral README.md (#995) 2023-09-29 12:24:32 +01:00
49fa184a35 Mistral readme (#994)
* Mistral: print the generated text.

* Add mistral to the readmes.
2023-09-29 11:50:50 +01:00
6f17ef82be Mistral: print the generated text. (#992) 2023-09-29 10:56:11 +01:00
01b92cd959 fixes slice_scatter dim type (#988) 2023-09-29 07:54:45 +01:00
53510ce427 Use a silu activation in mistral. (#991) 2023-09-29 07:06:54 +01:00
23b3576c47 Add the sliding window. (#986) 2023-09-28 17:26:33 +01:00
716ab2ccdc Mistral gpu fix (#985)
* Add the mistral example.

* Use the two model files.

* Adjust the dtype.

* Tweak the weight paths.

* Remove the end of text token.

* Get the mistral model to generate some text.

* Fix when running on the gpu.

* More gpu fixes.
2023-09-28 16:38:13 +01:00
ada8851a23 Add the mistral example. (#984)
* Add the mistral example.

* Use the two model files.

* Adjust the dtype.

* Tweak the weight paths.

* Remove the end of text token.

* Get the mistral model to generate some text.
2023-09-28 16:19:18 +01:00
c05a348e36 Add the Mistral 7b model (#983)
* Start sketching the mistral 7b model.

* Add the kv cache.

* Add the decoder layer.

* Add the mistral model.

* Rotary embeddings.

* Add the attention mask.
2023-09-28 14:29:41 +01:00
25657804ef Simd128 q2k vecdot (#982)
* Sketch the simd128 version of q2k vecdot.

* Use a single accumulator.
2023-09-28 12:16:35 +01:00
5e1c595e00 Optimize the index-select cuda kernel. (#976) 2023-09-28 09:05:29 +01:00
8a49e01b9d Add the remaining quantized tests to the wasm suite. (#980) 2023-09-28 08:42:56 +01:00
9cb110c44c Sketch a simd128 optimized q4k vecdot. (#977)
* Sketch a simd128 optimized q4k vecdot.

* Simdify.

* More quantization optimizations.

* Again more simdification.

* Simdify the splitting loop.
2023-09-27 20:19:38 +01:00
667f01c173 Simd128 vec-dot for q4_0. (#974)
* Simd128 vec-dot for q4_0.

* Bugfix.

* Add wasm tests.

* Bugfix for the q40 vecdot.

* More quantization tests.
2023-09-27 14:15:30 +01:00
e59784e353 simd128 optimized q8_0 vecdot (#972)
* wasm/simd128 version of the quantized q8_0 vecdot.

* Add the missing conversion.
2023-09-27 11:03:20 +01:00
29bd6b2979 Phi 1.5 wasm module (#966)
* add phi wasm module

* replace input with textarea

* trim input prompt

* stop on <|endoftext|>

* formatting

* clean up

* add blurb, and syntax highlighting

* add phi-v1.5 wasm

* add note

* hide Options on details

* add first token to generated text

* whitespaces for new line

* fix: abort -> aborted
2023-09-27 06:07:11 +01:00
9571b200c9 fix firstToken, minor ui changes (#971) 2023-09-27 06:01:59 +01:00
ce0a4e3a85 Use the gelu-erf activation. (#969) 2023-09-26 22:30:21 +01:00
4abc1ea34d Avoid some overflows on wasm32. (#968) 2023-09-26 11:15:38 +01:00
2dd43d6cdd add eos token to phi example (#965)
* add eos token to phi example

* rustfmt + get the token directly.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-09-26 09:21:22 +01:00
1fcac4afed Expose a function to clear the KV cache on mixformers. (#964) 2023-09-26 05:41:07 +01:00
a084f65f9a fix rep penalty min value (#963) 2023-09-26 05:23:50 +01:00
c798184c2b Configurable layer idx for the lstm layer. (#962) 2023-09-25 21:31:14 +01:00
c78a294323 Add some repeat penalty to the phi example. (#961) 2023-09-25 20:53:30 +01:00
a36d883254 Use a single flag for the point argument. (#958) 2023-09-25 12:53:24 +01:00
7f2bbcf746 [segment-anything] Support multi-point as the prompt input (#945)
* [sam] Support multi-point prompts

* [segment-anything] Pass points by reference

* [segment-anything] Update example code and image

* Fix clippy lint.

---------

Co-authored-by: Yun Ding <yunding@nvidia.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-09-25 12:14:10 +01:00
dc47224ab9 Override the default cudnn heuristics. (#957) 2023-09-25 10:31:53 +01:00
1ce7fe2543 Add more examples to the phi readme. (#956) 2023-09-24 18:19:05 +01:00
402ddcfcb4 Add the missing kernel. (#955) 2023-09-24 17:21:37 +01:00
f5069dd354 Use the repo for the quantized phi model. (#954) 2023-09-24 16:30:26 +01:00
0007ae9c11 Add the quantized mixformer model. (#953)
* Add the quantized mixformer model.

* Add the quantized option in the phi example.
2023-09-24 15:03:48 +01:00
e15862cfdb Shared the quantized var-builder code. (#952)
* Shared the quantized var-builder code.

* Fix compilation.
2023-09-24 12:55:07 +01:00
4aeb449017 Depreate the VarBuilder::from_safetensors function. (#951) 2023-09-24 11:18:17 +01:00
bcb0ed8f1c Self-contained safetensors for the multiprocess llama example. (#950) 2023-09-24 06:54:49 +01:00
7edd755756 Pass directly the buffer ownership. (#949) 2023-09-24 06:34:44 +01:00
e32c89d90c Add the buffered safetensor wrapper. (#948) 2023-09-23 22:57:42 +01:00
bb3471ea31 Adapt more examples to the updated safetensor api. (#947)
* Simplify the safetensor usage.

* Convert more examples.

* Move more examples.

* Adapt stable-diffusion.
2023-09-23 21:26:03 +01:00
890d069092 Self-contained safetensor wrappers (#946)
* Self-contained safetensor wrappers.

* Use the new safetensor container in varbuilders.
2023-09-23 20:39:52 +01:00
5dbe46b389 Add tracing. (#943) 2023-09-23 16:55:46 +01:00
ccf352f3d1 Use yoke to provide a self-referential container for mmaped safetenso… (#939)
* Use yoke to provide a self-referential container for mmaped safetensor files.

* Add the new self-owned type for safetensor files without removing the previous version.

* Add routing.

* Add an initializer for the case of multiple files.
2023-09-23 15:43:11 +01:00
402d207f0f VarMap setter functions (#938)
* Add some setter helper functions for varmap.

* Add more comments.
2023-09-23 10:27:51 +01:00
7582937a32 Add the causal mask in mixformer. (#937) 2023-09-23 09:50:26 +01:00
b54acfa3d0 Tracing for the phi model (#936)
* Add some tracing bits to mixformers.

* Add the missing file.

* Add the conv2d layer to with-tracing.

* Improve the tracing usage.
2023-09-23 09:19:34 +01:00
cda1786eed smaller t5 models quantized (#934) 2023-09-22 22:31:23 +01:00
912a3d63b0 Use the proper block size for quantizing models. (#933)
* Use the proper block size for quantizing models.

* Use the proper dimension.
2023-09-22 21:36:56 +01:00
3ef328c53d Mention the new phi model in the readme. (#932) 2023-09-22 21:24:51 +01:00
0c8e983514 update link to t5 (#931) 2023-09-22 20:30:01 +01:00
df6f5240ba Complete the mixformer implementation. (#930)
* Complete the mixformers implementation.

* Tweak the attention.

* Add the phi-1.5 example.

* Improve the phi example.

* Bugfix.

* Get the phi example to work.
2023-09-22 20:03:16 +01:00
a46b1b4657 Mixformer (#929)
* Sketch the mixformer model.

* More modeling code.

* More mixformers.

* MixFormer creation.

* More mixformers.
2023-09-22 16:17:14 +01:00
19e52e5007 T5 Wasm (#918)
* init t5 wasm model

* split workers for each model

* clean up

* add some ui

* readme

* index

* typo

* remove cache param, clear_kv_cache

* add max_length as param

* add model tasks option to ui

* add method to load quantized gguf from buffer

* Add quantized wasm module

* add quantized models to UI, dynamic import wasms

* link to quantized

* fix copy

* fix ModelEncoder

* fix README.md
2023-09-22 15:31:10 +01:00
8601537e31 Add slice-scatter. (#927)
* Add slice-scatter.

* Add the op.

* Make transpose be a no-op when the dimensions are identical.

* Add the backprop.

* And add some gradient test.
2023-09-22 12:18:16 +01:00
4ac6039a42 Merge branch 'main' into book-trainin-simplified 2023-09-22 11:01:23 +06:00
52a60ca3ad https://github.com/huggingface/candle/issues/637 2023-09-22 10:57:11 +06:00
a96878f235 cuda cast i64 (#925) 2023-09-21 19:52:39 +01:00
aa8ec06fd2 Add the t5-xxl version. (#924) 2023-09-21 14:48:13 +01:00
b43ca493f6 Add more quantized flan t5 variants (#923)
* Add the quantized flan-t5-large variant.

* Add more sizes.
2023-09-21 13:23:30 +01:00
3b557765e8 T5 quantized example (#922)
* Load gguf files for the quantized t5.

* Add the quantized t5 example.

* Allow for loading local files.

* Add some support for quantizing safetensor files.

* Transpose before quantizing.

* Quantized t5.

* Retrieve the weights from the hub.
2023-09-21 12:33:15 +01:00
2619c4307f Add a quantized version of the t5 model. (#921) 2023-09-21 11:13:39 +01:00
c89b82b2d4 Add a clear cache function to the t5 model. (#919) 2023-09-21 09:01:06 +01:00
7b26e513f1 Add the erf function. (#917) 2023-09-21 06:19:10 +01:00
ab1d40ea97 Add more t5 tracing. (#915) 2023-09-20 20:20:54 +01:00
3a0d3e05df Add more t5 tracing. (#914)
* Add more t5 tracing.

* Rever the sm change.
2023-09-20 16:37:51 +01:00
9b24d89d2d Tracing mode for T5. (#913)
* Tracing mode for T5.

* Tracing for the linear layer.
2023-09-20 15:03:35 +01:00
fb1c2ac535 Add flash-attn support. (#912)
* Add flash-attn support.

* Add the use-flash-attn flag.

* Re-enable flash-attn.
2023-09-20 14:07:55 +01:00
728e167334 Add details on wuerstchen. (#911) 2023-09-20 13:09:35 +01:00
7b1ddcff47 Add clone to various nn layers. (#910) 2023-09-20 11:33:51 +01:00
f685b2231c Add some missing biases. (#908) 2023-09-20 10:14:51 +01:00
c0b49d5a50 Wuerstchen parameter tweaks. (#907) 2023-09-20 09:26:24 +01:00
098dd0d1e9 fix: add missingtop_p in llama_multiprocess (#905) 2023-09-20 08:54:56 +01:00
05626ef492 Flan T5: Read lm_head when word embeddings are not tied (#903)
* Read lm_head when word embeddings are not tied

* Fix formatting

* Address comments
2023-09-19 22:36:47 +01:00
67a486d18d Line-up the wuerstchen model with the python implementation. (#901)
* Line-up the wuerstchen model with the python implementation.

* Missing cos.

* Fix the picture denormalization.
2023-09-19 21:59:44 +01:00
7ad82b87e4 BERT Wasm (#902)
* implement wasm module

* add example to workspace

* add UI explore semantic similiarity

* change status messages

* formatting

* minor changes
2023-09-19 21:31:37 +01:00
8696f64bae Fix T5 kv cache (#899)
* Fix T5 kv cache

* Add argument for decoder prompt

* Fix range
2023-09-19 20:36:15 +01:00
d7e48234d4 Add an erf based gelu op (#900)
* Erf based gelu.

* Add the erf backed gelu.

* Test the new gelu op (which is not gelu_new).
2023-09-19 19:54:28 +01:00
34f2ecbc3b Fix the leaky relu. (#898) 2023-09-19 18:17:17 +01:00
4f91c8e109 Improve the error message on shape mismatch for cat. (#897)
* Improve the error message on shape mismatch for cat.

* Cosmetic tweak.
2023-09-19 15:09:47 +01:00
06e46d7c3b Only use classifier free guidance for the prior. (#896)
* Only use classifier free guidance for the prior.

* Add another specific layer-norm structure.

* Tweaks.

* Fix the latent shape.

* Print the prior shape.

* More shape fixes.

* Remove some debugging continue.
2023-09-19 14:13:05 +01:00
9cf26c5cff Fix typo in error_manage.md (#888)
occured -> occurred
2023-09-19 07:14:15 +01:00
aaa9d4ed6c W decoding. (#893)
* W decoding.

* Add the diffusion loop.

* Use the appropriate config.
2023-09-19 07:13:44 +01:00
92db8cecd3 Specialized attention module for Wuerstchen. (#890)
* Specialized attention module for Wuerstchen.

* Reshaping ops.

* Attention processor.

* Finish the forward pass.

* Hook the new attention processor.

* Get the prior forward pass to work.

* Make it contiguous.
2023-09-18 21:16:09 +01:00
1542e92629 T5: Add option to override use_cache from config (#892)
* Add option to override use_cache from config

* Disable cache by default and cleanup code
2023-09-18 20:20:21 +01:00
82a98f6da0 Prior denoising. (#889) 2023-09-18 16:51:38 +01:00
5082954c52 Fix the W clip embeddings. (#887)
* Fix the W clip embeddings.

* Add the specialized ddpm scheduler.
2023-09-18 14:50:14 +01:00
7dd8e12472 Bump the crate versions to v0.2.3. (#886)
* Bump the crate version.

* Also update the python bindings.
2023-09-18 12:14:03 +01:00
12696b7b2d Fix typos in SAM WASM example (#884) 2023-09-18 09:41:50 +01:00
ef8cd8fea0 Update the candle-gemm version. (#885) 2023-09-18 09:36:20 +01:00
03e194123d Add return types to *.pyi stubs (#880)
* Start generating return types

* Finish tensor type hinting

* Add `save_gguf` to `utils`

* Typehint `quant-llama.py`
2023-09-17 22:11:01 +01:00
c2b866172a More Wuerstchen fixes. (#882)
* More Weurstchen fixes.

* More shape fixes.

* Add more of the prior specific bits.

* Broadcast add.

* Fix the clip config.

* Add some masking options to the clip model.
2023-09-17 22:08:11 +01:00
06cc329e71 Remove the parameters for the Wuerstchen layer-norm. (#879)
* Remove the parameters for the Wuerstchen layer-norm.

* Fixes.

* More fixes (including conv-transpose2d.

* More fixes.

* Again more fixes.
2023-09-17 15:59:27 +01:00
5f83c13f17 Add the DDPM scheduler. (#877)
* Add the DDPM scheduler.

* Minor tweaks.
2023-09-17 15:03:01 +01:00
db3e9dae04 Wuerstchen main (#876)
* Wuerstchen main.

* More of the wuerstchen cli example.

* Paella creation.

* Build the prior model.

* Fix the weight file names.
2023-09-17 12:46:38 +01:00
7f65af1f0d Avoid re-encoding the input in the T5 example. (#875) 2023-09-17 10:25:54 +01:00
eeb54716dd Tweaks for the T5 example. (#874) 2023-09-17 10:05:15 +01:00
1a276b5da7 Add a KV cache to T5. (#873)
* Add a KV cache to T5.

* Suggest using release mode.

* Use the kv cache in decoding.

* Add a comment.
2023-09-17 08:00:45 +01:00
8658df3485 Generate *.pyi stubs for PyO3 wrapper (#870)
* Begin to generate typehints.

* generate correct stubs

* Correctly include stubs

* Add comments and typhints to static functions

* ensure candle-pyo3 directory

* Make `llama.rope.freq_base` optional

* `fmt`
2023-09-16 17:23:38 +01:00
7cafca835a readme tweaks. (#867) 2023-09-16 07:22:24 +01:00
04ca2b9ebd Update README + SAM (#866)
* use serde-wasm-bindgen, faster serialization

* update readme with demos
2023-09-16 07:34:13 +02:00
635012d770 Do not backprop through argmin/argmax. (#865) 2023-09-15 22:15:40 +01:00
3e49f8fce5 Implement T5 decoding (#864)
* Load t5 decoder

* Run enc, dec, and lm head, but no cross attn

* Cross-attention over key_value_states

* New arg for decoder input ids

* Add mask, don't forward position biases through decoder

* Update t5 examples

* Clippy + rustfmt
2023-09-15 22:05:12 +02:00
c2007ac88f W fixes. (#862) 2023-09-15 15:11:11 +01:00
30be5b6660 Replication pad (#861)
* Add the embed mapper convolutions.

* Add the replication pad layer.

* Use the replication-pad op.

* Tweak a todo.
2023-09-15 14:06:21 +01:00
107d3d9530 Add the embed mapper convolutions. (#860) 2023-09-15 11:38:38 +02:00
2746f2c4be DiffNeXt/unet (#859)
* DiffNeXt/unet

* Start adding the vae.

* VAE residual block.

* VAE forward pass.

* Add pixel shuffling.

* Actually use pixel shuffling.
2023-09-15 10:14:02 +01:00
81a36b8713 Add link error info (#851)
* add link error info

* grammar fix
2023-09-15 07:25:10 +01:00
0633c85514 Add leaky-relu in the activation enum. (#858) 2023-09-15 07:05:38 +01:00
39157346cb Add SAM UI Demo (#854)
* fix tensor flattening

* send image data back

* sam ui worker example

* SAM example

* resize container

* no need for this
2023-09-15 06:31:58 +01:00
5cefbba757 minor UI fixes (#856)
* fixes

* remove listener

* remove event listener
2023-09-15 06:30:50 +01:00
130fe5a087 Add the upblocks. (#853) 2023-09-14 22:24:56 +01:00
91ec546feb More DiffNeXt. (#847)
* More DiffNeXt.

* Down blocks.
2023-09-14 22:16:31 +02:00
0a647875ec Use softmax-last-dim in the quantized example. (#848) 2023-09-14 17:29:24 +01:00
a0c6d5548c Add the attention block. (#846)
* Add the attention block.

* Add more to clipnext.
2023-09-14 15:40:09 +01:00
286f01db14 Start adding the Wuerstchen diffusion pipeline (#843)
* Wuerstchen common bits.

* Add the prior layer.

* Start adding diffnext.
2023-09-14 10:56:07 +01:00
d6447ad635 Tensor based indexing. (#842) 2023-09-14 07:47:07 +01:00
49d3f7f708 Add support to flan-t5 (#840) 2023-09-13 19:27:20 +02:00
9a465e1b26 Add 1d upsampling. (#839)
* Add 1d upsampling.

* Add the interpolate functions.
2023-09-13 16:50:39 +01:00
31ab2ddaeb Remove the padding. (#838) 2023-09-13 13:00:59 +01:00
b11a2a7b9d Move the constant to avoid some unused warning. (#837) 2023-09-13 11:56:53 +01:00
1c09164021 Add CANDLE_NVCC_CCBIN support for candle-kernels, and eliminate warning. (#836) 2023-09-13 11:39:22 +01:00
3e94324012 Add some sentence similarity part to the t5 example. (#835)
* Add some sentence similarity part to the t5 example.

* Clippy fix.
2023-09-13 10:44:02 +01:00
e6f040d6e3 Readme gallery (#834)
* More readme tweaks.

* Update README.md
2023-09-13 09:05:47 +01:00
cbd36157ac Add a gif to the quantized readme. (#833)
* Add a gif to the quantized readme.

* gif update.
2023-09-13 08:43:52 +01:00
18d3c803a8 Scalar support in minimum/maximum. (#832)
* Scalar support in minimum/maximum.

* Add a clamp method to tensors.
2023-09-13 08:24:58 +01:00
e4553fb355 T5 tweaks (#831)
* Use default values rather than options.

* Avoid exposing the device field.

* More tweaks.
2023-09-13 07:37:04 +01:00
d801e1d564 Clippy fix. (#830) 2023-09-13 07:16:20 +01:00
9daa6dbe87 Extract T5 module and add main function to use it (#829)
* Extract t5 out of musicgen

* Add main for t5 module
2023-09-13 07:14:05 +01:00
e82fcf1c59 Add more example readmes. (#828)
* Add more readmes.

* Add a readme for dinov2.

* Add some skeleton files for a couple more examples.

* More whisper details.
2023-09-12 17:21:24 +01:00
805bf9ffa7 Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling

* Update changelog

* rustfmt

* Add tests

* Fix clippy warning

* Fix another clippy error
2023-09-12 18:10:16 +02:00
42da17694a Segment Anything readme (#827)
* Add a readme for the segment-anything model.

* Add the original image.

* Clean-up the segment anything cli example.

* Also print the mask id in the outputs.
2023-09-12 14:35:55 +01:00
25aacda28e Add useful libraries section (#825)
* Add useful libraries section

* Add link
2023-09-12 11:06:21 +01:00
7a62aad24a Add a readme for yolo-v8. (#824) 2023-09-12 11:01:06 +01:00
bb23b90b1d Add a small readme for the quantized example. (#823) 2023-09-12 10:17:31 +01:00
2257f4d475 Bump the crate version + update the changelog. (#822) 2023-09-12 06:39:24 +01:00
871efc0307 Bugfix for the conv2d cpu kernel. (#820) 2023-09-11 23:11:27 +01:00
c5a058b169 Use the module trait in stable-diffusion. (#817) 2023-09-11 20:40:07 +01:00
59e63d690c Add weight, bias, and hidden_size methods (#816)
* Add weight, bias methods to Conv(1|2)

* Add hidden_size method to Embedding

* Expose hidden_size
2023-09-11 16:01:11 +01:00
dbd4561416 im2col version of the conv1d kernel. (#815)
* im2col version of the cuda conv1d kernel.

* im2col version of the conv1d cpu kernel.
2023-09-11 14:40:09 +01:00
5c35fbbb13 Stable-Diffusion readme (#814)
* Stable Diffusion readme.

* Fix the image path.

* Move the assets.

* Resize the sample image.

* Lower resolution.
2023-09-11 13:06:29 +01:00
70f38c2069 Proper error on unsupported dtypes when using gemm. (#813) 2023-09-11 12:10:51 +01:00
d7b9fec849 Move the stable-diffusion modeling code so that it's easier to re-use. (#812) 2023-09-11 11:45:57 +01:00
84ee870efd Use softmax-last-dim in whisper. (#810) 2023-09-11 11:05:05 +01:00
df712ecf64 Handle the case where the kernel is not contiguous in the cuda backend. (#809) 2023-09-11 09:48:31 +01:00
6fb665004c Enable im2col on the cpu side. (#805)
* Enable im2col on the cpu side.

* Hook im2col on the cpu backend.

* Use the kernel offset.

* Avoid an unnecessary copy.

* Handle non-contiguous kernels.

* Add a const to select the conv2d kernel.
2023-09-11 09:28:13 +01:00
1cd74129d4 Add Im2Col support on the gpu side. (#808)
* Add Im2Col support on the gpu side.

* Actually enable.
2023-09-11 08:52:33 +01:00
98d1242b8f im2col based conv2d (#802)
* im2col implementation for conv2d.

* Fix for the im2col implementation to match the current conv2d.

* Small optimization.

* Add a cuda kernel.

* Handle arbitrary layouts.

* Im2Col cuda code.
2023-09-10 21:02:42 +01:00
18d6db2180 more doc fixes (#804) 2023-09-10 20:36:29 +01:00
4f18180fc7 Bugfix so that im2col produce the same results as conv2d. (#801) 2023-09-10 16:59:46 +01:00
559944146f Add an im2col based benchmark. (#800)
* Add an im2col based benchmark.

* Reshape the final result.
2023-09-10 16:56:28 +01:00
3dd5804299 Fix typo in readme. (#799) 2023-09-10 13:49:47 +01:00
90e077e409 Return the low res mask in the wasm segment-anything module. (#798)
* Return the low res mask.

* Add some validations.
2023-09-10 13:03:02 +01:00
584171cae1 Add a wasm module for the segment anything example. (#797) 2023-09-10 12:29:37 +01:00
6c58fc59fd Little docs changes (#791)
* Little doc fixes

* change imports in lib

* rename candle_core to candle

* revert "rename candle_core to candle"
2023-09-10 12:02:52 +01:00
35f72514f5 Move more models to candle-transformers (#796)
* Move dinov2.

* Move efficientnet.

* Move the quantized llama model.

* Move segment-anything.
2023-09-10 10:20:18 +01:00
d3f05eae8c Move some models to candle-transformers so that it's easier to re-use. (#794)
* Move some models to candle-transformers so that they can be shared.

* Also move falcon.

* Move Llama.

* Move whisper (partial).
2023-09-10 09:40:27 +01:00
258ac32c38 Fix cuda randn when generating an odd number of values. (#793) 2023-09-09 18:44:21 +01:00
31936c08fe ViT tracing. (#790) 2023-09-09 17:26:39 +01:00
74ad4deb42 Get the MobileSAM TinyViT based version to work. (#789)
* More TinyViT support in SA.

* More mobilesam work.

* Add the mobile-sam weights to the hub.
2023-09-09 16:21:44 +01:00
b7cd58473b TinyViT backbone for segment-anything. (#787)
* TinyViT.

* More TinyViT.

* Add more to the tinyvit backbone.

* Proper padding.

* Plus ViT.

* Add the tiniest vit spec.
2023-09-09 15:10:06 +01:00
3cd7e7b51d Fuse the rel-pos additions via a custom-op. (#786)
* Fuse the rel-pos additions via a custom-op.

* Run with rayon.

* Add more tracing.
2023-09-09 10:46:09 +01:00
722c50bb0c Use byteorder in mnist. (#785) 2023-09-09 09:03:59 +01:00
976a1086ee feat: u32 from_be_bytes (#765) 2023-09-09 08:55:35 +01:00
c88d6fd4b9 Remove set_training. (#784) 2023-09-09 08:27:37 +01:00
057f7909bc Accelerate support for gelu. (#782) 2023-09-08 21:58:56 +01:00
acf8f10ae1 Get the comparison operation to work on scalar values. (#780)
* Get the comparison operation to work on scalar values.

* Add some time measurement.
2023-09-08 20:13:29 +01:00
0906acab91 Automatic mask generation (#779)
* A few more contiguous fixes for cuda.

* Mask generation.

* Generic bbox.

* Generate all the masks.
2023-09-08 19:11:34 +01:00
158ff3c609 Add tracing to segment-anything (#777)
* Tracing support for segment-anything.

* More tracing.

* Handle the empty slice case.
2023-09-08 15:31:29 +01:00
e5703d2f56 Draw the mask on a merged image. (#775)
* Draw the mask on a merged image.

* Clippy fix.

* Enable the target point by default.

* Add to the readme.
2023-09-08 14:04:34 +01:00
98172d46fa Fix some errors about BlockQ8_1 (#776)
* use int8 type instead of uint8 for BlockQ8_1.qs

The uint8 type of BlockQ8_1.qs causes great loss for negative weights
Ref: ebc96086af/ggml.c (L904)

Signed-off-by: Zhang Miaolei <zmlcc@outlook.com>

* fix sum error in vec_dot of BlockQ4_1

Ref: ebc96086af/ggml.c (L2840)

Signed-off-by: Zhang Miaolei <zmlcc@outlook.com>

* fix sum error in vec_dot of BlockQ5_1

Ref: ebc96086af/ggml.c (L3490)

Signed-off-by: Zhang Miaolei <zmlcc@outlook.com>

---------

Signed-off-by: Zhang Miaolei <zmlcc@outlook.com>
2023-09-08 13:29:40 +01:00
28c87f6a34 Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator.

* Add the target point.

* Fix.

* Remove the allow-unused.

* Mask post-processing.
2023-09-08 12:26:56 +01:00
c1453f00b1 Improve the safetensor loading in the segment-anything example. (#772)
* Improve the safetensor loading in the segment-anything example.

* Properly handle the labels when embedding the point prompts.
2023-09-08 09:39:10 +01:00
989a4807b1 Use shape with holes. (#771) 2023-09-08 08:50:27 +01:00
0e250aee4f Shape with holes (#770)
* Shape with holes.

* rustfmt.
2023-09-08 08:38:13 +01:00
cfcbec9fc7 Add small customization to the build (#768)
* Add ability to override the compiler used by NVCC from an environment variable

* Allow relative paths in CANDLE_FLASH_ATTN_BUILD_DIR

* Add the compilation failure to the readme, with a possible solution

* Adjust the error message, and remove the special handling of the relative paths
2023-09-08 08:15:14 +01:00
3898e500de Generate a mask image + the scaled input image. (#769)
* Also round-trip the original image.

* Make it possible to use a safetensors input.
2023-09-08 05:53:08 +01:00
79c27fc489 Segment-anything fixes: avoid normalizing twice. (#767)
* Segment-anything fixes: avoid normalizing twice.

* More fixes for the image aspect ratio.
2023-09-07 21:45:16 +01:00
7396b8ed1a Segment Anything - process images (#766)
* Start processing images.

* Add LayerNorm2d.

* Properly use LayerNorm2d.

* Tweak eps.

* Use LayerNorm on inputs with a rank different from 3.

* Window partitioning.

* Fix a couple todos.

* More todos.

* Hard-code the einsums.

* More padding support.

* Some sizes tweaks.

* Use the hub to get the weights.

* Use a batch matmul.

* Tweaks.

* More fixes.

* Get some predictions to be generated.
2023-09-07 19:22:45 +01:00
7b50f3e106 More segment-anything again. (#764)
* More segment-anything again.

* Transformer block forward.

* Two-ways transformer.

* Position embeddings.

* Sketch the prompt encoder.

* More prompt-encoder.

* More prompt-encoder.

* Add the main sam module.

* Embed the transformer.

* And hook the transformer forward step.

* Build the model.

* Handle the global attn indexes.

* Get the model to load.
2023-09-07 12:06:55 +01:00
8c991df394 More segment-anything. (#763)
* More segment-anything.

* Split the model in multiple files.

* Start adding the transformer.

* Add the attention block.

* Move the MLP Block.
2023-09-07 07:28:30 +01:00
000fa00e31 Expose the conv2d-transpose layers. (#761) 2023-09-07 06:04:52 +01:00
a17a7c42c1 Add a nn layer for conv-transpose2d. (#760) 2023-09-07 05:47:28 +01:00
6527ab81a3 Sketch the segment anything model. (#759)
* Sketch the segment anything model.

* Fix some clippy lint.

* Add the mask decoder.
2023-09-07 05:34:05 +01:00
7b1f2da828 Cudnn fix. (#758) 2023-09-06 17:39:39 +01:00
bdc9d46fe3 Use an arc in the varbuilder rather than rc. (#757)
* Use an arc in the varbuilder rather than rc.

* Require the backends to be send.

* Request send and sync.
2023-09-06 15:29:09 +01:00
dcf708559d Fix for cudnn to work with img2img. (#753) 2023-09-06 07:49:28 +01:00
7299a68353 img2img pipeline for stable diffusion. (#752)
* img2img pipeline for stable diffusion.

* Rename the arguments + fix.

* Fix for zero strength.

* Another fix.

* Another fix.

* Revert.

* Include the backtrace.

* Noise scaling.

* Fix the height/width.
2023-09-06 07:06:49 +01:00
16bf44f6e9 force model cache (#751) 2023-09-06 05:53:31 +02:00
a4f40f3dc8 Use rayon directly rather than constraining the number of threads. (#749) 2023-09-05 20:26:15 +01:00
6a40decc76 Minor WASM UI improvements (#748)
* add stats

* random seed btn

* minor ui improvoments
2023-09-05 19:24:43 +01:00
a0d65585db Softmax implementation for cuda. (#747) 2023-09-05 18:38:03 +01:00
94c6a8d3d3 Add a dedicated cuda kernel for softmax. (#746) 2023-09-05 17:53:20 +02:00
6615daf242 Tweaks to softmax. (#745) 2023-09-05 15:22:27 +01:00
1c9e5394a5 Add a custom softmax implementation. (#744)
* Add a custom softmax implementation.

* Add softmaxlastdim to the benchmarks.

* And add a test.

* Support more dtypes.

* Polish the code.

* Use the slow implementation on cuda.

* Add a todo for the cuda kernel.
2023-09-05 14:20:23 +01:00
a8410bf35e Add some documentation. (#743) 2023-09-05 09:51:12 +01:00
cda45a7443 Let outside CustomOp2 implementations use binary_map/binary_map_vec (#741) 2023-09-05 09:27:32 +01:00
4698eb5cb6 Fix typo in the nll function document (#742) 2023-09-05 09:25:11 +01:00
000487c36f Add a python function to save as safetensors. (#740) 2023-09-04 20:32:14 +01:00
ab0d9fbdd1 Properly set the is_bf16 flag. (#738) 2023-09-04 16:45:26 +01:00
f80fd44201 BF16 support for flash-attn. (#737) 2023-09-04 16:35:43 +01:00
0d00c06a83 Fix clippy lint. (#736) 2023-09-04 16:09:19 +01:00
8395152d20 Llama2c WASM UI improvements (#732)
* pass seed, expose model seq_len

* wip new llama2.c ui

* final new UI example

* small coppy

* copy
2023-09-04 15:59:22 +01:00
e2f9f60ac2 Avoid some redundant clone. (#731) 2023-09-04 09:18:32 +02:00
d0cdea95a5 Add back the bf16 flash-attn kernels. (#730) 2023-09-04 07:50:52 +01:00
20512ba408 Return the metadata in the gguf pyo3 bindings. (#729)
* Return the metadata in the gguf pyo3 bindings.

* Read the metadata in the quantized llama example.

* Get inference to work on gguf files.
2023-09-04 07:07:00 +01:00
9c61b0fc9b Proper log buckets for t5. (#727)
* Proper log buckets for t5.

* Properly pass the position bias.
2023-09-03 20:33:50 +01:00
26cd266e65 Musicgen text embeddings. (#726)
* Musicgen text embeddings.

* Bugfix for layer norm.

* Proper position bias.

* Expose the weights.
2023-09-03 18:27:48 +01:00
bbec527bb9 Fix the musicgen example. (#724)
* Fix the musicgen example.

* Retrieve the weights from the hub.
2023-09-03 14:50:39 +01:00
f7980e07e0 Add ggufv2 support (#725) 2023-09-03 14:41:57 +01:00
74a82c358a Add the mse loss. (#723) 2023-09-03 10:51:40 +01:00
84d003ff53 Handle arbitrary shapes in Tensor::new. (#718) 2023-09-02 19:59:21 +01:00
21109e1983 Recommend using maturin. (#717) 2023-09-02 16:19:35 +01:00
ad796eb4be More quantized llama in python. (#716)
* More quantized llama in python.

* Expose a couple more functions.

* Apply the last layer.

* Use the vocab from the ggml files.
2023-09-02 13:41:48 +01:00
e8e33752f4 Sketch a quantized llama using the pyo3 api. (#715)
* Sketch a quantized llama using the pyo3 api.

* Add more ops.

* Expose a few more functions to use in the quantized model.

* Rope embeddings.

* Get the forward pass to work.
2023-09-02 11:26:05 +01:00
dabaa479b9 Update README.md (#714) 2023-09-02 07:56:12 +01:00
2c1df6bba1 Add a repeat penality to the llama2-c command line example. (#713)
* Add a repeat penality to the llama2-c command line example.

* Another fix attempt.
2023-09-01 20:38:58 +01:00
4d56cef583 Handle the empty sequence case properly. (#712)
* Handle the empty sequence case properly.

* Proper fix.
2023-09-01 20:12:30 +01:00
19042962d5 Whisper fix (#711)
* Remove unnecessary file.

* Whisper fix.
2023-09-01 20:04:07 +01:00
731e3ffb03 Remove unnecessary file. (#710) 2023-09-01 19:42:23 +01:00
2fef14cb14 Add a repeat penalty to the llama2.c wasm example. (#709) 2023-09-01 19:32:28 +01:00
1e5b2cc1d5 Add some quantized functions to pyo3. (#708) 2023-09-01 19:45:36 +02:00
2ed78ab336 Support for quantized tensors in the python api. (#706)
* Add more pyo3 support.

* Add some support for quantized tensors in pyo3.

* Add an arc layer on qmatmul.

* Add the quantized matmul.

* Quantization support.

* More quantization support.

* Test the python quantization.
2023-09-01 15:53:42 +01:00
237323c2bc Cleanup the pyo3 setup. (#705) 2023-09-01 14:26:18 +01:00
af552a5274 Fix the rnn tests for accelerate. (#704) 2023-09-01 13:21:38 +01:00
7529531056 Add the optimizer trait. (#702) 2023-09-01 12:55:39 +01:00
f2d476ca65 Replace the discord link. (#701) 2023-09-01 09:43:55 +01:00
f9f482d4e5 Add some doc to the varbuilder. (#700) 2023-09-01 08:28:35 +01:00
9736236175 Allow retrieving and setting prefix of VarBuilder (#699) 2023-09-01 08:08:41 +01:00
30a4b593d7 More ops again. (#697) 2023-08-31 22:28:48 +01:00
949f1eae6f Implement a couple more binary ops. (#693) 2023-08-31 21:30:15 +01:00
7cef35c84d Tweak some quantized args (#692)
* Print the args + change the default temp/repeat penalty.

* Minor formatting tweak.
2023-08-31 17:25:21 +01:00
7509c98970 Interactive mode for the quantized model. (#690) 2023-08-31 10:52:42 +01:00
94aa234dfd Add the kv-cache to the whisper wasm version. (#689)
* Add the kv-cache to the whisper wasm version.

* Improve the handling of special tokens.
2023-08-31 09:37:44 +01:00
db59816087 Add a GRU layer. (#688)
* Add a GRU layer.

* Fix the n gate computation.
2023-08-31 08:43:10 +01:00
d210c71d77 Set the learning rate. (#687) 2023-08-31 08:03:40 +01:00
8e84d8a59b Llama2.c wasm module. (#686) 2023-08-31 07:44:32 +01:00
9bd486fb96 Add Yolo Pose to JS Example (#684)
* add support for yolo pose models

* fix copy
2023-08-31 06:32:57 +01:00
eaf760a751 Add a python variant for the lstm test. (#682) 2023-08-30 22:32:08 +01:00
1d0bb48fae Improve Whisper WASM UI example (#669)
* wip add module and js worker example

* params

* clean up, send error

* final UI with whisper webworker

* add simple instructions
2023-08-30 20:35:41 +02:00
21e1c73892 Add a LSTM test. (#681)
* Add a LSTM test.

* Clippy.
2023-08-30 20:05:42 +02:00
2047d34b7c More robust tests (so that they pass on accelerate). (#679) 2023-08-30 18:10:10 +01:00
9874d843f1 Fix the accelerate build (#678)
* Cosmetic changes.

* Fix the accelerate build for tanh.
2023-08-30 18:31:14 +02:00
7d753d3acd Mnist training dropout (#677)
* Use dropout in the mnist training.

* Fix.
2023-08-30 16:41:01 +01:00
3159982a89 Add a Dropout layer (#676)
* Add a dropout layer.

* Add an actual layer.
2023-08-30 16:19:28 +01:00
ad8a62dbf5 Add tanh. (#675)
* Add tanh.

* Use tanh in the lstm block.

* Add a test for tanh forward and backward passes.
2023-08-30 13:54:50 +01:00
f35b9f6baa Add some recurrent neural networks (#674)
* Add the rnn module.

* More LSTM.

* Implement the RNN forward pass.

* More forward pass for LSTM.
2023-08-30 13:27:09 +01:00
618f4e4c78 Add some documentation. (#673)
* Add some documentation.

* Bump the crate version.
2023-08-30 11:54:00 +01:00
5ac0a98f01 Changelog update. (#672) 2023-08-30 09:27:56 +01:00
393690387f Support dilation in conv-transpose2d. (#671) 2023-08-30 09:22:00 +01:00
9b25113393 Small cleanups (avoid some possible mutations) (#670)
* More mut cleanup.

* Factor out some common bits.
2023-08-30 08:54:00 +01:00
a1a5ab8b0a Neon optimized vecdot (#666)
* Q5k vecdot.

* Add the q3k vecdot.

* Q2k vecdot.

* Move the quantized model to its own file.
2023-08-29 22:28:46 +01:00
59b731de99 Add the powf op. (#664)
* Add the powf op.

* Cuda kernels and backprop.

* Add a test.
2023-08-29 20:48:18 +01:00
2d3fcad267 Simplify usage of the pool functions. (#662)
* Simplify usage of the pool functions.

* Small tweak.

* Attempt at using apply to simplify the convnet definition.
2023-08-29 19:12:16 +01:00
b31d41e26a Add a convnet training example. (#661)
* Add a convnet example.

* Dataset fix.

* Randomize batches.
2023-08-29 18:23:01 +01:00
71221559d3 Fix the dilated convolutions. (#659) 2023-08-29 16:37:42 +01:00
a044907ffc Dilated convolutions (#657)
* Add the dilation parameter.

* Restore the basic optimizer example.

* Dilation support in cudnn.

* Use the dilation parameter in the cpu backend.

* More dilation support.

* No support for dilation in transposed convolutions.

* Add dilation to a test.

* Remove a print.

* Helper function.
2023-08-29 16:12:11 +01:00
ee8bb1bde1 Add avx implemenetations of q2k, q3k and q5k vec-dot functions (#654)
* `q2k` avx implementation

* `q3k` avx implementation

* `q5k` avx implementation

* `avx` make masks constant

* clippy stuff
2023-08-29 13:35:56 +01:00
3d2d3c7edb Merge pull request #658 from huggingface/upgrade_hf_hub2
Upgrading hf-hub (for windows support, removing symlink requirement).
2023-08-29 14:32:15 +02:00
1aca6fa291 Upgrading hf-hub. 2023-08-29 14:18:54 +02:00
4ed202447e Upgrading hf-hub. 2023-08-29 14:14:26 +02:00
1d6bff53fc Changelog update. (#656) 2023-08-29 12:55:56 +01:00
14b4d456e8 Merge pull request #439 from huggingface/training_hub_dataset
[Book] Add small error management + start training (with generic dataset inclusion).
2023-08-29 13:10:05 +02:00
2d5b7a735d Update the book with new layout of datasets. 2023-08-29 12:51:59 +02:00
62ef494dc1 Use multiple transformer layer in the same cross-attn blocks. (#653)
* Use multiple transformer layer in the same cross-attn blocks.

* Make the context contiguous if required.
2023-08-29 11:13:43 +01:00
d0a330448d Backprop support for pooling ops. (#652)
* Backprop support for pooling ops.

* max-pool gradient.
2023-08-29 10:17:59 +01:00
4b8d57ba15 AVX version of the q4k vecdot. (#651) 2023-08-29 09:41:17 +01:00
d5a525f7a7 Fix clippy + save_image. 2023-08-29 10:19:44 +02:00
33c23c19b6 Preliminary support for SDXL. (#647)
* Preliminary support for SDXL.

* More SDXL support.

* More SDXL.

* Use the proper clip config.

* Querying for existing tensors.

* More robust test.
2023-08-29 09:00:04 +01:00
Lei
49326fb925 Update .gitignore (#649) 2023-08-29 08:41:33 +01:00
fd3131a4ce Fix the debug implementation. (#648) 2023-08-28 22:51:39 +01:00
037b41c9dc Cuda conv transpose (#645)
* Cuda kernel for conv-transpose.

* Fix the cuda kernel.

* Fix the tests.
2023-08-28 20:58:49 +01:00
72fae3140c Optimize the conv2d transpose cpu kernel. (#644)
* Optimize the conv2d transpose cpu kernel.

* Use multiple cores.
2023-08-28 20:06:31 +01:00
ca26198b95 Fix the cpu kernel for conv-transpose. (#643) 2023-08-28 16:45:12 +01:00
b292047882 Backprop for conv2d. (#638)
* Start adding backprop for conv2d.

* Backprop for conv2d.

* Bugfix + start adding a conv2d test.

* Conv2d backprop testing.

* More conv fixes.
2023-08-28 16:08:55 +01:00
09c5bd1881 Rebased 2023-08-28 15:47:03 +02:00
fe6c88713d Fix waiting upgrade for SSL ? 2023-08-28 15:15:27 +02:00
6f3f9285e6 Remove image dep. 2023-08-28 15:15:27 +02:00
baca3cf69d Fix deps. 2023-08-28 15:15:27 +02:00
d726484a6d Re-enable local dir for mnist. 2023-08-28 15:15:27 +02:00
dd06d93d0b Cleanup:
- Moved around book from `examples` to `candle-book` proper (overlapping
  the book and the lib structures)
2023-08-28 15:15:26 +02:00
c109c93db7 Update candle-book/src/SUMMARY.md 2023-08-28 15:15:02 +02:00
d7a273be51 Training:
- Removed a lot of surface (SerializedFileReader ownership is really
  painful).
- Moved example + vision to hf.co version.
- Removed feature gate.
2023-08-28 15:15:01 +02:00
dd02f589c0 Better training+hub 2023-08-28 15:14:43 +02:00
7602323667 [Book] Add small error management + start training (with generic dataset
inclusion).
2023-08-28 15:14:17 +02:00
9137c63175 Update README.md (#640) 2023-08-28 11:34:54 +01:00
3cca89cc70 Add conv-transpose. (#635)
* Add conv-transpose.

* Return zeros for now.

* Naive CPU implementation.

* Add a conv-transpose test + fix the cpu implementation.

* Add a second test.
2023-08-28 10:10:12 +01:00
26e1b40992 Repeat-penalty in the falcon example. (#634) 2023-08-28 08:13:40 +01:00
1da71a5da1 Neon optimized version of the q4k vecdot product. (#632) 2023-08-27 21:30:47 +01:00
24dda44c27 Add wasm support for yolo-v8 pose detection. (#630)
* Add wasm support for yolo-v8 pose detection.

* Better bbox handling.

* Add the pose model in the wasm example lib.
2023-08-27 19:49:24 +01:00
72ebb12bca Remove some dead-code annotations. (#629)
* Remove some dead-code annotations.

* More dead code removal.

* One more.

* CI fix.
2023-08-27 18:52:33 +01:00
a3f97c143d Bump the crate version + update CHANGELOG. (#628) 2023-08-27 18:17:11 +01:00
4c338b0cd9 VarBuilder cleanup (#627)
* VarBuilder cleanup.

* Implement the basic varbuilders.

* Add the sharded code.

* Proper support for tensor sharding.
2023-08-27 18:03:26 +01:00
be471d50ab Llama quantization. (#625) 2023-08-27 14:08:15 +01:00
7151f2cf63 Add the quantize command. (#624)
* Add the quantize command.

* Bugfix for writing gguf files.

* And add a comment.
2023-08-27 11:35:19 +01:00
6e485f2deb Add some optional repeat penalty. (#623)
* Add some optional repeat penalty.

* Add the missing files.
2023-08-27 10:48:45 +01:00
5320aa6b7d Move the test-utils bits to a shared place. (#619) 2023-08-27 09:42:22 +01:00
a8b39dd7b7 Fix for q5_1 quantization. (#617)
* Fix for q5_1 quantization.

* Fix some typos.
2023-08-27 08:31:18 +01:00
fa0d75b18d Quantization tests + fix some issues. (#616) 2023-08-27 08:17:38 +01:00
28658054ff More missing quantized bits. (#615)
* Q4_1 support.

* Add Q5_1 quantization.

* Tweak.
2023-08-27 07:52:26 +01:00
ab36a7f3e3 Fix for when f16c is not available. (#614) 2023-08-27 07:19:52 +01:00
f704e39761 Missing quants ops (#611)
* Another transmute tweak.

* Changelog tweak.

* Add some missing quantized ops.
2023-08-26 20:09:04 +01:00
fdf15f0e05 Another transmute tweak. (#610)
* Another transmute tweak.

* Changelog tweak.
2023-08-26 13:00:24 +01:00
06b37ea7ad Avoid using tmp values. (#609) 2023-08-26 12:28:28 +01:00
c72eb3d75b Add reference implementation for q4k and q5k (#586)
* add `q2k` vec-dot

* `q3k` vec-dot + quantization bugfix

* `q4k` vec-dot

* `q5k` vec-dot

* Validate against GGML unit test results.

* Remove some more `transmutes`
2023-08-26 12:07:54 +01:00
864227edbf [WIP] Improve Yolo WASM UI example (#591)
* return detections with classes names

* ignore .DS_Store

* example how to load wasm module

* add param to set model size

* add param for model size

* accept iou and confidence threshold on run

* conf and iou thresholds

* clamp only

* remove images from branch

* a couple of renamings, add readme with instructions

* final design

* minor font + border update
2023-08-26 11:40:41 +01:00
b23b347b35 Merge pull request #601 from huggingface/repair_bf16_f16_cast
Repairing cast bf16/f16
2023-08-26 12:34:41 +02:00
71518caeee Align tensor device print more with PyTorch (#590)
* Improve tensor print

* Use CudaDevice only if enabled with cuda feature

* run rust fmt

* up

* improve

* rustfmt
2023-08-26 11:20:22 +01:00
6559eae72c Avoid some transmutes. (#607) 2023-08-25 18:21:37 +01:00
46eb225ba5 Add some missing entries to the changelog. (#606) 2023-08-25 18:01:38 +01:00
aa67e5107d Merge pull request #600 from huggingface/codellama_gpu_support
Adding support for codellama in examples.
2023-08-25 18:25:26 +02:00
c105550405 s/panic/bail/ 2023-08-25 18:05:07 +02:00
ca6c050b04 Cleanup the pose reporting code. (#605) 2023-08-25 16:49:21 +01:00
9c8d6dbc2a Neon intrinsics for the q8_0 vecdot. (#604)
* Neon intrinsics for the q8_0 vecdot.

* Get the tests to run with accelerate (with some numerical error failures).
2023-08-25 14:42:18 +01:00
0afbc435df Add some configurable legend for yolo detection. (#603)
* Add some configurable legend for yolo detection.

* Clippyness.
2023-08-25 13:50:31 +01:00
d4e75d5825 Let's keep the dirty code on its own. 2023-08-25 12:01:58 +00:00
be371e827c Intermediary float cast is necessary for cuda 11.8 2023-08-25 11:54:30 +00:00
97909e5068 Move the yolo model bits in a separate file. (#602)
* Move the yolo model bits in a separate file.

* Improve the drawing.

* Bugfix.
2023-08-25 12:47:55 +01:00
1c1e34735e static_cast ? 2023-08-25 11:40:36 +00:00
db8bab8b7a Different casting ? 2023-08-25 10:49:22 +00:00
bc131b402b Repairing cast bf16/f16 2023-08-25 10:38:19 +00:00
8bc5fffa45 More support for pose estimation in yolo-v8. (#599)
* More support for pose estimation in yolo-v8.

* Support both object detection and pose-estimation in the yolo-v8 example.
2023-08-25 11:21:11 +01:00
4826a4212e Adding support for codellama in examples.
Codellama requires bf16 for now (error to convert from bf16 to f16).
Multiprocess demo not functional for it because flash-attn only supports
f16 for now.
2023-08-25 09:56:11 +00:00
afc10a3232 AVX version for the q8-0 multiplications. (#598) 2023-08-25 10:14:49 +01:00
d728e646c2 Use resolver 2 explicitely. (#597) 2023-08-25 09:35:40 +01:00
c093b03d51 Generic implementation of vecdot for q80. (#596)
* Generic implementation of vecdot for q80.

* Add support for code-llama 7b.

* Support more code-llama.
2023-08-25 09:04:05 +01:00
d8ba0452dc Fail on bf16. (#594) 2023-08-25 06:10:38 +01:00
189442a0fa Add the pose estimation head for yolo. (#589)
* Add the pose estimation head for yolo.

* Properly handle the added position dimensions.

* Integrate the pose estimation head in the forward pass.

* Renaming.

* Fix for pose estimation.
2023-08-24 22:12:34 +01:00
2cde0cb74b More pickle support. (#588)
* More pickle support.

* Be more verbose.
2023-08-24 18:45:10 +01:00
e21c686cdc Fixes for clippy 1.72. (#587) 2023-08-24 17:46:17 +01:00
c265ac50fa Add a function to write gguf files. (#585)
* Add a function to write gguf files.

* More GGUF file writing.

* Write the tensor data in GGUF files.
2023-08-24 17:03:06 +01:00
a87c6f7652 Merge pull request #561 from patrickvonplaten/add_installation
Improve installation section and "get started"
2023-08-24 16:25:52 +02:00
afd965f77c More non square testing (#582)
* Add more non square testing.

* More testing.
2023-08-24 13:01:04 +01:00
d2f42ab086 Referenze implementations of q2k and q3k vec-dot functions (#580)
* add `q2k` vec-dot

* `q3k` vec-dot + quantization bugfix
2023-08-24 12:35:54 +01:00
ca318a6ec7 Add to the cuda example a reproduction of the issue. (#579)
* Add to the cuda example a reproduction of the issue.

* Tweak.

* Add a test using non-square matrixes.

* Fix the conv2d kernel.

* Display the error.

* And tweak the comment.
2023-08-24 12:07:31 +01:00
dd64465899 Add a test for conv2d with padding + bugfix the random number generation on cuda. (#578)
* Add a test for conv2d with padding.

* Cosmetic changes.

* Bugfix the rand function on the cuda backend.
2023-08-24 10:16:37 +01:00
79916c2edb Use the hub weights for efficientnet. (#573) 2023-08-23 18:20:21 +01:00
431051cc32 Add Efficientnet (#572)
* EfficientNet.

* Complete the efficientnet implementation.

* Improve group handling.

* Get the efficientnet to work.
2023-08-23 18:02:58 +01:00
eedd85ffa7 Move the imagenet specific bits to a separate file. (#571) 2023-08-23 16:42:09 +01:00
7478dda255 Cosmetic tweaks. (#570) 2023-08-23 15:45:40 +01:00
329f661d9b Trace softmax (#568)
* Trace the softmax op.

* Inline the sum.

* Add min/max vec operations.
2023-08-23 15:25:50 +01:00
075b505480 Mirror GGML's unit tests (#569)
* Add ggml unit tests

* simplify random matmul test for other test cases
2023-08-23 15:25:17 +01:00
aba1e90797 Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions.

* Avoid some unnecessary groups checks.

* Move the tensor convolution bits.

* Properh handling of groups.

* Bump the crate version.

* And add a changelog.
2023-08-23 12:58:55 +01:00
1f58bdbb1d Apply suggestions from code review 2023-08-23 13:33:45 +02:00
c98d3cfd8b Update candle-book/src/guide/installation.md 2023-08-23 13:31:54 +02:00
c5e43ad0ab Apply suggestions from code review 2023-08-23 13:27:29 +02:00
2c280007e8 Apply suggestions from code review 2023-08-23 13:26:21 +02:00
4ee1cf038a Get the rms epsilon from GGUF. (#565) 2023-08-23 11:40:20 +01:00
0f4ff8a739 Fix the quantized example. (#564) 2023-08-23 11:09:55 +01:00
89a00b56cc add chat models in quantized example (#551)
* add chat models in quantized example

* cargo fmt
2023-08-23 11:05:33 +01:00
9a5c7db91a Add support for i64 (#563)
* Add the i64 dtype.

* Adapt the cuda kernels.
2023-08-23 10:42:19 +01:00
649202024c fix code snippets 2023-08-23 09:05:07 +00:00
283f6c048d fix code snippets 2023-08-23 09:04:36 +00:00
c8211fc474 fix code snippets 2023-08-23 09:04:08 +00:00
7732bf6238 correct 2023-08-23 08:54:48 +00:00
7c0ca80d3a move installation to book 2023-08-23 08:52:53 +00:00
b558d08b85 improve 2023-08-23 08:42:47 +00:00
34cb9f924f improve 2023-08-23 08:40:23 +00:00
d4968295a0 improve 2023-08-23 08:37:08 +00:00
65e146c72d Add installation section 2023-08-23 08:32:59 +00:00
3743bed2d7 Fix the ? operator cannot be applied to type Device of example (#560)
According to the API:

```rust
inp = inp.to_device(&Device::Cuda(0)?)?;
```

cannot work as `Cuda(...)` expects a type `Device` not an integer.

I'd recommend to instead use `new_cuda(...)`
2023-08-23 09:29:50 +01:00
508d34daf2 GGUF support in the quantized model. (#559)
* GGUF support in the quantized model.

* Get the GGUF support to work on llama.
2023-08-23 09:20:57 +01:00
0764741cc4 Handle GGUF files in tensor-tools. (#558) 2023-08-23 06:32:07 +01:00
6a30ecefad Preliminary GGUF support. (#557)
* Preliminary GGUF support.

* Tensor reading.
2023-08-23 00:14:10 +01:00
7687a0f453 Also fix the aspect ratio in the wasm example. (#556)
* Also fix the aspect ratio in the wasm example.

* Add the yolo lib.

* Update the build script.
2023-08-22 22:20:08 +01:00
f9ecc84477 GQA support in the quantized model. (#555)
* GQA support in the quantized model.

* Fix the reshaping.

* Fix the main llama model.

* Infer the proper gqa from the model kind.
2023-08-22 19:41:10 +01:00
07067b01dc Avoid some mutable variables (take 2). (#554)
* Avoid some mutable variables (take 2).

* Fix.
2023-08-22 18:51:20 +01:00
cc22d4db20 Put the transcribe token before the language one. (#553) 2023-08-22 16:46:34 +01:00
ec665acad7 Revert "Avoid some mut in quantized functions. (#550)" (#552)
This reverts commit cf27b9b636.
2023-08-22 15:57:46 +01:00
cf27b9b636 Avoid some mut in quantized functions. (#550)
* Avoid a couple more 'let mut'.

* Tweaks.
2023-08-22 15:44:26 +01:00
352383cbc3 Add quantization support for q2k, q3k, q4k and q5k (#524)
* first q2 implementation

* First Q4K and Q5K implementations

* fix `q2k` and `q5k`

* Some first cleanups

* run `clippy` on tests

* finally implement `q3k`

* deactivate `q3k` test on macos

* also disable the test on linux

* Fix floating bits in `q3k` dequantization

* Refactoring pass + reorder quants in file

* `fmt`

* Re-add `src` asserts and redefine `dst`
2023-08-22 15:04:55 +01:00
9bc811a247 Improve the aspect ratio handling on yolo-v8. (#549)
* Fix the aspect ratio handling in yolo-v8.

* Typo.
2023-08-22 14:55:33 +01:00
bb69d89e28 Move the yolo shared bits to a common place. (#548)
* Move the yolo shared bits to a common place.

* Share more code.

* Configurable thresholds.
2023-08-22 13:03:07 +01:00
20ce3e9f39 Sketch the yolo wasm example. (#546)
* Sketch the yolo wasm example.

* Web ui.

* Get the web ui to work.

* UI tweaks.

* More UI tweaks.

* Use the natural width/height.

* Add a link to the hf space in the readme.
2023-08-22 11:56:43 +01:00
44420d8ae1 Add some llama-v2 variants. (#545) 2023-08-22 08:35:15 +01:00
f16bb97401 Use the yolo-v8 weights from the hub. (#544)
* Use the weights from the hub.

* Add to the readme.
2023-08-21 22:07:36 +01:00
3507e14c0c Yolo v8 fixes (#542)
* Fixes for the yolo-v8 layout.

* Bugfixes.

* Another silly bugfix.

* Remove the hf-hub dependency.

* Remove the transformers dependency.
2023-08-21 21:05:40 +01:00
de50e66af1 Add yolo v8 as an example (#541)
* Sketching yolo-v8.

* Get the model to load.

* yolo-v8 forward pass.

* Complete(?) the forward pass.

* Fix some shape issues.

* Add the missing padding.

* Process the predictions.
2023-08-21 18:40:09 +01:00
cc2d6cf2e0 Improve the timestamps support in whisper (#539)
* Timestamp support for whisper.

* Properly display the timestamps.

* Bugfix for the timestamp units.
2023-08-21 12:26:59 +01:00
e3b71851e6 Retrieve the yolo-v3 weights from the hub. (#537) 2023-08-21 10:55:09 +01:00
4300864ce9 Add some optional repeat penalty. (#535) 2023-08-21 09:59:13 +01:00
d70cffdab6 Fix the minimum/maximum gradient computations. (#534) 2023-08-21 08:28:41 +01:00
912561614f Better handling of zero temperatures. (#532) 2023-08-21 07:51:46 +01:00
8c232d706b Small tweaks to the pickle handling to be able to use libtorch files. (#530)
* Small tweaks to the pickle handling to be able to use libtorch files.

* Move the pytorch specific bits in a different function.
2023-08-20 23:25:34 +01:00
11c7e7bd67 Some fixes for yolo-v3. (#529)
* Some fixes for yolo-v3.

* Use the running stats for inference in the batch-norm layer.

* Get some proper predictions for yolo.

* Avoid the quadratic insertion.
2023-08-20 23:19:15 +01:00
a1812f934f Add a yolo-v3 example. (#528)
* Add a couple functions required for yolo.

* Add the yolo-v3 example.

* Add minimum and maximum.

* Use the newly introduced maximum.

* Cuda support for min/max + add some testing.

* Allow for more tests to work with accelerate.

* Fix a typo.
2023-08-20 18:19:37 +01:00
e3d2786ffb Add a couple functions required for yolo. (#527) 2023-08-20 17:02:05 +01:00
372f8912c5 Minor readme tweaks. (#526) 2023-08-20 14:33:21 +01:00
d2622a8160 Move the VarMap to a separate file (#525)
* Move the var-map struct in a separate file.

* Fix some typos.
2023-08-20 14:25:07 +01:00
2fcb386f17 Add a broadcast variant to matmul. (#523)
* Add a broadcast variant to matmul.

* Get the test to pass.
2023-08-20 13:20:42 +01:00
a8f61e66cc Bump the crates version to 0.1.2. (#522) 2023-08-20 08:07:07 +01:00
aa207f2dd9 Print some per-step timings in stable-diffusion. (#520)
* Skeleton files for neon support of quantization.

* SIMD version for q4 vecdot.

* Also simdify the q6k multiplication.

* Add some timings to stable-diffusion.
2023-08-20 05:45:12 +01:00
82410995a2 Neon support for quantization. (#519)
* Skeleton files for neon support of quantization.

* SIMD version for q4 vecdot.

* Also simdify the q6k multiplication.
2023-08-19 22:07:29 +01:00
d73ca3d28e Line up the llama.cpp implementation with the candle one. (#518)
* Separate the prompt stats from the post-prompt ones in the quantized example.

* Slightly nicer output printing.

* Line up with the llama.cpp implementation.
2023-08-19 20:12:07 +01:00
551409092e Small tweaks to tensor-tools. (#517) 2023-08-19 16:50:26 +01:00
6431140250 Retrieve tensor data from PyTorch files. (#516) 2023-08-19 15:57:18 +01:00
607ffb9f1e Retrieve more information from PyTorch checkpoints. (#515)
* Retrieve more information from PyTorch checkpoints.

* Add enough support to load dino-v2 backbone weights.
2023-08-19 15:05:34 +01:00
f861a9df6e Add ggml support to tensor-tools (#512)
* Pickle work-in-progress.

* More unpickling.

* More pickling.

* Proper handling of setitems.

* Clippy.

* Again more pickling.

* Restore the example.

* Add enough pickle support to get the list of tensors.

* Read the data from zip files.

* Retrieve the tensor shape.

* Extract the size and dtype.

* More storage types.

* Improve the destructuring.

* Also support ggml files.
2023-08-19 11:45:22 +01:00
ad33715c61 Preliminary support for importing PyTorch weights. (#511)
* Pickle work-in-progress.

* More unpickling.

* More pickling.

* Proper handling of setitems.

* Clippy.

* Again more pickling.

* Restore the example.

* Add enough pickle support to get the list of tensors.

* Read the data from zip files.

* Retrieve the tensor shape.

* Extract the size and dtype.

* More storage types.

* Improve the destructuring.
2023-08-19 11:26:32 +01:00
90ff04e77e Add the tensor-tools binary. (#510) 2023-08-19 09:06:44 +01:00
42e1cc8062 Add a batch normalization layer (#508)
* Add BatchNormalization.

* More batch-norm.

* Add some validation of the inputs.

* More validation.
2023-08-18 20:05:56 +01:00
b64e782c2d Use the hub to retrieve dinov2 model weights. (#507) 2023-08-18 18:27:31 +01:00
e5dd5fd1b3 Print the recognized categories in dino-v2. (#506) 2023-08-18 17:32:58 +01:00
cb069d6063 Add the permute op (similar to pytorch). (#504)
* Add the permute op (similar to pytorch).

* Add the backprop for dimension permutation.
2023-08-18 16:30:53 +01:00
4f1541526c dinov2 - read images from disk and compute the class probabilities (#503)
* Load the image from disk and convert it to a tensor.

* Tweak the function name.
2023-08-18 15:50:33 +01:00
95462c6a2e Add a vision transformer example (dino-v2). (#502)
* Add a vision transformer example (dino-v2).

* Add some documentation + test.

* CI fix.

* Another fix (still unable to replicate the errors locally :( )
2023-08-18 11:58:06 +01:00
b9661a1c25 Enable the image crate by default in examples (#501)
* Enable the image crate by default so that it's easier to compile the stable diffusion example.

* Also update the readme.
2023-08-18 10:00:05 +01:00
109e95b189 Basic qmatmul parallelization (#492)
* Basic `par_iter` parallelization

* Pass errors up

* Disable `avx` for x86 macs
2023-08-18 09:45:37 +01:00
c78ce76501 Add a simple Module trait and implement it for the various nn layers (#500)
* Start adding the module trait.

* Use the module trait.

* Implement module for qmatmul.
2023-08-18 09:38:22 +01:00
13401df4d1 Add an abstract type for RmsNorm. (#499) 2023-08-18 08:52:14 +01:00
a22b1bed7b Tensor -> QTensor conversion (#496)
* Sketch some qmatmul test.

* Add the quantization function.

* More testing.

* Make the test smaller and faster.

* Add some shape checking.
2023-08-18 08:19:20 +01:00
26fd37b348 Use the main branch of the HF repo where possible. (#498)
* Use the main branch of the HF repo where possible.

* And add the large model.
2023-08-18 08:18:30 +01:00
f056dcab21 Add medium model (#497) 2023-08-18 08:08:59 +01:00
557b2c28dd Q6K quantization (#495)
* Print the detected arch options.

* Add the q6k quantization.

* Add a currently broken test.

* Bugfix.

* Bugfix.

* Another bugfix.

* Another bugfix + get the test to work.
2023-08-17 22:22:57 +01:00
fc81af1712 AVX version of the q6k vec-dot. (#493)
* AVX version of the q6k vec-dot.

* Use the avx sum.
2023-08-17 20:13:18 +01:00
3164cd24fa Replicate the sot-token logic from the Python implementation more acc… (#491)
* Replicate the sot-token logic from the Python implementation more accurately.

* Add a flag to control the timestamp mode.
2023-08-17 16:59:36 +01:00
5f30c1e1e0 Add the whisper small model. (#490) 2023-08-17 15:48:34 +01:00
ad7c53953b Add a verbose-prompt mode, similar to llama.cpp. (#489) 2023-08-17 15:26:44 +01:00
5d99026fd2 F16 support for stable diffusion (#488)
* F16 support for stable diffusion.

* Keep the attention bits in F32.

* Keep more of the attention bits in F32.

* More mixed precision support.
2023-08-17 13:48:56 +01:00
c3176f0dfb Flash-attention support in stable diffusion (#487)
* Add flash-attention for the stable-diffusion example.

* Change the dtype.

* Silly fix.

* Another fix.

* Revert the dtype back to the query dtype after apply flash-attn.
2023-08-17 12:16:40 +01:00
03be33eea4 Relax the requirements on CustomOp. (#486)
* Relax the requirements on CustomOp.

* Simplify the custom-ops when no backward is required.
2023-08-17 11:12:05 +01:00
d32e8199cd Layer norm tweaks (#482)
* Add some options to make layer-norm more configurable.

* Add the rms-norm variant.

* Replace the RmsNorm with the shared bits.
2023-08-17 10:07:13 +01:00
d99cac3ec3 Move the avx specific bits to a separate file. (#481) 2023-08-17 09:01:06 +01:00
f708efb19c Add some accelerate details on the readme. (#480) 2023-08-17 08:26:02 +01:00
306c8eee7a AVX version of the vecdot for q4_0. (#474)
* AVX version of the vecdot for q4_0.

* Tweak the avx bits.

* Add a qmatmul benchmark.

* Fix the quantized test.
2023-08-17 07:03:32 +01:00
098909de40 Add vecdot for q6k-q8k. (#476)
* Add vecdot for q6k-q8k.

* Add some testing for q8k.

* Use QMatMul for the output layer.
2023-08-16 20:59:40 +01:00
3bedba1fce Use a zipped iterator. (#475)
* Use a zipped iterator.

* Add to/from float for q8k.
2023-08-16 20:15:11 +01:00
c5f45887dc Add some tracing to the quantized example. (#473) 2023-08-16 18:49:08 +01:00
fa4590d7fd Merge pull request #469 from huggingface/fix_llama_v1
Fixing llamav1
2023-08-16 17:47:40 +02:00
2e206e269d Add the model argument. (#471) 2023-08-16 16:41:06 +01:00
575e88a999 Add a quantized test that use negative values. (#470)
* Add a quantized test that use negative values.

* Add a default tokenizer.
2023-08-16 16:32:58 +01:00
a9101700b6 Add a kv-cache to the quantized llama example. (#466)
* Add a kv-cache to the quantized llama example.

* Also print the prompt.

* Bugfix in q6k dequantizing.

* Another bugfix.
2023-08-16 14:28:42 +01:00
102fa4c2e3 Fixing llamav1 2023-08-16 14:53:29 +02:00
3071134788 Get the ggml based llama to generate some text. (#464)
* Add more stats to the ggml example.

* Build a quantized model from the file content.

* Move the tensor retrieval in the main crate.

* Start adding the forward pass.

* Add more to the forward pass of the quantized llama.

* Apply the attention layers.

* Add the sampling loop.

* Get the sampling loop to work.

* Minor tweak.

* Add a quantize/dequantize test.

* Bugfix.

* Add a comment + swap the order.

* Bugfixes.
2023-08-16 12:41:07 +01:00
fec87e86f5 Merge pull request #465 from huggingface/llama_hub_config
Using the real config from the hub when available.
2023-08-16 13:28:59 +02:00
33c882ea74 Clippy. 2023-08-16 10:41:00 +02:00
76804730c6 Using the real config from the hub when available. 2023-08-16 10:36:01 +02:00
965597a873 Add a test for qmatmul. (#459) 2023-08-16 06:36:27 +01:00
ca449f9ee1 Add quantized tensors. (#458)
* Add quantized tensors.

* Implement the debug trait for QTensor.

* Add the QMatMul custom op.
2023-08-15 22:45:53 +01:00
b8263aa15c Quantized support for f16 and f32 (#457)
* Add f32 as a quantized type.

* Add f16 as a quantized type too.
2023-08-15 21:09:37 +01:00
e68b2accb4 Split out the quantized file. (#456) 2023-08-15 20:26:27 +01:00
08effe3762 More quantization support (#455)
* Properly initialize wdata.

* Simplify the matmul bits.

* Add from_float for q4_0.

* Fix a couple bugs.

* Get the test to work.

* Get clippy to be happy.
2023-08-15 18:58:04 +01:00
8ad4a21ffc Add a basic optimizer example. (#454) 2023-08-15 17:19:18 +01:00
5e49922be2 Basic quantization support (#453)
* Add a vecdot trait.

* Start implementing mul_mat.

* Add to the mul mat implementation.

* Add q8_0 quantization.

* Implement the GgmlType trait for all types.

* Add the missing block.

* Add a TODO.
2023-08-15 15:53:19 +01:00
ebcfd96d94 add c++17 flags (#452) 2023-08-15 15:29:34 +01:00
5b1690fffa Tweak the llama example. (#450) 2023-08-15 12:18:20 +01:00
3cc87058b7 Support local weights & dynamic outputs (#447)
* Support local weights & dynamic outputs

* Revise as suggested

* Cargo code format
2023-08-15 11:51:57 +01:00
531f23b4d0 Rename vec-dot to vec-ops. (#449)
* Rename vec-dot to vec-ops.

* Also bump the crate version.

* Add a currently empty readme.
2023-08-15 10:48:57 +01:00
495e0b7580 Simd support (#448)
* Import the simd intrinsics in candle-core.

* simd version of reduce-sum.

* Bugfix.

* Fix some clippy lints.
2023-08-15 09:50:38 +01:00
90374097dc Cudnn support (#445)
* Add a cudnn feature to be used for conv2d.

* Allocate the proper workspace.

* Only create a single cudnn handle per cuda device.

* Proper cudnn usage.

* Bugfix.
2023-08-14 21:30:41 +01:00
c84883ecf2 Add a cuda kernel for upsampling. (#441)
* Add a cuda kernel for upsampling.

* Update for the latest tokenizers version.
2023-08-14 13:12:17 +01:00
a094dc503d Add a cuda kernel for avg-pool2d. (#440)
* Add a cuda kernel for avg-pool2d.

* Avoid running out of bounds.

* Finish wiring the avg pool kernel + add some testing.

* Support for max-pool + testing.
2023-08-14 12:32:05 +01:00
34f4b3187e Add a naive conv2d cuda kernel. (#438)
* Add a naive conv2d cuda kernel.

* Proper conv2d support on the rust side.

* Conv1d testing on gpu.

* Also use the test on gpus.

* Fix the clean-ptx target.
2023-08-14 10:34:42 +01:00
eab54e4490 Fix the tests for mkl. (#437) 2023-08-14 08:09:27 +01:00
9e7e6e0288 Add dequantization for ggmls q4_0, q4_1, q5_0, q5_1 and q8_0 (#407)
* Added dequantization for `q4_0`, `q4_1`, `q5_0`, `q5_1` and `q8_0`

* expose `tensor_from_ggml` for external usage

* bugfixes & example
2023-08-13 23:22:57 +01:00
8bd2b22b33 Optimize the logit computations in the whisper example. (#434) 2023-08-13 22:00:13 +01:00
d379a76a9e Add a softmax bench. (#433)
* Add a softmax bench.

* Add the vectorized sum reduce.
2023-08-13 20:09:18 +01:00
9af438ac1b Track the conv2d operations in stable-diffusion. (#431)
* Track the conv2d operations in stable-diffusion.

* Add more tracing to stable-diffusion.

* Also trace the resnet bits.

* Trace the attention blocks.

* Also trace the attention inner part.

* Small tweak.
2023-08-13 15:58:26 +01:00
b1ff78f762 Allow using accelerate with stable-diffusion. (#430) 2023-08-13 14:14:20 +01:00
5a63b51f14 Add a matmul benchmark. (#429) 2023-08-13 13:41:03 +01:00
6d694554b8 Support longer sequences in language detection. (#428) 2023-08-13 13:16:15 +01:00
9aca398a4f More accelerate optimizations (#427)
* Add more tracing to the whisper example.

* Support accelerate in more examples.

* Use accelerate for pointwise functions.

* Use accelerate for binary operations too.

* Bugfix for binary operation: use the rhs before the lhs.
2023-08-13 12:53:34 +01:00
60cd1551ca Add a KV cache to whisper. (#426) 2023-08-12 21:17:08 +01:00
a0908d212c Add a -language argument. (#425) 2023-08-12 17:08:40 +01:00
972078e1ae Update the readme with the discord server and common errors. (#423) 2023-08-12 16:45:58 +01:00
16b89f5b83 fix: can directly save the loaded weights (#421) 2023-08-12 16:33:29 +01:00
0741ebbd51 More multilingual support for whisper. (#419)
* More multilingual support for whisper.

* Use the language token appropriately.
2023-08-12 15:32:52 +01:00
0c3f109faa Basic multilingual support for whisper (#417)
* Multi-lingual support for whisper.

* Avoid hardcoding the token names.

* More multi-lingual support.

* Remove the todo.
2023-08-12 11:23:04 +01:00
2ba6b2826f Fix the readme instructions for stable-diffusion. (#415) 2023-08-11 18:59:04 +01:00
1d0157bbc4 Stable diffusion: retrieve the model files from the HF hub. (#414)
* Retrieve the model files from the HF hub in the stable diffusion example.

* Add to the readme.
2023-08-11 18:57:06 +01:00
91dbf907d3 Add more whisper variants. (#413) 2023-08-11 17:33:55 +01:00
e12372021b Expose the tensor write-bytes function. (#412) 2023-08-11 17:13:42 +01:00
55e428c8ae Expose the varmap inner data. (#411) 2023-08-11 16:58:56 +01:00
01ea57da8c Fix the conv tests. (#409) 2023-08-11 14:59:54 +01:00
662db45fc3 Use zero padding in conv1d and conv2d (same as pytorch). (#408) 2023-08-11 14:53:05 +01:00
906c0f3eb5 Remove the checkpoint conversion script. (#405)
* Remove the checkpoint conversion script.

* Remove references to the script.
2023-08-11 05:59:48 +01:00
e29c7809ec Parallelise the CPU kernels for the conv ops. (#401)
* Parallelise the conv2d op.

* Tighter control on threading.

* Also parallelise conv1d.

* Add some safety comment.
2023-08-11 05:51:58 +01:00
a325c1aa50 Upsample test + bugfix. (#399) 2023-08-10 21:02:35 +02:00
b6cf26e48e Merge pull request #393 from huggingface/older_gpus
Working on older GPUs (still not compute 52 it seems but > 6 could be OK)
2023-08-10 20:49:23 +02:00
379eadc68e Working now. 2023-08-10 19:43:25 +02:00
7e4fbc1e17 [DO NOT MERGE] temporary PR so users can try out on older GPUs. 2023-08-10 19:36:31 +02:00
80f0482f26 Fix the stable-diffusion vae. (#398)
* Fix the stable-diffusion vae.

* Fix for saving images.
2023-08-10 18:24:31 +01:00
94eff56aee Optimize the cpu conv2d kernel (#396)
* Conv2d simd optimization.

* Fix the contiguous copying.

* Small tweak.
2023-08-10 17:40:09 +01:00
a55133effd Merge pull request #395 from huggingface/fix_compat_windows
Compat windows.
2023-08-10 18:05:12 +02:00
ff53f38467 Small example for benchmarking some cpu ops (#394)
* Refactor the benchmark example.

* Rename the example.

* Add some comments.
2023-08-10 17:00:17 +01:00
4a95d34c83 Compat windows. 2023-08-10 17:46:47 +02:00
7f710a573d Merge pull request #374 from Rocketknight1/readme_fixes
README.md typos and grammar fixes
2023-08-10 16:34:19 +02:00
c8039579a5 Conv1d optimize (#392)
* Reorder the conv1d loops in the cpu backend.

* Optimize the 1d convolution.

* Conv1D optimize.

* Fix some clippy lints.
2023-08-10 15:23:52 +01:00
0b0fa56978 Merge pull request #386 from huggingface/enabling_61_maybe
This is duplicated code on Cuda 12.2.
2023-08-10 16:23:17 +02:00
385f0d261c Normalize embeddings in the bert example. (#390) 2023-08-10 13:05:55 +01:00
b765f2c37f Update the wasm build instructions. (#389) 2023-08-10 11:29:43 +01:00
66d1c093e0 This is duplicated code on Cuda 12.2.
Without it we can compile for 52 (but I get Operation Not supported
when actually trying to use those kernels).
2023-08-10 09:20:18 +02:00
de7c31bfe9 Merge pull request #368 from huggingface/add_cuda_ci
Adding cuda CI
2023-08-10 08:49:39 +02:00
8e7ef96588 Fix CI cuda. 2023-08-10 08:47:15 +02:00
f3fe730a30 Npy tweaks & error with path (#384)
* Simplify the npy writing.

* Wrap the file path so as to provide better errors.
2023-08-10 06:21:58 +01:00
c7f92f985e Further randn tweaks: use the appropriate rng rather than the f64 one, some cleanup. (#383) 2023-08-10 05:48:19 +01:00
Lei
3bbc08a8df Fix randn cpu (#382)
* Change distributions

Standard generates in [0, 1), Normal is correct.

* Add test

Not sure if this is the best place to put  the test

* Remove unnecessary use
2023-08-10 05:33:44 +01:00
6a2137af4f Update README.md 2023-08-10 00:19:58 +01:00
0dc1e5f387 Merge branch 'main' into readme_fixes 2023-08-10 00:19:20 +01:00
bd2fb6216b Testing in release mode because debug is too slow. 2023-08-09 23:19:55 +02:00
3542b26143 ssl update. 2023-08-09 23:11:45 +02:00
a690f14a77 Fix by hardcoding paths 2023-08-09 23:08:50 +02:00
90d778c059 ? 2023-08-09 23:02:11 +02:00
171fcbe539 CI ssh in the meantime. 2023-08-09 22:58:47 +02:00
07e83c55c0 Attempt nb2 2023-08-09 22:47:01 +02:00
25ec2d9f6b fix: remove incorrect unwrap (#379) 2023-08-09 21:45:24 +01:00
da26e2832c Update gemm to 0.15.6. (#378) 2023-08-09 21:04:28 +01:00
fcfdcbd337 Add a conv1d benchmark based on the whisper sizes. (#377)
* Add a conv1d benchmark based on the whisper sizes.

* Enforce the batch-dim in conv1d.
2023-08-09 20:27:03 +01:00
653ec5abc1 Update README.md (#376)
add missing word
2023-08-09 20:09:21 +01:00
c3a0761e62 Add some tracing to the whisper example. (#375) 2023-08-09 19:58:36 +01:00
0cef3998fd README.md typos and grammar fixes 2023-08-09 19:36:03 +01:00
e5f510d209 SSH to debug. 2023-08-09 19:54:40 +02:00
0dd94eff4c Merge pull request #367 from eltociear/eltociear-patch-1
Update README.md
2023-08-09 19:48:31 +02:00
a3b1699409 Embed the mel filters in the whisper binary. (#373) 2023-08-09 18:27:26 +01:00
5b79b38bc7 Remove extra square bracket (#372) 2023-08-09 18:24:28 +01:00
a5c5a893aa add max_pool2d (#371)
Co-authored-by: 赵理山 <ls@zhaolishandeMacBook-Air.local>
2023-08-09 18:05:26 +01:00
e6ce47f9e0 ? 2023-08-09 19:00:25 +02:00
1892bd139c Extract the strides in the conv ops. (#370) 2023-08-09 17:57:05 +01:00
749c8c7f51 Better rust GH action. 2023-08-09 18:42:53 +02:00
d9b4fef189 Chnage name 2023-08-09 18:14:29 +02:00
8fa329aca2 Adding cuda CI 2023-08-09 18:13:27 +02:00
cd225bd3b1 More testing for avg-pool2d. (#366)
* More testing for avg-pool2d.

* Another fix.

* Add a max-pool test with non-divisible kernel sizes.
2023-08-09 16:12:23 +01:00
a4f6977087 Update README.md
dauting -> daunting
2023-08-10 00:11:11 +09:00
dece0b8a76 Merge pull request #263 from huggingface/book_3
Book 3 (advanced loading + hub)
2023-08-09 16:50:11 +02:00
b80348d22f Bugfix for avg-pool + add some test. (#365) 2023-08-09 15:44:16 +01:00
3a62aee91f Write the generated images using the image crate. (#363)
* Use the image crate to write the generated images.

* Make the dependency optional.
2023-08-09 15:26:44 +01:00
be21d7e75a Fix the padding used in stable diffusion. (#362) 2023-08-09 13:23:59 +01:00
9c4cf6804b Merge pull request #355 from cksac/fix_book
fix repo link
2023-08-09 09:08:16 +02:00
dbc6f281c9 Conv1d test with padding. (#356) 2023-08-09 05:45:38 +01:00
47a5bee249 fix repo link 2023-08-09 11:29:48 +08:00
cf965ecaa8 Simplify the conv1d and conv2d code. (#352) 2023-08-08 22:10:59 +01:00
b9864e1357 Fix size-in-bytes for u8. (#351) 2023-08-08 21:15:18 +01:00
608b2358c6 Add some conv1d test + bugfix using padding. (#349) 2023-08-08 20:50:20 +01:00
1e6dbeac01 Add some conv2d tests. (#347)
* Add some conv2d tests.

* Add a simpler conv2d test.

* More conv2d testing + bugfix.

* Add a todo.
2023-08-08 19:02:42 +01:00
13ce68ff9b Bugfix for conv2d. (#343) 2023-08-08 15:20:00 +01:00
89d3926c9b Fixes for the stable diffusion example. (#342)
* Fixes for the stable diffusion example.

* Bugfix.

* Another fix.

* Fix for group-norm.

* More fixes to get SD to work.
2023-08-08 14:57:09 +01:00
ab35684326 Naive implementation for conv2d. (#341) 2023-08-08 06:34:36 +01:00
b5bb5e056d Add more conv2d support. (#340)
* Add more conv2d support.

* Conv2d cpu work.

* Conv2d output shape.
2023-08-08 06:04:32 +01:00
d0d7010682 CPU implementation for upsample-nearest2d. (#339) 2023-08-07 20:07:10 +01:00
fc265d9dcf Some CLIP fixes for stable diffusion. (#338)
* Some CLIP fixes for stable diffusion.

* Add the avg-pool2d operation on cpu.
2023-08-07 18:31:45 +01:00
2345b8ce3f Skeleton for the avg-pool2d and upsample-nearest2d ops. (#337)
* Skeleton for the avg-pool2d and upsample-nearest2d ops.

* Preliminary conv2d support.
2023-08-07 16:15:38 +01:00
f53a333ea9 Simple pad support. (#336)
* Simple pad support.

* Fix the tensor indexing when padding.
2023-08-07 15:24:56 +01:00
e72ba0b9e7 Add the license files. (#335) 2023-08-07 14:11:27 +01:00
5bb2fce998 Implement group-norm. (#334)
* Implement group-norm.

* Add some testing for group-norm.
2023-08-07 06:53:05 +01:00
2c9f605976 Add rand-like/randn-like. (#333) 2023-08-06 21:51:08 +01:00
141df4ad2b Main diffusion loop for the SD example. (#332) 2023-08-06 21:39:53 +01:00
166bfd5847 Add the recip op + use it in stable-diffusion. (#331)
* Add the recip unary op.

* Fix the cuda kernel.

* Use the recip op in sigmoid.
2023-08-06 21:14:52 +01:00
1c062bf06b Add the ddim scheduler. (#330) 2023-08-06 20:44:00 +01:00
d34039e352 Add a stable diffusion example (#328)
* Start adding a stable-diffusion example.

* Proper computation of the causal mask.

* Add the chunk operation.

* Work in progress: port the attention module.

* Add some dummy modules for conv2d and group-norm, get the attention module to compile.

* Re-enable the 2d convolution.

* Add the embeddings module.

* Add the resnet module.

* Add the unet blocks.

* Add the unet.

* And add the variational auto-encoder.

* Use the pad function from utils.
2023-08-06 17:49:43 +01:00
93cfe5642f Pyo3 dtype (#327)
* Better handling of dtypes in pyo3.

* More pyo3 dtype.
2023-08-06 10:17:43 +01:00
88bd3b604a Add some tensor creation functions to the pyo3 bindings. (#326) 2023-08-06 06:50:33 +01:00
b278834267 Support the Accelerate BLAS on macOS. (#325)
* Add the accelerate feature.

* Ffi tweaks.
2023-08-05 17:25:24 +01:00
0b175fcbbd Fix the pyo3 build for macos. (#324)
* Fix the pyo3 build for macos.

* rustfmt fix.
2023-08-05 14:53:57 +01:00
dba31473d4 Typos and format and CD only when PR lands. 2023-08-02 19:18:43 +02:00
1b2b32e58d Remove dead page.t 2023-08-02 18:59:36 +02:00
166f4d1101 s/candle/candle_core/g 2023-08-02 18:40:24 +02:00
ae68635af9 Add small error management. 2023-08-02 18:40:24 +02:00
c11e78b334 Odd rebase artifact. 2023-08-02 18:40:24 +02:00
1b705a426f Remove duplicate. 2023-08-02 18:40:24 +02:00
a70b95f9e7 Marking unwritten chapters as Draft (disables the link). 2023-08-02 18:40:24 +02:00
a44471a305 Adding more details on how to load things.
- Loading with memmap
- Loading a sharded tensor
- Moved some snippets to `candle-examples/src/lib.rs` This is because
managing book specific dependencies is a pain https://github.com/rust-lang/mdBook/issues/706
- This causes a non aligned inclusion  https://github.com/rust-lang/mdBook/pull/1856 which we have
to ignore fmt to remove.

mdbook might need some more love :)
2023-08-02 18:40:24 +02:00
45642a8530 Fixing examples. 2023-08-02 18:40:24 +02:00
82464166e4 3rd phase. 2023-08-02 18:40:24 +02:00
839 changed files with 180968 additions and 9707 deletions

View File

@ -1,8 +1,8 @@
[target.x86_64-unknown-linux-gnu]
rustflags = ["-C", "target-cpu=native"]
[target.aarch64-apple-darwin]
[build]
rustflags = ["-C", "target-cpu=native"]
[target.wasm32-unknown-unknown]
rustflags = ["-C", "target-feature=+simd128"]
[target.x86_64-apple-darwin]
rustflags = ["-C", "target-feature=-avx,-avx2"]

7
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,7 @@
version: 2
updates:
- package-ecosystem: "cargo"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 5

View File

@ -1,42 +0,0 @@
name: Deploy Rust book
on:
# TODO put this back only when merging after this PR lands.
pull_request:
push:
branches:
- main
jobs:
deploy:
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir mdbook
curl -sSL $url | tar -xz --directory=./mdbook
echo `pwd`/mdbook >> $GITHUB_PATH
- name: Deploy GitHub Pages
run: |
# This assumes your book is in the root of your repository.
# Just add a `cd` here if you need to change to another directory.
cd candle-book
mdbook build
git worktree add gh-pages
git config user.name "Deploy from CI"
git config user.email ""
cd gh-pages
# Delete the ref to avoid keeping history.
git update-ref -d refs/heads/gh-pages
rm -rf *
mv ../book/* .
git add .
git commit -m "Deploy $GITHUB_SHA to gh-pages"
git push --force --set-upstream origin gh-pages

View File

@ -1,29 +0,0 @@
name: CI
on:
pull_request:
jobs:
test:
name: Test candle-book
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@master
- name: Install Rust
run: |
rustup set profile minimal
rustup toolchain install stable
rustup default stable
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir bin
curl -sSL $url | tar -xz --directory=bin
echo "$(pwd)/bin" >> $GITHUB_PATH
- name: Run tests
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/

34
.github/workflows/ci_cuda.yaml vendored Normal file
View File

@ -0,0 +1,34 @@
name: CI / cuda
on:
workflow_dispatch:
pull_request:
jobs:
test-cuda:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
group: aws-g4dn-2xlarge
container:
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
options: --gpus 0
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
permissions:
contents: write
packages: write
# This is used to complete the identity challenge
# with sigstore/fulcio when running outside of PRs.
id-token: write
security-events: write
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install dependencies
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
- name: Install Rust Stable
uses: actions-rust-lang/setup-rust-toolchain@v1
- uses: Swatinem/rust-cache@v2
- name: Test (cuda)
run: cargo test --features cuda

BIN
.github/workflows/maturin.yml vendored Normal file

Binary file not shown.

68
.github/workflows/python.yml vendored Normal file
View File

@ -0,0 +1,68 @@
name: PyO3-CI
on:
workflow_dispatch:
push:
branches:
- main
paths:
- candle-pyo3/**
pull_request:
paths:
- candle-pyo3/**
jobs:
build_and_test:
name: Check everything builds & tests
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest] # For now, only test on Linux
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
- name: Install Python
uses: actions/setup-python@v4
with:
python-version: 3.11
architecture: "x64"
- name: Cache Cargo Registry
uses: actions/cache@v1
with:
path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
- name: Install Protoc
uses: arduino/setup-protoc@v2
with:
version: "25.0"
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Install
working-directory: ./candle-pyo3
run: |
python -m venv .env
source .env/bin/activate
pip install -U pip
pip install pytest maturin black
python -m maturin develop -r --features onnx
- name: Check style
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python stub.py --check
black --check .
- name: Run tests
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python -m pytest -s -v tests

View File

@ -1,6 +1,6 @@
on:
on:
push:
branches:
branches:
- main
pull_request:
@ -15,7 +15,10 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -34,7 +37,13 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- name: Delete huge unnecessary tools folder
if: runner.os == 'Linux'
run: rm -rf /opt/hostedtoolcache
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -49,7 +58,7 @@ jobs:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -65,7 +74,7 @@ jobs:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal

15
.github/workflows/trufflehog.yml vendored Normal file
View File

@ -0,0 +1,15 @@
on:
push:
name: Secret Leaks
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

22
.gitignore vendored
View File

@ -9,6 +9,10 @@ target/
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock
# editor config
.helix
.vscode
# These are backup files generated by rustfmt
**/*.rs.bk
@ -20,11 +24,25 @@ Cargo.lock
perf.data
flamegraph.svg
*.dylib
*.so
*.swp
*.swo
trace-*.json
candle-wasm-examples/*/build
candle-wasm-examples/*/*.bin
candle-wasm-examples/*/*.wav
candle-wasm-examples/*/*.safetensors
candle-wasm-examples/*/*.jpeg
candle-wasm-examples/*/audios/*.wav
candle-wasm-examples/**/*.safetensors
candle-wasm-examples/**/*.gguf
candle-wasm-examples/*/package-lock.json
candle-wasm-examples/**/config*.json
.DS_Store
.idea/*
__pycache__
out.safetensors
out.wav
bria.mp3
bria.safetensors
bria.wav

11
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,11 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none",
"python.testing.pytestArgs": [
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

113
CHANGELOG.md Normal file
View File

@ -0,0 +1,113 @@
# Changelog
This documents the main changes to the `candle` crate.
## v0.3.1 - Unreleased
### Added
### Modified
## v0.3.0 - 2023-10-01
### Added
- Added the Mistral 7b v0.1 model
[983](https://github.com/huggingface/candle/pull/983).
- Quantized version of the Mistral model
[1009](https://github.com/huggingface/candle/pull/1009).
- Add the gelu-erf op and activation function
[969](https://github.com/huggingface/candle/pull/969).
- Add the mixformer/phi-v1.5 model
[930](https://github.com/huggingface/candle/pull/930).
- Add the sclice-scatter op
[927](https://github.com/huggingface/candle/pull/927).
- Add the Wuerstchen diffusion model
[911](https://github.com/huggingface/candle/pull/911).
### Modified
- Support for simd128 intrinsics in some quantized vecdots
[982](https://github.com/huggingface/candle/pull/982).
- Optimize the index-select cuda kernel
[976](https://github.com/huggingface/candle/pull/976).
- Self-contained safetensor wrappers
[946](https://github.com/huggingface/candle/pull/946).
## v0.2.2 - 2023-09-18
### Added
- Support for `top_p` sampling
[819](https://github.com/huggingface/candle/pull/819).
- T5 model including decoding
[864](https://github.com/huggingface/candle/pull/864).
- 1-d upsampling
[839](https://github.com/huggingface/candle/pull/839).
### Modified
- Bugfix for conv2d
[820](https://github.com/huggingface/candle/pull/820).
- Support tensor based indexing using `.i`
[842](https://github.com/huggingface/candle/pull/842).
## v0.2.1 - 2023-09-11
### Added
- Add some RNNs (GRU and LSTM) in `candle-nn`
[674](https://github.com/huggingface/candle/pull/674),
[688](https://github.com/huggingface/candle/pull/688).
- gguf v2 support
[725](https://github.com/huggingface/candle/pull/725).
- Quantized llama example in Python using the pyo3 api
[716](https://github.com/huggingface/candle/pull/716).
- `candle-nn` layer for conv2d-transposed
[760](https://github.com/huggingface/candle/pull/760).
- Add the Segment-Anything Model (SAM) as an example
[773](https://github.com/huggingface/candle/pull/773).
- TinyViT backbone for the segment anything example
[787](https://github.com/huggingface/candle/pull/787).
- Shape with holes support
[770](https://github.com/huggingface/candle/pull/770).
### Modified
- Dilations are now supported in conv-transpose2d.
[671](https://github.com/huggingface/candle/pull/671).
- Interactive mode for the quantized model
[690](https://github.com/huggingface/candle/pull/690).
- Faster softmax operation
[747](https://github.com/huggingface/candle/pull/747).
- Faster convolution operations on CPU and CUDA via im2col
[802](https://github.com/huggingface/candle/pull/802).
- Moving some models to a more central location
[796](https://github.com/huggingface/candle/pull/796).
## v0.2.0 - 2023-08-30
### Added
- Add the powf op
[664](https://github.com/huggingface/candle/pull/664).
- Stable Diffusion XL support
[647](https://github.com/huggingface/candle/pull/647).
- Add the conv-transpose2d op
[635](https://github.com/huggingface/candle/pull/635).
- Refactor the VarBuilder api
[627](https://github.com/huggingface/candle/pull/627).
- Add some quantization command
[625](https://github.com/huggingface/candle/pull/625).
- Support more quantized types, e.g. Q2K, Q4K, Q5K...
[586](https://github.com/huggingface/candle/pull/586).
- Add pose estimation to the yolo example
[589](https://github.com/huggingface/candle/pull/589).
- Api to write GGUF files
[585](https://github.com/huggingface/candle/pull/585).
- Support more quantization types
[580](https://github.com/huggingface/candle/pull/580).
- Add EfficientNet as an example Computer Vision model
[572](https://github.com/huggingface/candle/pull/572).
- Add a group parameter to convolutions
[566](https://github.com/huggingface/candle/pull/566).
- New dtype: int64
[563](https://github.com/huggingface/candle/pull/563).
- Handling of the GGUF file format.
[559](https://github.com/huggingface/candle/pull/559).
## v0.1.2 - 2023-08-21

View File

@ -6,49 +6,76 @@ members = [
"candle-nn",
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/llama2-c",
"candle-wasm-examples/whisper",
"candle-wasm-examples/*",
"candle-wasm-tests",
"tensor-tools",
]
exclude = [
"candle-flash-attn",
"candle-kernels",
"candle-book",
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
"candle-onnx",
]
resolver = "2"
[workspace.package]
version = "0.1.0"
version = "0.9.0-alpha.2"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"
license = "MIT OR Apache-2.0"
[workspace.dependencies]
ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.2" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.2" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.2" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.2" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.2" }
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.2" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.2" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.2" }
clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.13", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.5", package = "candle-gemm" }
hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.15.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
memmap2 = "0.7.1"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
num_cpus = "1.15.0"
num-traits = "0.2.15"
rand = "0.8.5"
safetensors = "0.3.1"
parquet = { version = "51.0.0" }
rand = "0.9.0"
rand_distr = "0.5.1"
rayon = "1.7.0"
safetensors = "0.4.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.13.3", default-features = false }
tokenizers = { version = "0.21.0", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
wav = "1.0.0"
zip = { version = "0.6.6", default-features = false }
ug = "0.3.1"
ug-cuda = "0.3.1"
ug-metal = "0.3.1"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
[profile.release-with-debug]
inherits = "release"

201
LICENSE-APACHE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

23
LICENSE-MIT Normal file
View File

@ -0,0 +1,23 @@
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

View File

@ -1,7 +1,11 @@
.PHONY: clean-ptx clean test
clean-ptx:
find target -name "*.ptx" -type f -delete
echo "" > candle-kernels/src/lib.rs
touch candle-kernels/build.rs
touch candle-examples/build.rs
touch candle-flash-attn/build.rs
clean:
cargo clean

392
README.md
View File

@ -1,90 +1,273 @@
# candle
[![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.gg/hugging-face-879548962464493619)
[![Latest version](https://img.shields.io/crates/v/candle-core.svg)](https://crates.io/crates/candle-core)
[![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core)
![License](https://img.shields.io/crates/l/candle-core.svg)
[![License](https://img.shields.io/github/license/base-org/node?color=blue)](https://github.com/huggingface/candle/blob/main/LICENSE-MIT)
[![License](https://img.shields.io/badge/license-Apache%202.0-blue?style=flat-square)](https://github.com/huggingface/candle/blob/main/LICENSE-APACHE)
Candle is a minimalist ML framework for Rust with a focus on easiness of use and
on performance (including GPU support). Try our online demos:
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
[Segment
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
## Get started
Make sure that you have [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) correctly installed as described in [**Installation**](https://huggingface.github.io/candle/guide/installation.html).
Let's see how to run a simple matrix multiplication.
Write the following to your `myapp/src/main.rs` file:
```rust
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
use candle_core::{Device, Tensor};
let c = a.matmul(&b)?;
println!("{c}");
fn main() -> Result<(), Box<dyn std::error::Error>> {
let device = Device::Cpu;
let a = Tensor::randn(0f32, 1., (2, 3), &device)?;
let b = Tensor::randn(0f32, 1., (3, 4), &device)?;
let c = a.matmul(&b)?;
println!("{c}");
Ok(())
}
```
`cargo run` should display a tensor of shape `Tensor[[2, 4], f32]`.
Having installed `candle` with Cuda support, simply define the `device` to be on GPU:
```diff
- let device = Device::Cpu;
+ let device = Device::new_cuda(0)?;
```
For more advanced examples, please have a look at the following section.
## Check out our examples
Check out our [examples](./candle-examples/examples/):
These online demos run entirely in your browser:
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
object recognition.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [Llama and Llama-v2](./candle-examples/examples/llama/): general LLM.
We also provide a some command line based examples using state of the art models:
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
generation.
- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level
- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM
- [Gemma v1 and v2](./candle-examples/examples/gemma/): 2b and 7b+/9b general LLMs from Google Deepmind.
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
Griffin based models from Google that mix attention with a RNN like state.
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
2.7b, and 3.8b general LLMs with performance on par with 7b models.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
- [Mamba](./candle-examples/examples/mamba/): an inference only
implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28.
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
experts 8x7b general LLM with better performance than a Llama 2 70B model with
much faster inference.
- [StarCoder](./candle-examples/examples/bigcode/) and
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
performance.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters.
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
Run them using the following commands:
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
image generative model.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
estimation models.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.od.jpg" width="200"><img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.pose.jpg" width="200">
- [segment-anything](./candle-examples/examples/segment-anything/): image
segmentation model with prompt.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmentation model.
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
model using residual vector quantization.
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
text-to-speech.
- [Parler-TTS](./candle-examples/examples/parler-tts/): large text-to-speech
model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
- [VGG](./candle-examples/examples/vgg/),
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image.
- [CLIP](./candle-examples/examples/clip/): multi-model vision and language
model.
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
dedicated submodels for hand-writing and printed recognition.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
model, generates the translated text from the input text.
- [Moondream](./candle-examples/examples/moondream/): tiny computer-vision model
that can answer real-world questions about images.
Run them using commands like:
```
cargo run --example whisper --release
cargo run --example llama --release
cargo run --example falcon --release
cargo run --example bert --release
cargo run --example bigcode --release
cargo run --example quantized --release
```
In order to use **CUDA** add `--features cuda` to the example command line.
In order to use **CUDA** add `--features cuda` to the example command line. If
you have cuDNN installed, use `--features cudnn` for even more speedups.
There are also some wasm examples for whisper and
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
`trunk` or try them online:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
[Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
For llama2, run the following command to retrieve the weight files and start a
For LLaMA2, run the following command to retrieve the weight files and start a
test server:
```bash
cd candle-wasm-examples/llama2-c
wget https://karpathy.ai/llama2c/model.bin
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
trunk serve --release --public-url /candle-llama2/ --port 8081
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
trunk serve --release --port 8081
```
And then browse to
[http://localhost:8081/candle-llama2](http://localhost:8081/candle-llama2).
And then head over to
[http://localhost:8081/](http://localhost:8081/).
<!--- ANCHOR: useful_libraries --->
## Useful External Resources
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
very detailed tutorial showing how to convert a PyTorch model to Candle.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and
ergonomic LoRA implementation for Candle. `candle-lora` has
out-of-the-box LoRA support for many models from Candle, which can be found
[here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server.
- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
- [`candle-coursera-ml`](https://github.com/vishpat/candle-coursera-ml): Implementation of ML algorithms from Coursera's [Machine Learning Specialization](https://www.coursera.org/specializations/machine-learning-introduction) course.
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.
- [`llms-from-scratch-rs`](https://github.com/nerdai/llms-from-scratch-rs): A comprehensive Rust translation of the code from Sebastian Raschka's Build an LLM from Scratch book.
If you have an addition to this list, please submit a pull request.
<!--- ANCHOR_END: useful_libraries --->
<!--- ANCHOR: features --->
## Features
- Simple syntax, looks and like PyTorch.
- CPU and Cuda backends, m1, f16, bf16.
- Enable serverless (CPU), small and fast deployments
- WASM support, run your models in a browser.
- Model training.
- Distributed computing using NCCL.
- Models out of the box: Llama, Whisper, Falcon, StarCoder...
- Embed user-defined ops/kernels, such as [flash-attention
v2](https://github.com/LaurentMazare/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
- Simple syntax, looks and feels like PyTorch.
- Model training.
- Embed user-defined ops/kernels, such as [flash-attention v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
- Backends.
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- WASM support, run your models in a browser.
- Included models.
- Language Models.
- LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.
- Falcon.
- StarCoder, StarCoder2.
- Phi 1, 1.5, 2, and 3.
- Mamba, Minimal Mamba
- Gemma v1 2b and 7b+, v2 2b and 9b.
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- Qwen1.5, Qwen1.5 MoE.
- RWKV v5 and v6.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
- Mixtral 8x7b.
- Zephyr 7b a and b (Mistral-7b based).
- OpenChat 3.5 (Mistral-7b based).
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
- Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Image to text.
- BLIP.
- TrOCR.
- Audio.
- Whisper, multi-lingual speech-to-text.
- EnCodec, audio compression model.
- MetaVoice-1B, text-to-speech model.
- Parler-TTS, text-to-speech model.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- SegFormer.
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types.
<!--- ANCHOR_END: features --->
## How to use ?
## How to use
<!--- ANCHOR: cheatsheet --->
Cheatsheet:
| | Using PyTorch | Using Candle |
|------------|------------------------------------------|------------------------------------------------------------------|
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.]], [3., 4.]], &Device::Cpu)?` |
| Creation | `torch.Tensor([[1, 2], [3, 4]])` | `Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?` |
| Creation | `torch.zeros((2, 2))` | `Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?` |
| Indexing | `tensor[:, :4]` | `tensor.i((.., ..4))?` |
| Operations | `tensor.view((2, 2))` | `tensor.reshape((2, 2))?` |
| Operations | `a.matmul(b)` | `a.matmul(&b)?` |
| Arithmetic | `a + b` | `&a + &b` |
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::Cuda(0))?` |
| Device | `tensor.to(device="cuda")` | `tensor.to_device(&Device::new_cuda(0)?)?` |
| Dtype | `tensor.to(dtype=torch.float16)` | `tensor.to_dtype(&DType::F16)?` |
| Saving | `torch.save({"A": A}, "model.bin")` | `candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?` |
| Loading | `weights = torch.load("model.bin")` | `candle::safetensors::load("model.safetensors", &device)` |
@ -95,63 +278,148 @@ Cheatsheet:
## Structure
- [candle-core](./candle-core): Core ops, devices, and `Tensor` struct definition
- [candle-nn](./candle-nn/): Facilities to build real models
- [candle-examples](./candle-examples/): Real-world like examples on how to use the library in real settings
- [candle-nn](./candle-nn/): Tools to build real models
- [candle-examples](./candle-examples/): Examples of using the library in realistic settings
- [candle-kernels](./candle-kernels/): CUDA custom kernels
- [candle-datasets](./candle-datasets/): Datasets and data loaders.
- [candle-transformers](./candle-transformers): Transformer related utilities.
- [candle-transformers](./candle-transformers): transformers-related utilities.
- [candle-flash-attn](./candle-flash-attn): Flash attention v2 layer.
- [candle-onnx](./candle-onnx/): ONNX model evaluation.
## FAQ
### Why Candle?
### Why should I use Candle?
Candle stems from the need to reduce binary size in order to *enable serverless*
possible by making the whole engine smaller than PyTorch very large library volume.
This enables creating runtimes on a cluster much faster.
Candle's core goal is to *make serverless inference possible*. Full machine learning frameworks like PyTorch
are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight
binaries.
And simply *removing Python* from production workloads.
Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
Secondly, Candle lets you *remove Python* from production workloads. Python overhead can seriously hurt performance,
and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
### Other ML frameworks
- [dfdx](https://github.com/coreylowman/dfdx) is a formidable crate, with shapes being included
in types preventing a lot of headaches by getting compiler to complain about shape mismatch right off the bat
However we found that some features still require nightly and writing code can be a bit dauting for non rust experts.
in types. This prevents a lot of headaches by getting the compiler to complain about shape mismatches right off the bat.
However, we found that some features still require nightly, and writing code can be a bit daunting for non rust experts.
We're leveraging and contributing to other core crates for the runtime so hopefully both crates can benefit from each
other
other.
- [burn](https://github.com/burn-rs/burn) is a general crate that can leverage multiple backends so you can choose the best
engine for your workload
engine for your workload.
- [tch-rs](https://github.com/LaurentMazare/tch-rs.git) Bindings to the torch library in Rust. Extremely versatile, but they
do bring in the entire torch library into the runtime. The main contributor of `tch-rs` is also involved in the development
bring in the entire torch library into the runtime. The main contributor of `tch-rs` is also involved in the development
of `candle`.
### Missing symbols when compiling with the mkl feature.
### Common Errors
#### Missing symbols when compiling with the mkl feature.
If you get some missing symbols when compiling binaries/tests using the mkl
features, e.g.:
or accelerate features, e.g. for mkl you get:
```
= note: /usr/bin/ld: (....o): in function `blas::sgemm':
.../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status
= note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
= note: use the `-l` flag to specify native libraries to link
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo (see https://doc.rust-lang.org/cargo/reference/build-scripts.html#cargorustc-link-libkindname)
= note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo
```
or for accelerate:
```
Undefined symbols for architecture arm64:
"_dgemm_", referenced from:
candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
"_sgemm_", referenced from:
candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
ld: symbol(s) not found for architecture arm64
```
This is likely due to some missing linker flag that enable the mkl library. You
can try adding the following at the top of your binary:
```
This is likely due to a missing linker flag that was needed to enable the mkl library. You
can try adding the following for mkl at the top of your binary:
```rust
extern crate intel_mkl_src;
```
or for accelerate:
```rust
extern crate accelerate_src;
```
### How to know where an error comes from.
#### Cannot run the LLaMA examples: access to source requires login credentials
```
Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401
```
This is likely because you're not permissioned for the LLaMA-v2 model. To fix
this, you have to register on the huggingface-hub, accept the [LLaMA-v2 model
conditions](https://huggingface.co/meta-llama/Llama-2-7b-hf), and set up your
authentication token. See issue
[#350](https://github.com/huggingface/candle/issues/350) for more details.
#### Missing cute/cutlass headers when compiling flash-attn
```
In file included from kernels/flash_fwd_launch_template.h:11:0,
from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
#include <cute/algorithm/copy.hpp>
^~~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
Error: nvcc error while compiling:
```
[cutlass](https://github.com/NVIDIA/cutlass) is provided as a git submodule so you may want to run the following command to check it in properly.
```bash
git submodule update --init
```
#### Compiling with flash-attention fails
```
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ...:
```
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
```
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
```
#### Linking error on windows when running rustdoc or mdbook tests
```
Couldn't compile the test.
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
error: linking with `link.exe` failed: exit code: 1181
//very long chain of linking
= note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
```
Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
```
mdbook test candle-book -L .\target\debug\deps\ `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
```
#### Extremely slow model load time with WSL
This may be caused by the models being loaded from `/mnt/c`, more details on
[stackoverflow](https://stackoverflow.com/questions/68972448/why-is-wsl-extremely-slow-when-compared-with-native-windows-npm-yarn-processing).
#### Tracking down errors
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
error is generated.
#### CudaRC error
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`

48
candle-book/Cargo.toml Normal file
View File

@ -0,0 +1,48 @@
[package]
name = "candle-book"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { workspace = true }
candle-datasets = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true, optional = true }
anyhow = { workspace = true }
tokio = "1.43.0"
[dev-dependencies]
byteorder = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
parquet = { workspace = true }
image = { workspace = true }
[build-dependencies]
anyhow = { workspace = true }
[features]
default = []

View File

@ -10,18 +10,19 @@
# Reference Guide
- [Running a model](inference/README.md)
- [Running a model](inference/inference.md)
- [Using the hub](inference/hub.md)
- [Serialization](inference/serialization.md)
- [Advanced Cuda usage](inference/cuda/README.md)
- [Writing a custom kernel](inference/cuda/writing.md)
- [Porting a custom kernel](inference/cuda/porting.md)
- [Error management](error_manage.md)
- [Creating apps](apps/README.md)
- [Creating a WASM app](apps/wasm.md)
- [Creating a REST api webserver](apps/rest.md)
- [Creating a desktop Tauri app](apps/dekstop.md)
- [Training](training/README.md)
- [Training](training/training.md)
- [Simplified](training/simplified.md)
- [MNIST](training/mnist.md)
- [Fine-tuning](training/finetuning.md)
- [Using MKL](advanced/mkl.md)
- [Fine-tuning]()
- [Serialization]()
- [Advanced Cuda usage]()
- [Writing a custom kernel]()
- [Porting a custom kernel]()
- [Using MKL]()
- [Creating apps]()
- [Creating a WASM app]()
- [Creating a REST api webserver]()
- [Creating a desktop Tauri app]()

View File

@ -0,0 +1 @@
# Advanced Cuda usage

View File

@ -0,0 +1 @@
# Porting a custom kernel

View File

@ -0,0 +1 @@
# Writing a custom kernel

View File

@ -1 +1,51 @@
# Error management
You might have seen in the code base a lot of `.unwrap()` or `?`.
If you're unfamiliar with Rust check out the [Rust book](https://doc.rust-lang.org/book/ch09-02-recoverable-errors-with-result.html)
for more information.
What's important to know though, is that if you want to know *where* a particular operation failed
You can simply use `RUST_BACKTRACE=1` to get the location of where the model actually failed.
Let's see on failing code:
```rust,ignore
let x = Tensor::zeros((1, 784), DType::F32, &device)?;
let y = Tensor::zeros((1, 784), DType::F32, &device)?;
let z = x.matmul(&y)?;
```
Will print at runtime:
```bash
Error: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }
```
After adding `RUST_BACKTRACE=1`:
```bash
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
```
Not super pretty at the moment, but we can see error occurred on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
especially in release builds. We're using [`anyhow`](https://docs.rs/anyhow/latest/anyhow/) for that.
The library is still young, please [report](https://github.com/LaurentMazare/candle/issues) any issues detecting where an error is coming from.
## Cuda error management
When running a model on Cuda, you might get a stacktrace not really representing the error.
The reason is that CUDA is async by nature, and therefore the error might be caught while you were sending totally different kernels.
One way to avoid this is to use `CUDA_LAUNCH_BLOCKING=1` as an environment variable. This will force every kernel to be launched sequentially.
You might still however see the error happening on other kernels as the faulty kernel might exit without an error but spoiling some pointer for which the error will happen when dropping the `CudaSlice` only.
If this occurs, you can use [`compute-sanitizer`](https://docs.nvidia.com/compute-sanitizer/ComputeSanitizer/index.html)
This tool is like `valgrind` but for cuda. It will help locate the errors in the kernels.

View File

@ -6,7 +6,7 @@ Open `src/main.rs` and fill in this content:
```rust
# extern crate candle_core;
use candle_core::{DType, Device, Result, Tensor};
use candle_core::{Device, Result, Tensor};
struct Model {
first: Tensor,
@ -25,11 +25,11 @@ fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to use the GPU.
let device = Device::Cpu;
let first = Tensor::zeros((784, 100), DType::F32, &device)?;
let second = Tensor::zeros((100, 10), DType::F32, &device)?;
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
@ -50,7 +50,7 @@ the classical `Linear` layer. We can do as such
```rust
# extern crate candle_core;
# use candle_core::{DType, Device, Result, Tensor};
# use candle_core::{Device, Result, Tensor};
struct Linear{
weight: Tensor,
bias: Tensor,
@ -80,7 +80,7 @@ This will change the model running code into a new function
```rust
# extern crate candle_core;
# use candle_core::{DType, Device, Result, Tensor};
# use candle_core::{Device, Result, Tensor};
# struct Linear{
# weight: Tensor,
# bias: Tensor,
@ -110,15 +110,15 @@ fn main() -> Result<()> {
let device = Device::cuda_if_available(0)?;
// Creating a dummy model
let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear{weight, bias};
let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear{weight, bias};
let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
// Inference on the model
let digit = model.forward(&dummy_image)?;
@ -128,17 +128,17 @@ fn main() -> Result<()> {
```
Now it works, it is a great way to create your own layers.
But most of the classical layers are already implemented in [candle-nn](https://github.com/LaurentMazare/candle/tree/main/candle-nn).
But most of the classical layers are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
## Using `candle_nn`.
For instance [Linear](https://github.com/LaurentMazare/candle/blob/main/candle-nn/src/linear.rs) is already there.
For instance [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs) is already there.
This Linear is coded with PyTorch layout in mind, to reuse better existing models out there, so it uses the transpose of the weights and not the weights directly.
So instead we can simplify our example:
```bash
cargo add --git https://github.com/LaurentMazare/candle.git candle-nn
cargo add --git https://github.com/huggingface/candle.git candle-nn
```
And rewrite our examples using it
@ -146,8 +146,8 @@ And rewrite our examples using it
```rust
# extern crate candle_core;
# extern crate candle_nn;
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::Linear;
use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module};
struct Model {
first: Linear,
@ -167,15 +167,15 @@ fn main() -> Result<()> {
let device = Device::Cpu;
// This has changed (784, 100) -> (100, 784) !
let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear::new(weight, Some(bias));
let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear::new(weight, Some(bias));
let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
@ -188,8 +188,8 @@ Feel free to modify this example to use `Conv2d` to create a classical convnet i
Now that we have the running dummy code we can get to more advanced topics:
- [For PyTorch users](./guide/cheatsheet.md)
- [Running existing models](./inference/README.md)
- [Training models](./training/README.md)
- [For PyTorch users](../guide/cheatsheet.md)
- [Running existing models](../inference/inference.md)
- [Training models](../training/training.md)

View File

@ -1,24 +1,59 @@
# Installation
Start by creating a new app:
**With Cuda support**:
1. First, make sure that Cuda is correctly installed.
- `nvcc --version` should print information about your Cuda compiler driver.
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
like:
```bash
compute_cap
8.9
```
You can also compile the Cuda kernels for a specific compute cap using the
`CUDA_COMPUTE_CAP=<compute cap>` environment variable.
If any of the above commands errors out, please make sure to update your Cuda version.
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
Start by creating a new cargo:
```bash
cargo new myapp
cd myapp
cargo add --git https://github.com/LaurentMazare/candle.git candle
```
At this point, candle will be built **without** CUDA support.
To get CUDA support use the `cuda` feature
Make sure to add the `candle-core` crate with the cuda feature:
```bash
cargo add --git https://github.com/LaurentMazare/candle.git candle --features cuda
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
```
You can check everything works properly:
Run `cargo build` to make sure everything can be correctly built.
```bash
cargo build
```
**Without Cuda support**:
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
```bash
cargo new myapp
cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core
```
Finally, run `cargo build` to make sure everything can be correctly built.
```bash
cargo build
```
**With mkl support**
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)

View File

@ -1 +0,0 @@
# Running a model

View File

@ -1 +1,104 @@
# Using the hub
Install the [`hf-hub`](https://github.com/huggingface/hf-hub) crate:
```bash
cargo add hf-hub
```
Then let's start by downloading the [model file](https://huggingface.co/bert-base-uncased/tree/main).
```rust
# extern crate candle_core;
# extern crate hf_hub;
use hf_hub::api::sync::Api;
use candle_core::Device;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights = repo.get("model.safetensors").unwrap();
let weights = candle_core::safetensors::load(weights, &Device::Cpu);
```
We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.
You can check all the names of the tensors [here](https://huggingface.co/bert-base-uncased?show_tensors=true)
## Using async
`hf-hub` comes with an async API.
```bash
cargo add hf-hub --features tokio
```
```rust,ignore
# This is tested directly in examples crate because it needs external dependencies unfortunately:
# See [this](https://github.com/rust-lang/mdBook/issues/706)
{{#include ../lib.rs:book_hub_1}}
```
## Using in a real model.
Now that we have our weights, we can use them in our bert architecture:
```rust
# extern crate candle_core;
# extern crate candle_nn;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());
#
# let weights = repo.get("model.safetensors").unwrap();
use candle_core::{Device, Tensor, DType};
use candle_nn::{Linear, Module};
let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap();
let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();
let linear = Linear::new(weight.clone(), Some(bias.clone()));
let input_ids = Tensor::zeros((3, 768), DType::F32, &Device::Cpu).unwrap();
let output = linear.forward(&input_ids).unwrap();
```
For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.
## Memory mapping
For more efficient loading, instead of reading the file, you could use [`memmap2`](https://docs.rs/memmap2/latest/memmap2/)
**Note**: Be careful about memory mapping it seems to cause issues on [Windows, WSL](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5893)
and will definitely be slower on network mounted disk, because it will issue more read calls.
```rust,ignore
{{#include ../lib.rs:book_hub_2}}
```
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind.
## Tensor Parallel Sharding
When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need.
For that you need to use [`safetensors`](https://crates.io/crates/safetensors) directly.
```bash
cargo add safetensors
```
```rust,ignore
{{#include ../lib.rs:book_hub_3}}
```

View File

@ -0,0 +1,7 @@
# Running a model
In order to run an existing model, you will need to download and use existing weights.
Most models are already available on https://huggingface.co/ in [`safetensors`](https://github.com/huggingface/safetensors) format.
Let's get started by running an old model : `bert-base-uncased`.

199
candle-book/src/lib.rs Normal file
View File

@ -0,0 +1,199 @@
#[cfg(test)]
pub mod simplified;
#[cfg(test)]
mod tests {
use anyhow::Result;
use candle::{DType, Device, Tensor};
use parquet::file::reader::SerializedFileReader;
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
#[tokio::test]
async fn book_hub_1() {
// ANCHOR: book_hub_1
use candle::Device;
use hf_hub::api::tokio::Api;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").await.unwrap();
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_1
assert_eq!(weights.len(), 206);
}
#[rustfmt::skip]
#[test]
fn book_hub_2() {
{
// ANCHOR: book_hub_2
use candle::Device;
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_2
assert_eq!(weights.len(), 206);
}
// #[rustfmt::skip]
// #[test]
// fn book_hub_3() {
{
// ANCHOR: book_hub_3
use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use safetensors::slice::IndexOp;
use safetensors::SafeTensors;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
// Use safetensors directly
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
let view = tensors
.tensor("bert.encoder.layer.0.attention.self.query.weight")
.unwrap();
// We're going to load shard with rank 1, within a world_size of 4
// We're going to split along dimension 0 doing VIEW[start..stop, :]
let rank = 1;
let world_size = 4;
let dim = 0;
let dtype = view.dtype();
let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];
if size % world_size != 0 {
panic!("The dimension is not divisible by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
// Everything is expressed in tensor dimension
// bytes offsets is handled automatically for safetensors.
let iterator = view.slice(start..stop).unwrap();
tp_shape[dim] = block_size;
// Convert safetensors Dtype to candle DType
let dtype: DType = dtype.try_into().unwrap();
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_3
assert_eq!(view.shape(), &[768, 768]);
assert_eq!(tp_tensor.dims(), &[192, 768]);
}
}
#[allow(unused)]
#[rustfmt::skip]
fn book_training_1() -> Result<()>{
// ANCHOR: book_training_1
use hf_hub::{api::sync::Api, Repo, RepoType};
let dataset_id = "mnist".to_string();
let api = Api::new()?;
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR_END: book_training_1
// Ignore unused
let _train = train_parquet;
// ANCHOR: book_training_2
for row in test_parquet {
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
println!("Column id {idx}, name {name}, value {field}");
}
}
// ANCHOR_END: book_training_2
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR: book_training_3
let test_samples = 10_000;
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
for row in test_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
test_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
test_buffer_labels.push(*label as u8);
}
}
}
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
let train_samples = 60_000;
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
for row in train_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
train_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
train_buffer_labels.push(*label as u8);
}
}
}
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
let mnist = candle_datasets::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
};
// ANCHOR_END: book_training_3
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
assert_eq!(mnist.test_labels.dims(), &[10_000]);
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
assert_eq!(mnist.train_labels.dims(), &[60_000]);
Ok(())
}
}

View File

@ -0,0 +1,196 @@
//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.
//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com
//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
//!
//! ##Basic moments:
//!
//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
//! For training, samples with real data on the results of the first and second stages of different elections are used.
//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
//! Model parameters (weights of neurons) are initialized randomly, then optimized during training.
//! After training, the model is tested on a deferred sample to evaluate the accuracy.
//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
#[rustfmt::skip]
mod tests {
use candle::{DType, Result, Tensor, D, Device};
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer};
// ANCHOR: book_training_simplified1
const VOTE_DIM: usize = 2;
const RESULTS: usize = 1;
const EPOCHS: usize = 10;
const LAYER1_OUT_SIZE: usize = 4;
const LAYER2_OUT_SIZE: usize = 2;
const LEARNING_RATE: f64 = 0.05;
#[derive(Clone)]
pub struct Dataset {
pub train_votes: Tensor,
pub train_results: Tensor,
pub test_votes: Tensor,
pub test_results: Tensor,
}
struct MultiLevelPerceptron {
ln1: Linear,
ln2: Linear,
ln3: Linear,
}
impl MultiLevelPerceptron {
fn new(vs: VarBuilder) -> Result<Self> {
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
Ok(Self { ln1, ln2, ln3 })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.ln1.forward(xs)?;
let xs = xs.relu()?;
let xs = self.ln2.forward(&xs)?;
let xs = xs.relu()?;
self.ln3.forward(&xs)
}
}
// ANCHOR_END: book_training_simplified1
// ANCHOR: book_training_simplified3
#[tokio::test]
async fn simplified() -> anyhow::Result<()> {
let dev = Device::cuda_if_available(0)?;
let train_votes_vec: Vec<u32> = vec![
15, 10,
10, 15,
5, 12,
30, 20,
16, 12,
13, 25,
6, 14,
31, 21,
];
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let train_results_vec: Vec<u32> = vec![
1,
0,
0,
1,
1,
0,
0,
1,
];
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
let test_votes_vec: Vec<u32> = vec![
13, 9,
8, 14,
3, 10,
];
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let test_results_vec: Vec<u32> = vec![
1,
0,
0,
];
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
let m = Dataset {
train_votes: train_votes_tensor,
train_results: train_results_tensor,
test_votes: test_votes_tensor,
test_results: test_results_tensor,
};
let trained_model: MultiLevelPerceptron;
loop {
println!("Trying to train neural network.");
match train(m.clone(), &dev) {
Ok(model) => {
trained_model = model;
break;
},
Err(e) => {
println!("Error: {}", e);
continue;
}
}
}
let real_world_votes: Vec<u32> = vec![
13, 22,
];
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let final_result = trained_model.forward(&tensor_test_votes)?;
let result = final_result
.argmax(D::Minus1)?
.to_dtype(DType::F32)?
.get(0).map(|x| x.to_scalar::<f32>())??;
println!("real_life_votes: {:?}", real_world_votes);
println!("neural_network_prediction_result: {:?}", result);
Ok(())
}
// ANCHOR_END: book_training_simplified3
// ANCHOR: book_training_simplified2
fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
let train_results = m.train_results.to_device(dev)?;
let train_votes = m.train_votes.to_device(dev)?;
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
let model = MultiLevelPerceptron::new(vs.clone())?;
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
let test_votes = m.test_votes.to_device(dev)?;
let test_results = m.test_results.to_device(dev)?;
let mut final_accuracy: f32 = 0.0;
for epoch in 1..EPOCHS + 1 {
let logits = model.forward(&train_votes)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_results)?;
sgd.backward_step(&loss)?;
let test_logits = model.forward(&test_votes)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_results)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_results.dims1()? as f32;
final_accuracy = 100. * test_accuracy;
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
loss.to_scalar::<f32>()?,
final_accuracy
);
if final_accuracy == 100.0 {
break;
}
}
if final_accuracy < 100.0 {
Err(anyhow::Error::msg("The model is not trained well enough."))
} else {
Ok(model)
}
}
// ANCHOR_END: book_training_simplified2
}

View File

@ -1 +0,0 @@
# Training

View File

@ -1 +1,10 @@
# MNIST
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
```rust,ignore
{{#include ../lib.rs:book_training_3}}
```
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
It is quite rudimentary, but simple enough for a small dataset like MNIST.

View File

@ -0,0 +1,45 @@
# Simplified
## How its works
This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
Basic moments:
1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
4. For training, samples with real data on the results of the first and second stages of different elections are used.
5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
7. After training, the model is tested on a deferred sample to evaluate the accuracy.
8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
```rust,ignore
{{#include ../simplified.rs:book_training_simplified1}}
```
```rust,ignore
{{#include ../simplified.rs:book_training_simplified2}}
```
```rust,ignore
{{#include ../simplified.rs:book_training_simplified3}}
```
## Example output
```bash
Trying to train neural network.
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
real_life_votes: [13, 22]
neural_network_prediction_result: 0.0
```

View File

@ -0,0 +1,39 @@
# Training
Training starts with data. We're going to use the huggingface hub and
start with the Hello world dataset of machine learning, MNIST.
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
```bash
cargo add hf-hub
```
This is going to be very hands-on for now.
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
```
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
We can inspect the content of the files with:
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
```
You should see something like:
```bash
Column id 1, name label, value 6
Column id 0, name image, value {bytes: [137, ....]
Column id 1, name label, value 8
Column id 0, name image, value {bytes: [137, ....]
```
So each row contains 2 columns (image, label) with image being saved as bytes.
Let's put them into a useful struct.

View File

@ -10,8 +10,11 @@ license.workspace = true
readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.1.0", optional = true }
candle-kernels = { workspace = true, optional = true }
candle-metal-kernels = { workspace = true, optional = true }
metal = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
@ -21,14 +24,39 @@ memmap2 = { workspace = true }
num-traits = { workspace = true }
num_cpus = { workspace = true }
rand = { workspace = true }
rand_distr = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
ug-cuda = { workspace = true, optional = true }
ug-metal = { workspace = true, optional = true }
yoke = { workspace = true }
zip = { workspace = true }
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
ug = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
criterion = { workspace = true }
[features]
default = []
cuda = ["dep:cudarc", "dep:candle-kernels"]
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]
[[bench]]
name = "bench_main"
harness = false
[[example]]
name = "metal_basics"
required-features = ["metal"]
[[example]]
name = "cuda_basics"
required-features = ["cuda"]

View File

@ -0,0 +1,14 @@
mod benchmarks;
use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches,
benchmarks::random::benches,
benchmarks::reduce::benches,
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,
benchmarks::unary::benches
);

View File

@ -0,0 +1,43 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor) {
a.affine(12.34, 56.78).unwrap();
}
fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let b = 1;
let m = 1024;
let k = 1024;
let tensor = Tensor::zeros((b, m, k), dtype, device).unwrap();
let flops = b * m * k * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_affine_benchmark(c, &device, DType::F32, "affine_f32");
run_affine_benchmark(c, &device, DType::F16, "affine_f16");
run_affine_benchmark(c, &device, DType::BF16, "affine_bf16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -0,0 +1,59 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(
x: &Tensor,
k: &Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) {
x.conv_transpose2d(k, padding, output_padding, stride, dilation)
.unwrap();
}
fn run_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let t = Tensor::arange(0.0f32, 10000.0, device)
.unwrap()
.reshape((1, 4, 50, 50))
.unwrap()
.to_dtype(dtype)
.unwrap();
let kernel = Tensor::arange(0.0f32, 100.0, device)
.unwrap()
.reshape((4, 1, 5, 5))
.unwrap()
.to_dtype(dtype)
.unwrap();
let flops = t.dims().iter().product::<usize>() * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&t), black_box(&kernel), 1, 0, 1, 2);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_benchmark(c, &device, DType::F32, "conv_transpose2d_f32");
run_benchmark(c, &device, DType::F16, "conv_transpose2d_f16");
run_benchmark(c, &device, DType::BF16, "conv_transpose2d_bf16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -0,0 +1,44 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor) {
a.matmul(&b.t().unwrap()).unwrap();
}
fn run_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let m = 1;
let n = 2048;
let k = 2048;
let dtype = DType::F32;
let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group(device.bench_name("matmul"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&lhs), black_box(&rhs));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -0,0 +1,72 @@
pub(crate) mod affine;
pub(crate) mod conv_transpose2d;
pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;
pub(crate) mod reduce;
pub(crate) mod unary;
pub(crate) mod where_cond;
use candle_core::{Device, Result};
pub(crate) trait BenchDevice {
fn sync(&self) -> Result<()>;
fn bench_name<S: Into<String>>(&self, name: S) -> String;
}
impl BenchDevice for Device {
fn sync(&self) -> Result<()> {
match self {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device
.synchronize()
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Device::Metal(device) => {
#[cfg(feature = "metal")]
return Ok(device.wait_until_completed()?);
#[cfg(not(feature = "metal"))]
panic!("Metal device without metal feature enabled: {:?}", device)
}
}
}
fn bench_name<S: Into<String>>(&self, name: S) -> String {
match self {
Device::Cpu => {
let cpu_type = if cfg!(feature = "accelerate") {
"accelerate"
} else if cfg!(feature = "mkl") {
"mkl"
} else {
"cpu"
};
format!("{}_{}", cpu_type, name.into())
}
Device::Cuda(_) => format!("cuda_{}", name.into()),
Device::Metal(_) => format!("metal_{}", name.into()),
}
}
}
struct BenchDeviceHandler {
devices: Vec<Device>,
}
impl BenchDeviceHandler {
pub fn new() -> Result<Self> {
let mut devices = Vec::new();
if cfg!(feature = "metal") {
devices.push(Device::new_metal(0)?);
} else if cfg!(feature = "cuda") {
devices.push(Device::new_cuda(0)?);
}
devices.push(Device::Cpu);
Ok(Self { devices })
}
}

View File

@ -0,0 +1,72 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{
quantized::{self, GgmlDType, QMatMul},
Device, Module, Tensor,
};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(matmul: &QMatMul, x: &Tensor) {
matmul.forward(x).unwrap();
}
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
let b = 1;
let m = 1;
let n = 1024;
let k = 1024;
let lhs = (0..(m * k))
.map(|v| v as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..(k * n))
.map(|v| v as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
group.sample_size(200);
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&matmul), black_box(&lhs));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in [
GgmlDType::F32,
GgmlDType::F16,
GgmlDType::Q4_0,
GgmlDType::Q4_1,
GgmlDType::Q5_0,
GgmlDType::Q5_1,
GgmlDType::Q8_0,
GgmlDType::Q2K,
GgmlDType::Q3K,
GgmlDType::Q4K,
GgmlDType::Q5K,
GgmlDType::Q6K,
] {
run_bench(c, &device, dtype);
}
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -0,0 +1,63 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn rand_uniform(a: &Tensor) {
a.rand_like(-1.0, 123.0).unwrap();
}
fn rand_normal(a: &Tensor) {
a.randn_like(100.0, 15.0).unwrap();
}
fn run_random_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let rows = 2048;
let cols = 2048;
let dtype = DType::F32;
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let flops = b * rows * cols * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_uniform(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let mut group = c.benchmark_group(device.bench_name("random_normal"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_normal(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_random_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -0,0 +1,158 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use half::{bf16, f16};
use std::time::Instant;
fn run_sum(a: &Tensor) {
a.sum_keepdim(2).unwrap();
}
fn run_arg_min(a: &Tensor) {
a.argmin_keepdim(2).unwrap();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
let (lo, up) = (-1000.0f32, 1000.0f32);
for device in handler.devices {
run_reduce(c, &device, (lo, up), false);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
run_arg_reduce(c, &device, (lo, up), false);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
run_reduce(c, &device, (lo, up), true);
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
run_arg_reduce(c, &device, (lo, up), true);
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
}
}
fn run_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;
let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};
let flops = b * m * k * T::DTYPE.size_in_bytes();
let name = match T::DTYPE {
DType::F32 => {
if strided {
"reduce_f32_strided"
} else {
"reduce_f32"
}
}
DType::F16 => {
if strided {
"reduce_f16_strided"
} else {
"reduce_f16"
}
}
DType::BF16 => {
if strided {
"reduce_bf16_strided"
} else {
"reduce_bf16"
}
}
_ => "unknown",
};
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run_sum(black_box(&a));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn run_arg_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
strided: bool,
) {
let b = 1;
let m = 1024;
let k = 1024;
let a = if strided {
Tensor::rand(lo, up, (b, m, k), &device)
.unwrap()
.transpose(0, 2)
.unwrap()
} else {
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
};
let flops = b * m * k * T::DTYPE.size_in_bytes();
let name = match T::DTYPE {
DType::F32 => {
if strided {
"arg_reduce_f32_strided"
} else {
"arg_reduce_f32"
}
}
DType::F16 => {
if strided {
"arg_reduce_f16_strided"
} else {
"arg_reduce_f16"
}
}
DType::BF16 => {
if strided {
"arg_reduce_bf16_strided"
} else {
"arg_reduce_bf16"
}
}
_ => "unknown",
};
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run_arg_min(black_box(&a));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
criterion_group!(benches, criterion_benchmark);

View File

@ -0,0 +1,49 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor) {
a.sqrt().unwrap();
}
fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let b = 1;
let m = 1024;
let k = 1024;
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device)
.unwrap()
.to_dtype(dtype)
.unwrap()
.reshape((b, m, k))
.unwrap();
let flops = b * m * k * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
for dtype in [DType::F32, DType::BF16, DType::F16] {
let name = format!("sqrt_{:?}", dtype);
run_unary_benchmark(c, &device, dtype, &name);
}
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -0,0 +1,64 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor, c: &Tensor) {
a.where_cond(b, c).unwrap();
}
const fn create_cond_arr<const N: usize>() -> [u8; N] {
let mut arr = [0u8; N];
let mut i = 0;
while i < N {
arr[i] = (i % 2) as u8;
i += 1;
}
arr
}
const B: usize = 1;
const M: usize = 1024;
const K: usize = 1024;
const SIZE: usize = B * M * K;
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
let on_true = Tensor::ones((B, M, K), dtype, device).unwrap();
let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap();
let elements = B * M * K;
// E.g. 2 f32 tensors + 1 u8 tensor
let flops = (2 * elements * dtype.size_in_bytes()) + elements;
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(
black_box(&tensor),
black_box(&on_true),
black_box(&on_false),
);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let device = BenchDeviceHandler::new().unwrap();
for d in device.devices {
run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32");
run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16");
run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,29 +1,17 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
let c = a.matmul(&b)?;
println!("{a} {b} {c}");
let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
let t1 = Tensor::new(data, &Device::Cpu)?;
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];
let t2 = Tensor::new(data2, &Device::Cpu)?;
assert_eq!(
Tensor::cat(&[&t1.t()?, &t2.t()?], 1)?
.t()?
.to_vec2::<f32>()?,
[
[3.0, 1.0, 4.0, 1.0, 5.0],
[2.0, 7.0, 1.0, 8.0, 2.0],
[5.0, 5.0, 5.0, 5.0, 5.0],
[2.0, 7.0, 1.0, 8.0, 2.0]
]
);
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
let new_a = a.slice_scatter(&b, 1, 2)?;
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
Ok(())
}

View File

@ -1,3 +1,6 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@ -6,10 +9,25 @@ use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
let sum = t.sum_keepdim(0)?;
println!("{sum}");
let sum = t.sum_keepdim(1)?;
println!("{sum}");
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("tf32: {:?}", start_time.elapsed());
drop(_x1);
Ok(())
}

View File

@ -1,6 +1,9 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use std::str::FromStr;
use anyhow::Result;

View File

@ -0,0 +1,28 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
// This requires the code to be run with MTL_CAPTURE_ENABLED=1
let device = Device::new_metal(0)?;
let metal_device = match &device {
Device::Metal(m) => m,
_ => anyhow::bail!("unexpected device"),
};
metal_device.capture("/tmp/candle.gputrace")?;
// This first synchronize ensures that a new command buffer gets created after setting up the
// capture scope.
device.synchronize()?;
let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?;
let x1 = x.add(&x)?;
println!("{x1:?}");
// This second synchronize ensures that the command buffer gets commited before the end of the
// capture scope.
device.synchronize()?;
Ok(())
}

View File

@ -0,0 +1,476 @@
#![allow(dead_code)]
use libc::{c_char, c_double, c_float, c_int, c_long, c_ulong};
mod ffi {
use super::*;
extern "C" {
// It would be nice to be able to switch to the NEWLAPACK version of the function but this
// seems to trigger some link error. Available function names can be seen here:
// /Library/Developer/CommandLineTools/SDKs/MacOSX13.3.sdk/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate.tbd
#[link_name = "sgemm_"]
pub fn sgemm_ffi(
transa: *const c_char,
transb: *const c_char,
m: *const c_int,
n: *const c_int,
k: *const c_int,
alpha: *const c_float,
a: *const c_float,
lda: *const c_int,
b: *const c_float,
ldb: *const c_int,
beta: *const c_float,
c: *mut c_float,
ldc: *const c_int,
);
#[link_name = "dgemm_"]
pub fn dgemm_ffi(
transa: *const c_char,
transb: *const c_char,
m: *const c_int,
n: *const c_int,
k: *const c_int,
alpha: *const c_double,
a: *const c_double,
lda: *const c_int,
b: *const c_double,
ldb: *const c_int,
beta: *const c_double,
c: *mut c_double,
ldc: *const c_int,
);
pub fn vvexpf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvexp(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvsqrtf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvsqrt(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvsinf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvsin(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvcosf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vvtanhf(dst: *mut c_float, src: *const c_float, len: *const c_int);
pub fn vvtanh(dst: *mut c_double, src: *const c_double, len: *const c_int);
pub fn vDSP_vaddD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vadd(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vsubD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vsub(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmulD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmul(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vdivD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vdiv(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vminD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmin(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmaxD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmax(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
}
}
#[allow(clippy::too_many_arguments)]
#[inline]
pub unsafe fn sgemm(
transa: u8,
transb: u8,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: &[f32],
lda: i32,
b: &[f32],
ldb: i32,
beta: f32,
c: &mut [f32],
ldc: i32,
) {
ffi::sgemm_ffi(
&(transa as c_char),
&(transb as c_char),
&m,
&n,
&k,
&alpha,
a.as_ptr(),
&lda,
b.as_ptr(),
&ldb,
&beta,
c.as_mut_ptr(),
&ldc,
)
}
#[allow(clippy::too_many_arguments)]
#[inline]
pub unsafe fn dgemm(
transa: u8,
transb: u8,
m: i32,
n: i32,
k: i32,
alpha: f64,
a: &[f64],
lda: i32,
b: &[f64],
ldb: i32,
beta: f64,
c: &mut [f64],
ldc: i32,
) {
ffi::dgemm_ffi(
&(transa as c_char),
&(transb as c_char),
&m,
&n,
&k,
&alpha,
a.as_ptr(),
&lda,
b.as_ptr(),
&ldb,
&beta,
c.as_mut_ptr(),
&ldc,
)
}
#[inline]
pub fn vs_exp(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvexpf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_exp(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvexp(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_sqrt(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvsqrtf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_sqrt(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvsqrt(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_sin(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvsinf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_sin(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvsin(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_cos(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvcosf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_cos(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvtanhf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvtanh(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvlogf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vd_ln(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
unsafe { ffi::vvlog(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
}
#[inline]
pub fn vs_sqr(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
}
#[inline]
pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
panic!("a and y have different lengths {a_len} <> {y_len}")
}
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
}
#[inline]
pub fn vs_tanh_inplace(y: &mut [f32]) {
unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vs_exp_inplace(y: &mut [f32]) {
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vd_exp_inplace(y: &mut [f64]) {
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vs_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
#[inline]
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vd_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
#[inline]
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vs_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
#[inline]
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vd_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline]
pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) {
let a_len = a.len();
let b_len = b.len();
let y_len = y.len();
if a_len != y_len || b_len != y_len {
panic!(
"{} a,b,y len mismatch {a_len} {b_len} {y_len}",
stringify!($fn_name)
);
}
unsafe {
// Weird quirk of accelerate, the rhs comes before the lhs.
ffi::$accelerate_name(
b.as_ptr(),
1,
a.as_ptr(),
1,
y.as_mut_ptr(),
1,
a_len as u64,
)
}
}
};
}
binary_op!(vs_add, f32, vDSP_vadd);
binary_op!(vd_add, f64, vDSP_vaddD);
binary_op!(vs_sub, f32, vDSP_vsub);
binary_op!(vd_sub, f64, vDSP_vsubD);
binary_op!(vs_mul, f32, vDSP_vmul);
binary_op!(vd_mul, f64, vDSP_vmulD);
binary_op!(vs_div, f32, vDSP_vdiv);
binary_op!(vd_div, f64, vDSP_vdivD);
binary_op!(vs_max, f32, vDSP_vmax);
binary_op!(vd_max, f64, vDSP_vmaxD);
binary_op!(vs_min, f32, vDSP_vmin);
binary_op!(vd_min, f64, vDSP_vminD);

View File

@ -1,3 +1,5 @@
//! Traits to Define Backend Behavior
//!
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
@ -15,6 +17,8 @@ pub trait BackendStorage: Sized {
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
fn powf(&self, _: &Layout, _: f64) -> Result<Self>;
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
@ -37,6 +41,35 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self>;
fn conv2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConv2D,
) -> Result<Self>;
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self>;
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,
@ -67,6 +100,19 @@ pub trait BackendStorage: Sized {
) -> Result<Self>;
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
#[allow(clippy::too_many_arguments)]
// Similar to cudaMemcpy2D, though values are in elements and not in bytes.
fn copy2d(
&self,
_: &mut Self,
_d1: usize,
_d2: usize,
_src_stride1: usize,
_dst_stride1: usize,
_src_offset: usize,
_dst_offset: usize,
) -> Result<()>;
}
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
@ -83,9 +129,24 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
/// # Safety
/// This function is unsafe as it doesn't initialize the underlying data store.
/// The caller should ensure that the data is properly initialized as early as possible
/// after this call.
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn set_seed(&self, _: u64) -> Result<()>;
/// Synchronize should block until all the operations on the device are completed.
fn synchronize(&self) -> Result<()>;
}

View File

@ -1,3 +1,4 @@
//! Methods for backpropagation of gradients.
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
@ -15,12 +16,23 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
}
}
thread_local! {
static CANDLE_GRAD_DO_NOT_DETACH: bool = {
match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl Tensor {
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
/// argument.
/// This assumes that the op graph is a DAG.
fn sorted_nodes(&self) -> Vec<&Tensor> {
pub fn sorted_nodes(&self) -> Vec<&Tensor> {
// The vec of sorted nodes is passed as an owned value rather than a mutable reference
// to get around some lifetime limitations.
fn walk<'a>(
@ -36,6 +48,8 @@ impl Tensor {
// Do not call recursively on the "leaf" nodes.
track_grad = true;
nodes
} else if node.dtype().is_int() {
nodes
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
@ -55,11 +69,27 @@ impl Tensor {
kernel: rhs,
..
}
| Op::ConvTranspose1D {
arg: lhs,
kernel: rhs,
..
}
| Op::Conv2D {
arg: lhs,
kernel: rhs,
..
}
| Op::ConvTranspose2D {
arg: lhs,
kernel: rhs,
..
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Matmul(lhs, rhs) => {
| Op::Matmul(lhs, rhs)
| Op::SliceScatter0(lhs, rhs, _) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
@ -80,22 +110,41 @@ impl Tensor {
nodes
}
}
Op::Unary(_node, UnaryOp::Ceil)
| Op::Unary(_node, UnaryOp::Floor)
| Op::Unary(_node, UnaryOp::Round)
| Op::Unary(_node, UnaryOp::Sign) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D { arg: node, .. }
| Op::UpsampleNearest2D { arg: node, .. }
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
| Op::Reduce(node, _, _)
| Op::ToDType(node)
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Permute(node, _)
| Op::Narrow(node, _, _, _)
| Op::Unary(node, _)
| Op::Elu(node, _)
| Op::Powf(node, _)
| Op::CustomOp1(node, _) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
}
Op::ToDType(node) => {
if node.dtype().is_float() {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
} else {
nodes
}
}
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
nodes
@ -119,10 +168,16 @@ impl Tensor {
if node.is_variable() {
continue;
}
let grad = grads.remove(node).unwrap();
// TODO: We should perform all these operations in place (or at least not track the
// whole graph). The only drawback would be if we wanted to support grad of grad but
// this is out of scope.
let grad = grads
.remove(node)
.expect("candle internal error - grad not populated");
// https://github.com/huggingface/candle/issues/1241
// Ideally, we would make these operations in place where possible to ensure that we
// do not have to allocate too often. Here we just call `.detach` to avoid computing
// the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach() };
if let Some(op) = node.op() {
match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => {
@ -153,6 +208,21 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
}
Op::Binary(lhs, rhs, BinaryOp::Minimum)
| Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
// If both masks are 1 one the same point, we want to scale the
// gradient by 0.5 rather than 1.
let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::WhereCond(pred, t, f) => {
let zeros = grad.zeros_like()?;
let t_sum_grad = grads.or_insert(t)?;
@ -162,7 +232,189 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv1D {
arg,
kernel,
padding,
stride,
dilation,
} => {
// The output height for conv_transpose1d is:
// (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
let grad_l_in = grad.dim(2)?;
let k_size = kernel.dim(2)?;
let out_size =
(grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg = grad.conv_transpose1d(
kernel,
*padding,
out_padding,
*stride,
*dilation,
/* groups */ 1,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0) = kernel.dims3()?;
let (_, _, g_k0) = grad_kernel.dims3()?;
let grad_kernel = if g_k0 != k0 {
grad_kernel.narrow(2, 0, k0)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::Conv2D {
arg,
kernel,
padding,
stride,
dilation,
} => {
// The output height for conv_transpose2d is:
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
let grad_h = grad.dim(2)?;
let k_h = kernel.dim(2)?;
let out_size =
(grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
let grad_arg = grad.conv_transpose2d(
kernel,
*padding,
out_padding,
*stride,
*dilation,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose1d",
})?,
Op::ConvTranspose2D {
arg,
kernel,
padding,
stride,
dilation,
output_padding: _output_padding,
} => {
let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = grad
.transpose(0, 1)?
.conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::AvgPool2D {
arg,
kernel_size,
stride,
} => {
if kernel_size != stride {
crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
}
let (_n, _c, h, w) = arg.dims4()?;
let grad_arg = grad.upsample_nearest2d(h, w)?;
let grad_arg =
(grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
Op::MaxPool2D {
arg,
kernel_size,
stride,
} => {
if kernel_size != stride {
crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
}
let (_n, _c, h, w) = arg.dims4()?;
// For computing the max-pool gradient, we compute a mask where a 1 means
// that the element is the maximum, then we apply this mask to the
// upsampled gradient (taking into account that multiple max may exist so
// we scale the gradient for this case).
let node_upsampled = node.upsample_nearest2d(h, w)?;
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
Op::UpsampleNearest1D { arg, target_size } => {
let (_n, c, size) = arg.dims3()?;
if target_size % size != 0 {
crate::bail!("backward not supported for non integer upscaling factors")
}
let scale = target_size / size;
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = conv_sum;
}
Op::UpsampleNearest2D {
arg,
target_h,
target_w,
} => {
let (_n, c, h, w) = arg.dims4()?;
if target_h % h != 0 || target_w % w != 0 {
crate::bail!("backward not supported for non integer upscaling factors")
}
let scale_h = target_h / h;
let scale_w = target_w / w;
if scale_h != scale_w {
crate::bail!("backward not supported for non uniform upscaling factors")
};
let kernel =
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = conv_sum;
}
Op::SliceScatter0(lhs, rhs, start_rhs) => {
let rhs_sum_grad = grads.or_insert(rhs)?;
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
}
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
@ -237,7 +489,6 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?;
}
Op::Cmp(_args, _) => {}
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
@ -254,7 +505,7 @@ impl Tensor {
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
}
Op::Copy(arg) => {
let sum_grad = grads.or_insert(arg)?;
@ -277,6 +528,11 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
Op::Unary(arg, UnaryOp::Tanh) => {
let sum_grad = grads.or_insert(arg)?;
let minus_dtanh = (node.sqr()? - 1.)?;
*sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
}
Op::Unary(arg, UnaryOp::Abs) => {
let sum_grad = grads.or_insert(arg)?;
let ones = arg.ones_like()?;
@ -291,6 +547,11 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?
}
Op::Unary(arg, UnaryOp::Recip) => {
let sum_grad = grads.or_insert(arg)?;
let grad = (grad / arg.sqr()?)?;
*sum_grad = sum_grad.sub(&grad)?
}
&Op::Narrow(ref arg, dim, start_idx, len) => {
let arg_dims = arg.dims();
let left_pad = if start_idx == 0 {
@ -317,20 +578,72 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
Op::Unary(_, UnaryOp::Floor)
| Op::Unary(_, UnaryOp::Round)
| Op::Reduce(_, ReduceOp::ArgMin, _)
| Op::Reduce(_, ReduceOp::ArgMax, _)
| Op::Unary(_, UnaryOp::Sign)
| Op::Cmp(_, _) => {}
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
Op::Unary(arg, UnaryOp::Gelu) => {
let sum_grad = grads.or_insert(arg)?;
let cube = arg.powf(3.)?;
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
let gelu_grad = (((0.5 * &tanh)?
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
+ 0.5)?;
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
}
Op::Unary(arg, UnaryOp::Erf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
let erf_grad =
(2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
*sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
}
Op::Unary(arg, UnaryOp::GeluErf) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
*sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
}
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
}
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::Unary(arg, UnaryOp::Silu) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
let silu_grad = &sigmoid_arg * (1. - *node) + *node;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}
Op::Elu(arg, alpha) => {
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?;
let zeros = arg.zeros_like()?;
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
// node == alpha * (e^x - 1) for x <= 0, reuse it
let negative_exp_mask = (negative_mask * (*node + *alpha))?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
}
Op::Powf(arg, e) => {
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::CustomOp1(arg, c) => {
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
let sum_grad = grads.or_insert(arg)?;
@ -384,6 +697,15 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Permute(arg, dims) => {
let mut inv_dims = vec![0; dims.len()];
for (i, &dim_idx) in dims.iter().enumerate() {
inv_dims[dim_idx] = i
}
let arg_grad = grad.permute(inv_dims)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
};
}
}
@ -391,29 +713,38 @@ impl Tensor {
}
}
/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {
/// Create a new gradient store
fn new() -> Self {
GradStore(HashMap::new())
}
/// Get the gradient tensor corresponding to the given tensor id
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
self.0.get(&id)
}
/// Get the gradient tensor associated with the given tensor
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
self.0.get(&tensor.id())
}
/// Remove the gradient tensor associated with the given tensor, returning it if it exists
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
self.0.remove(&tensor.id())
}
/// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
self.0.insert(tensor.id(), grad)
}
/// Get the gradient tensor associated with the given tensor, or, if it does not exist,
/// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
use std::collections::hash_map::Entry;
let grad = match self.0.entry(tensor.id()) {
@ -425,4 +756,9 @@ impl GradStore {
};
Ok(grad)
}
/// Get the tensor ids of the stored gradient tensors
pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
self.0.keys()
}
}

View File

@ -1,6 +1,10 @@
//! 1D and 2D Convolutions
//!
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
pub(crate) b_size: Option<usize>,
pub(crate) b_size: usize,
// Maybe we should have a version without l_in as this bit depends on the input and not only on
// the weights.
pub(crate) l_in: usize,
@ -9,19 +13,350 @@ pub struct ParamsConv1D {
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl ParamsConv1D {
pub(crate) fn l_out(&self) -> usize {
let dilation = 1;
(self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1
(self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
match self.b_size {
None => vec![self.c_out, l_out],
Some(n) => vec![n, self.c_out, l_out],
}
vec![self.b_size, self.c_out, l_out]
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConvTranspose1D {
pub(crate) b_size: usize,
pub(crate) l_in: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) k_size: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConvTranspose1D {
pub(crate) fn l_out(&self) -> usize {
(self.l_in - 1) * self.stride - 2 * self.padding
+ self.dilation * (self.k_size - 1)
+ self.output_padding
+ 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
vec![self.b_size, self.c_out, l_out]
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo {
ImplicitGemm,
ImplicitPrecompGemm,
Gemm,
Direct,
Fft,
FftTiling,
Winograd,
WinogradNonFused,
Count,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv2D {
pub(crate) b_size: usize,
pub(crate) i_h: usize,
pub(crate) i_w: usize,
pub(crate) k_h: usize,
pub(crate) k_w: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl ParamsConv2D {
pub(crate) fn out_h(&self) -> usize {
(self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1
}
pub(crate) fn out_w(&self) -> usize {
(self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConvTranspose2D {
pub(crate) b_size: usize,
pub(crate) i_h: usize,
pub(crate) i_w: usize,
pub(crate) k_h: usize,
pub(crate) k_w: usize,
pub(crate) c_out: usize,
pub(crate) c_in: usize,
pub(crate) padding: usize,
pub(crate) output_padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
}
impl ParamsConvTranspose2D {
pub(crate) fn out_h(&self) -> usize {
(self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1
- 2 * self.padding
}
pub(crate) fn out_w(&self) -> usize {
(self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1
- 2 * self.padding
}
pub(crate) fn out_dims(&self) -> Vec<usize> {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}
impl Tensor {
fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
let storage =
self.storage()
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
arg,
kernel,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k * groups {
Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),
k_shape: kernel.shape().clone(),
padding,
stride,
msg: "the number of in-channels on the input doesn't match the kernel size",
}
.bt())?
}
let params = ParamsConv1D {
b_size,
l_in,
c_out: c_out / groups,
c_in: c_in / groups,
k_size,
padding,
stride,
dilation,
cudnn_fwd_algo: None,
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.zip(&kernel)
.map(|(block, kernel)| block.conv1d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
}
fn conv_transpose1d_single_group(
&self,
kernel: &Self,
params: &ParamsConvTranspose1D,
) -> Result<Self> {
let storage = self.storage().conv_transpose1d(
self.layout(),
&kernel.storage(),
kernel.layout(),
params,
)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
arg,
kernel,
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_in_k, c_out, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
if c_in % groups != 0 {
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in: c_in / groups,
padding,
output_padding,
stride,
dilation,
};
if groups == 1 {
self.conv_transpose1d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.zip(&kernel)
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage =
self.storage()
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
arg,
kernel,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
/// Applies a 2D convolution over the input tensor.
pub fn conv2d(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k * groups {
crate::bail!(
"in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
)
}
let params = ParamsConv2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out: c_out / groups,
c_in: c_in / groups,
padding,
stride,
dilation,
cudnn_fwd_algo: None,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.zip(&kernel)
.map(|(block, kernel)| block.conv2d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
}
/// Applies a 2D transposed convolution over the input tensor.
pub fn conv_transpose2d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose2D {
b_size,
i_h,
i_w,
k_h,
k_w,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose2d(
self.layout(),
&kernel.storage(),
kernel.layout(),
&params,
)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D {
arg,
kernel,
padding: params.padding,
output_padding: params.output_padding,
stride: params.stride,
dilation: params.dilation,
});
let out_dims = params.out_dims();
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
}

View File

@ -1,6 +1,6 @@
//! Implement conversion traits for tensors
use crate::{Device, Error, Tensor, WithDType};
use half::{bf16, f16};
use crate::{DType, Device, Error, Tensor, WithDType};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::convert::TryFrom;
impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
@ -92,5 +92,54 @@ from_tensor!(f64);
from_tensor!(f32);
from_tensor!(f16);
from_tensor!(bf16);
from_tensor!(i64);
from_tensor!(u32);
from_tensor!(u8);
impl Tensor {
pub fn write_bytes<W: std::io::Write>(&self, f: &mut W) -> crate::Result<()> {
use byteorder::{LittleEndian, WriteBytesExt};
let vs = self.flatten_all()?;
match self.dtype() {
DType::BF16 => {
let vs = vs.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
let vs = vs.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
for v in vs.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
for v in vs.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
for v in vs.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
DType::I64 => {
for v in vs.to_vec1::<i64>()? {
f.write_i64::<LittleEndian>(v)?
}
}
DType::U8 => {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;
}
}
Ok(())
}
}

148
candle-core/src/cpu/avx.rs Normal file
View File

@ -0,0 +1,148 @@
use super::{Cpu, CpuF16};
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use half::f16;
pub struct CurrentCpu {}
const STEP: usize = 32;
const EPR: usize = 8;
const ARR: usize = STEP / EPR;
impl Cpu<ARR> for CurrentCpu {
type Unit = __m256;
type Array = [__m256; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
_mm256_setzero_ps()
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn from_f32(v: f32) -> Self::Unit {
_mm256_set1_ps(v)
}
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
_mm256_loadu_ps(mem_addr)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
_mm256_add_ps(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
_mm256_add_ps(_mm256_mul_ps(b, c), a)
}
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
_mm256_storeu_ps(mem_addr, a);
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..ARR / 2 {
x[2 * i] = _mm256_add_ps(x[2 * i], x[2 * i + 1]);
}
for i in 0..ARR / 4 {
x[4 * i] = _mm256_add_ps(x[4 * i], x[4 * i + 2]);
}
#[allow(clippy::reversed_empty_ranges)]
for i in 0..ARR / 8 {
x[8 * i] = _mm256_add_ps(x[8 * i], x[8 * i + 4]);
}
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
let t1 = _mm_hadd_ps(t0, t0);
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}
pub struct CurrentCpuF16 {}
impl CpuF16<ARR> for CurrentCpuF16 {
type Unit = __m256;
type Array = [__m256; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
_mm256_setzero_ps()
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn from_f32(v: f32) -> Self::Unit {
_mm256_set1_ps(v)
}
#[cfg(target_feature = "f16c")]
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
}
#[cfg(not(target_feature = "f16c"))]
unsafe fn load(mem_addr: *const f16) -> Self::Unit {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = (*mem_addr.add(i)).to_f32();
}
_mm256_loadu_ps(tmp.as_ptr())
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
_mm256_add_ps(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
_mm256_add_ps(_mm256_mul_ps(b, c), a)
}
#[cfg(target_feature = "f16c")]
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
}
#[cfg(not(target_feature = "f16c"))]
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit) {
let mut tmp = [0.0f32; 8];
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
for i in 0..8 {
*mem_addr.add(i) = f16::from_f32(tmp[i]);
}
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
let mut offset = ARR >> 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
let t1 = _mm_hadd_ps(t0, t0);
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}

763
candle-core/src/cpu/erf.rs Normal file
View File

@ -0,0 +1,763 @@
#![allow(clippy::excessive_precision)]
// Code taken from https://github.com/statrs-dev/statrs
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
//! related functions
mod evaluate {
//! Provides functions that don't have a numerical solution and must
//! be solved computationally (e.g. evaluation of a polynomial)
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
/// coeffecient
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
/// `2z^2 - z + 3`
///
/// # Remarks
///
/// Returns 0 for a 0 length coefficient slice
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
let n = coeff.len();
if n == 0 {
return 0.0;
}
let mut sum = *coeff.last().unwrap();
for c in coeff[0..n - 1].iter().rev() {
sum = *c + z * sum;
}
sum
}
}
use std::f64;
/// `erf` calculates the error function at `x`.
pub fn erf(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x >= 0.0 && x.is_infinite() {
1.0
} else if x <= 0.0 && x.is_infinite() {
-1.0
} else if x == 0. {
0.0
} else {
erf_impl(x, false)
}
}
/// `erf_inv` calculates the inverse error function
/// at `x`.
pub fn erf_inv(x: f64) -> f64 {
if x == 0.0 {
0.0
} else if x >= 1.0 {
f64::INFINITY
} else if x <= -1.0 {
f64::NEG_INFINITY
} else if x < 0.0 {
erf_inv_impl(-x, 1.0 + x, -1.0)
} else {
erf_inv_impl(x, 1.0 - x, 1.0)
}
}
/// `erfc` calculates the complementary error function
/// at `x`.
pub fn erfc(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x == f64::INFINITY {
0.0
} else if x == f64::NEG_INFINITY {
2.0
} else {
erf_impl(x, true)
}
}
/// `erfc_inv` calculates the complementary inverse
/// error function at `x`.
pub fn erfc_inv(x: f64) -> f64 {
if x <= 0.0 {
f64::INFINITY
} else if x >= 2.0 {
f64::NEG_INFINITY
} else if x > 1.0 {
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
} else {
erf_inv_impl(1.0 - x, x, 1.0)
}
}
// **********************************************************
// ********** Coefficients for erf_impl polynomial **********
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_impl`
/// in the interval [1e-10, 0.5].
const ERF_IMPL_AN: &[f64] = &[
0.00337916709551257388990745,
-0.00073695653048167948530905,
-0.374732337392919607868241,
0.0817442448733587196071743,
-0.0421089319936548595203468,
0.0070165709512095756344528,
-0.00495091255982435110337458,
0.000871646599037922480317225,
];
/// Polynomial coefficients for a denominator of `erf_impl`
/// in the interval [1e-10, 0.5]
const ERF_IMPL_AD: &[f64] = &[
1.0,
-0.218088218087924645390535,
0.412542972725442099083918,
-0.0841891147873106755410271,
0.0655338856400241519690695,
-0.0120019604454941768171266,
0.00408165558926174048329689,
-0.000615900721557769691924509,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BN: &[f64] = &[
-0.0361790390718262471360258,
0.292251883444882683221149,
0.281447041797604512774415,
0.125610208862766947294894,
0.0274135028268930549240776,
0.00250839672168065762786937,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BD: &[f64] = &[
1.0,
1.8545005897903486499845,
1.43575803037831418074962,
0.582827658753036572454135,
0.124810476932949746447682,
0.0113724176546353285778481,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CN: &[f64] = &[
-0.0397876892611136856954425,
0.153165212467878293257683,
0.191260295600936245503129,
0.10276327061989304213645,
0.029637090615738836726027,
0.0046093486780275489468812,
0.000307607820348680180548455,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CD: &[f64] = &[
1.0,
1.95520072987627704987886,
1.64762317199384860109595,
0.768238607022126250082483,
0.209793185936509782784315,
0.0319569316899913392596356,
0.00213363160895785378615014,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DN: &[f64] = &[
-0.0300838560557949717328341,
0.0538578829844454508530552,
0.0726211541651914182692959,
0.0367628469888049348429018,
0.00964629015572527529605267,
0.00133453480075291076745275,
0.778087599782504251917881e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DD: &[f64] = &[
1.0,
1.75967098147167528287343,
1.32883571437961120556307,
0.552528596508757581287907,
0.133793056941332861912279,
0.0179509645176280768640766,
0.00104712440019937356634038,
-0.106640381820357337177643e-7,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_EN: &[f64] = &[
-0.0117907570137227847827732,
0.014262132090538809896674,
0.0202234435902960820020765,
0.00930668299990432009042239,
0.00213357802422065994322516,
0.00025022987386460102395382,
0.120534912219588189822126e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_ED: &[f64] = &[
1.0,
1.50376225203620482047419,
0.965397786204462896346934,
0.339265230476796681555511,
0.0689740649541569716897427,
0.00771060262491768307365526,
0.000371421101531069302990367,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FN: &[f64] = &[
-0.00546954795538729307482955,
0.00404190278731707110245394,
0.0054963369553161170521356,
0.00212616472603945399437862,
0.000394984014495083900689956,
0.365565477064442377259271e-4,
0.135485897109932323253786e-5,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FD: &[f64] = &[
1.0,
1.21019697773630784832251,
0.620914668221143886601045,
0.173038430661142762569515,
0.0276550813773432047594539,
0.00240625974424309709745382,
0.891811817251336577241006e-4,
-0.465528836283382684461025e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GN: &[f64] = &[
-0.00270722535905778347999196,
0.0013187563425029400461378,
0.00119925933261002333923989,
0.00027849619811344664248235,
0.267822988218331849989363e-4,
0.923043672315028197865066e-6,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GD: &[f64] = &[
1.0,
0.814632808543141591118279,
0.268901665856299542168425,
0.0449877216103041118694989,
0.00381759663320248459168994,
0.000131571897888596914350697,
0.404815359675764138445257e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HN: &[f64] = &[
-0.00109946720691742196814323,
0.000406425442750422675169153,
0.000274499489416900707787024,
0.465293770646659383436343e-4,
0.320955425395767463401993e-5,
0.778286018145020892261936e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HD: &[f64] = &[
1.0,
0.588173710611846046373373,
0.139363331289409746077541,
0.0166329340417083678763028,
0.00100023921310234908642639,
0.24254837521587225125068e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_IN: &[f64] = &[
-0.00056907993601094962855594,
0.000169498540373762264416984,
0.518472354581100890120501e-4,
0.382819312231928859704678e-5,
0.824989931281894431781794e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_ID: &[f64] = &[
1.0,
0.339637250051139347430323,
0.043472647870310663055044,
0.00248549335224637114641629,
0.535633305337152900549536e-4,
-0.117490944405459578783846e-12,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JN: &[f64] = &[
-0.000241313599483991337479091,
0.574224975202501512365975e-4,
0.115998962927383778460557e-4,
0.581762134402593739370875e-6,
0.853971555085673614607418e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JD: &[f64] = &[
1.0,
0.233044138299687841018015,
0.0204186940546440312625597,
0.000797185647564398289151125,
0.117019281670172327758019e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KN: &[f64] = &[
-0.000146674699277760365803642,
0.162666552112280519955647e-4,
0.269116248509165239294897e-5,
0.979584479468091935086972e-7,
0.101994647625723465722285e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KD: &[f64] = &[
1.0,
0.165907812944847226546036,
0.0103361716191505884359634,
0.000286593026373868366935721,
0.298401570840900340874568e-5,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LN: &[f64] = &[
-0.583905797629771786720406e-4,
0.412510325105496173512992e-5,
0.431790922420250949096906e-6,
0.993365155590013193345569e-8,
0.653480510020104699270084e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LD: &[f64] = &[
1.0,
0.105077086072039915406159,
0.00414278428675475620830226,
0.726338754644523769144108e-4,
0.477818471047398785369849e-6,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MN: &[f64] = &[
-0.196457797609229579459841e-4,
0.157243887666800692441195e-5,
0.543902511192700878690335e-7,
0.317472492369117710852685e-9,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MD: &[f64] = &[
1.0,
0.052803989240957632204885,
0.000926876069151753290378112,
0.541011723226630257077328e-5,
0.535093845803642394908747e-15,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_NN: &[f64] = &[
-0.789224703978722689089794e-5,
0.622088451660986955124162e-6,
0.145728445676882396797184e-7,
0.603715505542715364529243e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_ND: &[f64] = &[
1.0,
0.0375328846356293715248719,
0.000467919535974625308126054,
0.193847039275845656900547e-5,
];
// **********************************************************
// ********** Coefficients for erf_inv_impl polynomial ******
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AN: &[f64] = &[
-0.000508781949658280665617,
-0.00836874819741736770379,
0.0334806625409744615033,
-0.0126926147662974029034,
-0.0365637971411762664006,
0.0219878681111168899165,
0.00822687874676915743155,
-0.00538772965071242932965,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AD: &[f64] = &[
1.0,
-0.970005043303290640362,
-1.56574558234175846809,
1.56221558398423026363,
0.662328840472002992063,
-0.71228902341542847553,
-0.0527396382340099713954,
0.0795283687341571680018,
-0.00233393759374190016776,
0.000886216390456424707504,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BN: &[f64] = &[
-0.202433508355938759655,
0.105264680699391713268,
8.37050328343119927838,
17.6447298408374015486,
-18.8510648058714251895,
-44.6382324441786960818,
17.445385985570866523,
21.1294655448340526258,
-3.67192254707729348546,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BD: &[f64] = &[
1.0,
6.24264124854247537712,
3.9713437953343869095,
-28.6608180499800029974,
-20.1432634680485188801,
48.5609213108739935468,
10.8268667355460159008,
-22.6436933413139721736,
1.72114765761200282724,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CN: &[f64] = &[
-0.131102781679951906451,
-0.163794047193317060787,
0.117030156341995252019,
0.387079738972604337464,
0.337785538912035898924,
0.142869534408157156766,
0.0290157910005329060432,
0.00214558995388805277169,
-0.679465575181126350155e-6,
0.285225331782217055858e-7,
-0.681149956853776992068e-9,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CD: &[f64] = &[
1.0,
3.46625407242567245975,
5.38168345707006855425,
4.77846592945843778382,
2.59301921623620271374,
0.848854343457902036425,
0.152264338295331783612,
0.01105924229346489121,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DN: &[f64] = &[
-0.0350353787183177984712,
-0.00222426529213447927281,
0.0185573306514231072324,
0.00950804701325919603619,
0.00187123492819559223345,
0.000157544617424960554631,
0.460469890584317994083e-5,
-0.230404776911882601748e-9,
0.266339227425782031962e-11,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DD: &[f64] = &[
1.0,
1.3653349817554063097,
0.762059164553623404043,
0.220091105764131249824,
0.0341589143670947727934,
0.00263861676657015992959,
0.764675292302794483503e-4,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_EN: &[f64] = &[
-0.0167431005076633737133,
-0.00112951438745580278863,
0.00105628862152492910091,
0.000209386317487588078668,
0.149624783758342370182e-4,
0.449696789927706453732e-6,
0.462596163522878599135e-8,
-0.281128735628831791805e-13,
0.99055709973310326855e-16,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_ED: &[f64] = &[
1.0,
0.591429344886417493481,
0.138151865749083321638,
0.0160746087093676504695,
0.000964011807005165528527,
0.275335474764726041141e-4,
0.282243172016108031869e-6,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FN: &[f64] = &[
-0.0024978212791898131227,
-0.779190719229053954292e-5,
0.254723037413027451751e-4,
0.162397777342510920873e-5,
0.396341011304801168516e-7,
0.411632831190944208473e-9,
0.145596286718675035587e-11,
-0.116765012397184275695e-17,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FD: &[f64] = &[
1.0,
0.207123112214422517181,
0.0169410838120975906478,
0.000690538265622684595676,
0.145007359818232637924e-4,
0.144437756628144157666e-6,
0.509761276599778486139e-9,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GN: &[f64] = &[
-0.000539042911019078575891,
-0.28398759004727721098e-6,
0.899465114892291446442e-6,
0.229345859265920864296e-7,
0.225561444863500149219e-9,
0.947846627503022684216e-12,
0.135880130108924861008e-14,
-0.348890393399948882918e-21,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GD: &[f64] = &[
1.0,
0.0845746234001899436914,
0.00282092984726264681981,
0.468292921940894236786e-4,
0.399968812193862100054e-6,
0.161809290887904476097e-8,
0.231558608310259605225e-11,
];
/// `erf_impl` computes the error function at `z`.
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
fn erf_impl(z: f64, inv: bool) -> f64 {
if z < 0.0 {
if !inv {
return -erf_impl(-z, false);
}
if z < -0.5 {
return 2.0 - erf_impl(-z, true);
}
return 1.0 + erf_impl(-z, false);
}
let result = if z < 0.5 {
if z < 1e-10 {
z * 1.125 + z * 0.003379167095512573896158903121545171688
} else {
z * 1.125
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
}
} else if z < 110.0 {
let (r, b) = if z < 0.75 {
(
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
0.3440242112,
)
} else if z < 1.25 {
(
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
0.419990927,
)
} else if z < 2.25 {
(
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
0.4898625016,
)
} else if z < 3.5 {
(
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
0.5317370892,
)
} else if z < 5.25 {
(
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
0.5489973426,
)
} else if z < 8.0 {
(
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
0.5571740866,
)
} else if z < 11.5 {
(
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
0.5609807968,
)
} else if z < 17.0 {
(
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
0.5626493692,
)
} else if z < 24.0 {
(
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
0.5634598136,
)
} else if z < 38.0 {
(
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
0.5638477802,
)
} else if z < 60.0 {
(
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
0.5640528202,
)
} else if z < 85.0 {
(
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
0.5641309023,
)
} else {
(
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
0.5641584396,
)
};
let g = (-z * z).exp() / z;
g * b + g * r
} else {
0.0
};
if inv && z >= 0.5 {
result
} else if z >= 0.5 || inv {
1.0 - result
} else {
result
}
}
// `erf_inv_impl` computes the inverse error function where
// `p`,`q`, and `s` are the first, second, and third intermediate
// parameters respectively
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
let result = if p <= 0.5 {
let y = 0.0891314744949340820313;
let g = p * (p + 10.0);
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
g * y + g * r
} else if q >= 0.25 {
let y = 2.249481201171875;
let g = (-2.0 * q.ln()).sqrt();
let xs = q - 0.25;
let r =
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
g / (y + r)
} else {
let x = (-q.ln()).sqrt();
if x < 3.0 {
let y = 0.807220458984375;
let xs = x - 1.125;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
y * x + r * x
} else if x < 6.0 {
let y = 0.93995571136474609375;
let xs = x - 3.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
y * x + r * x
} else if x < 18.0 {
let y = 0.98362827301025390625;
let xs = x - 6.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
y * x + r * x
} else if x < 44.0 {
let y = 0.99714565277099609375;
let xs = x - 18.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
y * x + r * x
} else {
let y = 0.99941349029541015625;
let xs = x - 44.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
y * x + r * x
}
};
s * result
}

View File

@ -0,0 +1,191 @@
pub trait VecOps: num_traits::NumAssign + Copy {
fn min(self, rhs: Self) -> Self;
fn max(self, rhs: Self) -> Self;
/// Dot-product of two vectors.
///
/// # Safety
///
/// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
*res = Self::zero();
for i in 0..len {
*res += *lhs.add(i) * *rhs.add(i)
}
}
/// Sum of all elements in a vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len`. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
*res = Self::zero();
for i in 0..len {
*res += *xs.add(i)
}
}
/// Maximum element in a non-empty vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
*res = (*res).max(*xs.add(i))
}
}
/// Minimum element in a non-empty vector.
///
/// # Safety
///
/// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
/// element.
#[inline(always)]
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
*res = (*res).min(*xs.add(i))
}
}
}
impl VecOps for f32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
super::vec_dot_f32(lhs, rhs, res, len)
}
#[inline(always)]
unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
super::vec_sum(xs, res, len)
}
}
impl VecOps for half::f16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
*res = half::f16::from_f32(res_f32);
}
}
impl VecOps for f64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}
impl VecOps for half::bf16 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Self::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
Self::max(self, other)
}
}
impl VecOps for u8 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for u32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
if n_threads == 1 {
func(0)
} else {
rayon::scope(|s| {
for thread_idx in 0..n_threads {
let func = &func;
s.spawn(move |_| func(thread_idx));
}
})
}
}
#[inline(always)]
pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
if n_threads == 1 {
for i in lo..up {
func(i)
}
} else {
rayon::scope(|s| {
for thread_idx in 0..n_threads {
let func = &func;
s.spawn(move |_| {
for i in (thread_idx..up).step_by(n_threads) {
func(i)
}
});
}
})
}
}

184
candle-core/src/cpu/mod.rs Normal file
View File

@ -0,0 +1,184 @@
//! Traits and methods for CPU-backed Tensors
pub mod erf;
pub mod kernels;
#[allow(unused)]
trait Cpu<const ARR: usize> {
type Unit;
type Array;
const STEP: usize;
const EPR: usize;
fn n() -> usize;
unsafe fn zero() -> Self::Unit;
unsafe fn zero_array() -> Self::Array;
unsafe fn load(mem_addr: *const f32) -> Self::Unit;
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
}
#[allow(unused)]
trait CpuF16<const ARR: usize> {
type Unit;
type Array;
const STEP: usize;
const EPR: usize;
fn n() -> usize;
unsafe fn zero() -> Self::Unit;
unsafe fn zero_array() -> Self::Array;
unsafe fn load(mem_addr: *const f16) -> Self::Unit;
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
}
use half::f16;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub mod avx;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub use avx::{CurrentCpu, CurrentCpuF16};
#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
pub mod simd128;
#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
pub use simd128::CurrentCpu;
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(target_feature = "neon")]
pub mod neon;
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(target_feature = "neon")]
pub use neon::CurrentCpu;
#[cfg(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
let np = k & !(CurrentCpu::STEP - 1);
let mut sum = CurrentCpu::zero_array();
let mut ax = CurrentCpu::zero_array();
let mut ay = CurrentCpu::zero_array();
for i in (0..np).step_by(CurrentCpu::STEP) {
for j in 0..CurrentCpu::n() {
ax[j] = CurrentCpu::load(a_row.add(i + j * CurrentCpu::EPR));
ay[j] = CurrentCpu::load(b_row.add(i + j * CurrentCpu::EPR));
sum[j] = CurrentCpu::vec_fma(sum[j], ax[j], ay[j]);
}
}
CurrentCpu::vec_reduce(sum, c);
// leftovers
for i in np..k {
*c += *a_row.add(i) * (*b_row.add(i));
}
}
#[cfg(not(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
)))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f32(a_row: *const f32, b_row: *const f32, c: *mut f32, k: usize) {
// leftovers
for i in 0..k {
*c += *a_row.add(i) * (*b_row.add(i));
}
}
#[cfg(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
))]
#[inline(always)]
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
let np = k & !(CurrentCpu::STEP - 1);
let mut sum = CurrentCpu::zero_array();
let mut x = CurrentCpu::zero_array();
for i in (0..np).step_by(CurrentCpu::STEP) {
for j in 0..CurrentCpu::n() {
x[j] = CurrentCpu::load(row.add(i + j * CurrentCpu::EPR));
sum[j] = CurrentCpu::vec_add(sum[j], x[j]);
}
}
CurrentCpu::vec_reduce(sum, b);
// leftovers
for i in np..k {
*b += *row.add(i)
}
}
#[cfg(not(any(
target_feature = "neon",
target_feature = "avx",
target_feature = "simd128"
)))]
#[inline(always)]
pub(crate) unsafe fn vec_sum(row: *const f32, b: *mut f32, k: usize) {
*b = 0f32;
for i in 0..k {
*b += *row.add(i)
}
}
#[cfg(target_feature = "avx")]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
let mut sumf = 0.0f32;
let np = k & !(CurrentCpuF16::STEP - 1);
let mut sum = CurrentCpuF16::zero_array();
let mut ax = CurrentCpuF16::zero_array();
let mut ay = CurrentCpuF16::zero_array();
for i in (0..np).step_by(CurrentCpuF16::STEP) {
for j in 0..CurrentCpuF16::n() {
ax[j] = CurrentCpuF16::load(a_row.add(i + j * CurrentCpuF16::EPR));
ay[j] = CurrentCpuF16::load(b_row.add(i + j * CurrentCpuF16::EPR));
sum[j] = CurrentCpuF16::vec_fma(sum[j], ax[j], ay[j]);
}
}
CurrentCpuF16::vec_reduce(sum, &mut sumf);
// leftovers
for i in np..k {
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sumf;
}
#[cfg(not(target_feature = "avx"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
// leftovers
let mut sum = 0.0;
for i in 0..k {
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sum;
}

View File

@ -0,0 +1,74 @@
use super::Cpu;
#[cfg(target_arch = "arm")]
use core::arch::arm::*;
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
pub struct CurrentCpu {}
const STEP: usize = 16;
const EPR: usize = 4;
const ARR: usize = STEP / EPR;
impl CurrentCpu {
#[cfg(target_arch = "aarch64")]
unsafe fn reduce_one(x: float32x4_t) -> f32 {
vaddvq_f32(x)
}
#[cfg(target_arch = "arm")]
unsafe fn reduce_one(x: float32x4_t) -> f32 {
vgetq_lane_f32(x, 0) + vgetq_lane_f32(x, 1) + vgetq_lane_f32(x, 2) + vgetq_lane_f32(x, 3)
}
}
impl Cpu<ARR> for CurrentCpu {
type Unit = float32x4_t;
type Array = [float32x4_t; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
vdupq_n_f32(0.0)
}
unsafe fn from_f32(x: f32) -> Self::Unit {
vdupq_n_f32(x)
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
vld1q_f32(mem_addr)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
vaddq_f32(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
vfmaq_f32(a, b, c)
}
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
vst1q_f32(mem_addr, a);
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..ARR / 2 {
x[2 * i] = vaddq_f32(x[2 * i], x[2 * i + 1]);
}
for i in 0..ARR / 4 {
x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]);
}
*y = Self::reduce_one(x[0]);
}
}

View File

@ -0,0 +1,64 @@
use super::Cpu;
use core::arch::wasm32::*;
pub struct CurrentCpu {}
const STEP: usize = 16;
const EPR: usize = 4;
const ARR: usize = STEP / EPR;
impl Cpu<ARR> for CurrentCpu {
type Unit = v128;
type Array = [v128; ARR];
const STEP: usize = STEP;
const EPR: usize = EPR;
fn n() -> usize {
ARR
}
unsafe fn zero() -> Self::Unit {
f32x4_splat(0.0)
}
unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}
unsafe fn from_f32(v: f32) -> Self::Unit {
f32x4_splat(v)
}
unsafe fn load(mem_addr: *const f32) -> Self::Unit {
v128_load(mem_addr as *mut v128)
}
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
f32x4_add(a, b)
}
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
f32x4_add(f32x4_mul(b, c), a)
}
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit) {
v128_store(mem_addr as *mut v128, a);
}
unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
for i in 0..ARR / 2 {
x[2 * i] = f32x4_add(x[2 * i], x[2 * i + 1]);
}
for i in 0..ARR / 4 {
x[4 * i] = f32x4_add(x[4 * i], x[4 * i + 2]);
}
for i in 0..ARR / 8 {
x[8 * i] = f32x4_add(x[8 * i], x[8 * i + 4]);
}
*y = f32x4_extract_lane::<0>(x[0])
+ f32x4_extract_lane::<1>(x[0])
+ f32x4_extract_lane::<2>(x[0])
+ f32x4_extract_lane::<3>(x[0]);
}
}

View File

@ -0,0 +1,360 @@
/// Helper functions to write CPU kernels.
use crate::backend::BackendStorage;
use crate::{Error, Layout, Result, WithDType};
type C = super::CpuStorage;
pub trait Map1 {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
match vs {
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
}
}
}
pub trait Map1Any {
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
match vs {
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
}
}
}
pub trait Map2 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
pub trait Map2U8 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt()),
}
}
}
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
) -> Vec<U> {
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
.iter()
.zip(rhs[o_r1..o_r2].iter())
.map(|(&l, &r)| f(l, r))
.collect(),
(Some((o_l1, o_l2)), None) => {
// TODO: Maybe we want to avoid going through the layout twice.
match rhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
lhs[o_l1..o_l2]
.iter()
.map(|&l| {
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(l, *r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
(None, Some((o_r1, o_r2))) => {
// TODO: Maybe we want to avoid going through the layout twice.
match lhs_l.offsets_b() {
Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
rhs[o_r1..o_r2]
.iter()
.map(|&r| {
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
i_right_broadcast += 1;
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
}
if i_in_block >= ob.len {
i_in_block = 0
}
f(*l, r)
})
.collect()
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
// Similar to binary_map but with vectorized variants.
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
mut f_vec: FV,
) -> Vec<T> {
let el_count = lhs_l.shape().elem_count();
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec(
&lhs[src_i..src_i + ob.len],
rhs,
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys = lhs[o_l1..o_l2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &r) in rhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(*v, r)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
Some(ob) if ob.right_broadcast == 1 => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec(
lhs,
&rhs[src_i..src_i + ob.len],
&mut ys_to_set[dst_i..dst_i + ob.len],
);
dst_i += ob.len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
Some(ob) => {
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys = rhs[o_r1..o_r2].to_vec();
for idx_l in 0..ob.left_broadcast {
let start = idx_l * ob.len * ob.right_broadcast;
for (i, &l) in lhs.iter().enumerate() {
let start = start + i * ob.right_broadcast;
for v in ys[start..start + ob.right_broadcast].iter_mut() {
*v = f(l, *v)
}
}
}
ys
}
None => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
},
_ => lhs_l
.strided_index()
.zip(rhs_l.strided_index())
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect(),
}
}
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
vs: &[T],
layout: &Layout,
mut f: F,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
[start_offset..start_offset + len]
.iter()
.map(|&v| f(v))
.collect(),
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let mut result = Vec::with_capacity(layout.shape().elem_count());
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
} else {
for index in block_start_index {
for offset in 0..block_len {
let v = unsafe { vs.get_unchecked(index + offset) };
result.push(f(*v))
}
}
}
result
}
}
}
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
vs: &[T],
layout: &Layout,
mut f: F,
mut f_vec: FV,
) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
};
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
ys
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
let el_count = layout.shape().elem_count();
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
let mut result = Vec::with_capacity(el_count);
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
result
} else {
let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
};
let mut dst_index = 0;
for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len];
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
f_vec(vs, ys);
dst_index += block_len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,225 @@
use crate::WithDType;
use cudarc;
use cudarc::cudnn::safe::{ConvForward, Cudnn};
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::Arc;
// The cudnn handles are stored per thread here rather than on the CudaDevice as they are neither
// send nor sync.
thread_local! {
static CUDNN: RefCell<HashMap<crate::cuda_backend::DeviceId, Arc<Cudnn>>> = HashMap::new().into();
}
impl From<cudarc::cudnn::CudnnError> for crate::Error {
fn from(err: cudarc::cudnn::CudnnError) -> Self {
crate::Error::wrap(err)
}
}
impl From<cudarc::driver::DriverError> for crate::Error {
fn from(err: cudarc::driver::DriverError) -> Self {
crate::Error::wrap(err)
}
}
pub(crate) fn launch_conv2d<
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
Y: cudarc::cudnn::CudnnDataType,
>(
src: &CudaView<T>,
src_l: &crate::Layout,
filter: &CudaView<T>,
dst: &mut CudaSlice<T>,
params: &crate::conv::ParamsConv2D,
dev: &crate::cuda_backend::CudaDevice,
) -> crate::Result<()> {
use crate::conv::CudnnFwdAlgo as CandleAlgo;
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
let device_id = dev.id();
let cudnn = CUDNN.with(|cudnn| {
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
let c = Cudnn::new(dev.cuda_stream());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
c
})?;
let conv = cudnn.create_conv2d::<Y>(
/* pad */ [params.padding as i32, params.padding as i32],
/* stride */ [params.stride as i32, params.stride as i32],
/* dilation */ [params.dilation as i32, params.dilation as i32],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
let x_shape = [
params.b_size as i32,
params.c_in as i32,
params.i_h as i32,
params.i_w as i32,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
x_shape,
)?
} else {
let s = src_l.stride();
cudnn.create_4d_tensor_ex::<T>(
x_shape,
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
)?
};
let w = cudnn.create_4d_filter::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[
params.c_out as i32,
params.c_in as i32,
params.k_h as i32,
params.k_w as i32,
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;
let conv2d = ConvForward {
conv: &conv,
x: &x,
w: &w,
y: &y,
};
let alg = match params.cudnn_fwd_algo {
None => conv2d.pick_algorithm()?,
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
Some(CandleAlgo::ImplicitPrecompGemm) => {
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
}
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv2d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
unsafe {
conv2d.launch::<CudaSlice<u8>, _, _, _>(
alg,
Some(&mut workspace),
(T::one(), T::zero()),
src,
filter,
dst,
)?;
}
Ok(())
}
pub(crate) fn launch_conv1d<
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
Y: cudarc::cudnn::CudnnDataType,
>(
src: &CudaView<T>,
src_l: &crate::Layout,
filter: &CudaView<T>,
dst: &mut CudaSlice<T>,
params: &crate::conv::ParamsConv1D,
dev: &crate::cuda_backend::CudaDevice,
) -> crate::Result<()> {
use crate::conv::CudnnFwdAlgo as CandleAlgo;
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
let device_id = dev.id();
let cudnn = CUDNN.with(|cudnn| {
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
let c = Cudnn::new(dev.cuda_stream());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
c
})?;
let conv = cudnn.create_conv2d::<Y>(
/* pad */ [params.padding as i32, 0],
/* stride */ [params.stride as i32, 1],
/* dilation */ [params.dilation as i32, 1],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
// > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
// > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
// > recommended that the user create a 4D tensor, and set the size along unused dimensions
// > to 1.
let x_shape = [
params.b_size as i32,
params.c_in as i32,
params.l_in as i32,
1,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
x_shape,
)?
} else {
let s = src_l.stride();
cudnn.create_4d_tensor_ex::<T>(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
};
let w = cudnn.create_4d_filter::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[
params.c_out as i32,
params.c_in as i32,
params.k_size as i32,
1,
],
)?;
let l_out = params.l_out() as i32;
let y = cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, l_out, 1],
)?;
let conv1d = ConvForward {
conv: &conv,
x: &x,
w: &w,
y: &y,
};
let alg = match params.cudnn_fwd_algo {
None => conv1d.pick_algorithm()?,
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
Some(CandleAlgo::ImplicitPrecompGemm) => {
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
}
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv1d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
unsafe {
conv1d.launch::<CudaSlice<u8>, _, _, _>(
alg,
Some(&mut workspace),
(T::one(), T::zero()),
src,
filter,
dst,
)?;
}
Ok(())
}

View File

@ -0,0 +1,646 @@
use crate::backend::BackendDevice;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
use half::{bf16, f16};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
pub struct ModuleStore {
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
context: Arc<cudarc::driver::CudaContext>,
modules: Arc<std::sync::RwLock<ModuleStore>>,
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
stream: Arc<cudarc::driver::CudaStream>,
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
impl std::fmt::Debug for CudaDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CudaDevice({:?})", self.id)
}
}
impl CudaDevice {
#[allow(clippy::missing_safety_doc)]
pub unsafe fn alloc<T: cudarc::driver::DeviceRepr>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc::<T>(len).w()
}
pub fn alloc_zeros<T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc_zeros::<T>(len).w()
}
pub fn memcpy_htod<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::HostSlice<T> + ?Sized,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_htod(src, dst).w()
}
pub fn memcpy_dtov<T: cudarc::driver::DeviceRepr, Src: cudarc::driver::DevicePtr<T>>(
&self,
src: &Src,
) -> Result<Vec<T>> {
self.stream.memcpy_dtov(src).w()
}
pub fn memcpy_dtod<
T,
Src: cudarc::driver::DevicePtr<T>,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_dtod(src, dst).w()
}
pub fn memcpy_stod<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::HostSlice<T> + ?Sized,
>(
&self,
src: &Src,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.memcpy_stod(src).w()
}
}
pub struct CudaFunc {
func: CudaFunction,
stream: Arc<cudarc::driver::CudaStream>,
}
impl std::ops::Deref for CudaFunc {
type Target = CudaFunction;
fn deref(&self) -> &Self::Target {
&self.func
}
}
impl CudaFunc {
pub fn into_cuda_function(self) -> CudaFunction {
self.func
}
}
#[macro_export]
macro_rules! builder_arg {
($b:ident, $($arg:expr),*) => {
$(
let __arg = $arg;
$b.arg(&__arg);
)*
};
}
impl CudaFunc {
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
self.stream.launch_builder(&self.func)
}
}
impl CudaDevice {
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
self.stream.clone()
}
#[cfg(not(target_arch = "wasm32"))]
pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<CudaFunc> {
let mut buf = vec![];
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
let opts = cudarc::nvrtc::CompileOptions {
use_fast_math: Some(true),
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
let module = self.context.load_module(ptx).w()?;
let func = module.load_function(func_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn id(&self) -> DeviceId {
self.id
}
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count)? };
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u8;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count)? };
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count)? };
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as i64;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count)? };
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = bf16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count)? };
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = f16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count)? };
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as f32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }?;
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub fn get_or_load_custom_func(
&self,
fn_name: &str,
module_name: &str,
ptx: &str,
) -> Result<CudaFunc> {
let ms = self.custom_modules.read().unwrap();
if let Some(mdl) = ms.get(module_name).as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
}
drop(ms);
let mut ms = self.custom_modules.write().unwrap();
let cuda_module = self.context.load_module(ptx.into()).w()?;
ms.insert(module_name.to_string(), cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
let ms = self.modules.read().unwrap();
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
}
drop(ms);
let mut ms = self.modules.write().unwrap();
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
ms.mdls[mdl.index()] = Some(cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
}
impl CudaDevice {
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.new_stream().w()?;
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
Ok(Self {
id: DeviceId::new(),
context,
stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
})
}
}
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result<Self> {
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.default_stream();
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
Ok(Self {
id: DeviceId::new(),
context,
stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
})
}
fn set_seed(&self, seed: u64) -> Result<()> {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.context.ordinal(),
}
}
fn same_device(&self, rhs: &Self) -> bool {
self.id == rhs.id
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count)?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count)?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count)?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count)? };
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count)? };
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
};
let slice = if lo == 0. && up == 1.0 {
slice
} else {
use super::utils::Map1;
let layout = Layout::contiguous(shape);
super::Affine(up - lo, lo).map(&slice, self, &layout)?
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
// curand can only generate an odd number of values.
// https://github.com/huggingface/candle/issues/734
let elem_count_round = if elem_count % 2 == 1 {
elem_count + 1
} else {
elem_count
};
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round)? };
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round)? };
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc::<u8>(elem_count)?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc::<i64>(elem_count)?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc::<bf16>(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc::<f16>(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc::<f64>(elem_count)?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn synchronize(&self) -> Result<()> {
self.stream.synchronize().map_err(crate::Error::wrap)?;
Ok(())
}
}

View File

@ -0,0 +1,62 @@
use crate::{DType, Layout};
/// cudarc related errors
#[derive(thiserror::Error, Debug)]
pub enum CudaError {
#[error(transparent)]
Cuda(#[from] cudarc::driver::DriverError),
#[error(transparent)]
Compiler(#[from] cudarc::nvrtc::CompileError),
#[error(transparent)]
Cublas(#[from] cudarc::cublas::result::CublasError),
#[error(transparent)]
Curand(#[from] cudarc::curand::result::CurandError),
#[error("missing kernel '{module_name}'")]
MissingKernel { module_name: String },
#[error("unsupported dtype {dtype:?} for {op}")]
UnsupportedDtype { dtype: DType, op: &'static str },
#[error("internal error '{0}'")]
InternalError(&'static str),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Layout,
rhs_stride: Layout,
mnk: (usize, usize, usize),
},
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
#[error("{cuda} when loading {module_name}")]
Load {
cuda: cudarc::driver::DriverError,
module_name: String,
},
}
impl From<CudaError> for crate::Error {
fn from(val: CudaError) -> Self {
crate::Error::Cuda(Box::new(val)).bt()
}
}
pub trait WrapErr<O> {
fn w(self) -> std::result::Result<O, crate::Error>;
}
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
fn w(self) -> std::result::Result<O, crate::Error> {
self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,172 @@
/// Helper functions to plug cuda kernels in candle.
use crate::{Layout, Result, Shape, WithDType};
pub use cudarc;
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
use super::{CudaDevice, CudaError, WrapErr};
pub type S = super::CudaStorageSlice;
pub trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?),
S::I64(s) => S::I64(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?),
S::F32(s) => S::F32(self.f(s, d, l)?),
S::F64(s) => S::F64(self.f(s, d, l)?),
};
Ok(out)
}
}
pub trait Map2 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
};
Ok(out)
}
}
pub trait Map3 {
#[allow(clippy::too_many_arguments)]
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
src3: &CudaSlice<T>,
layout3: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
#[allow(clippy::too_many_arguments)]
fn map(
&self,
s1: &S,
l1: &Layout,
s2: &S,
l2: &Layout,
s3: &S,
l3: &Layout,
d: &CudaDevice,
) -> Result<S> {
let out = match (s1, s2, s3) {
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
};
Ok(out)
}
}
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()>;
fn map(
&self,
dst: &mut S,
dst_s: &Shape,
src: &S,
src_l: &Layout,
d: &CudaDevice,
) -> Result<()> {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
}
}
}
pub trait Map1Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
wrap: W,
) -> Result<S>;
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
S::I64(s) => self.f(s, d, l, S::I64)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
S::F64(s) => self.f(s, d, l, S::F64)?,
};
Ok(out)
}
}
pub trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
dev: &CudaDevice,
) -> Result<S>;
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
};
Ok(out)
}
}

View File

@ -0,0 +1,490 @@
use crate::op::{BackpropOp, Op};
use crate::tensor::from_storage;
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use std::sync::Arc;
/// Unary ops that can be defined in user-land.
pub trait CustomOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_arg3: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
impl Tensor {
/// Applies a unary custom op without backward support
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a binary custom op without backward support
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
let (storage, shape) =
self.storage()
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a ternary custom op without backward support
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a unary custom op.
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
let (storage, shape) = self
.storage()
.apply_op1(self.layout(), c.as_ref().as_ref())?;
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
self.apply_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
pub fn apply_op2_arc(
&self,
rhs: &Self,
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
self.apply_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
pub fn apply_op3_arc(
&self,
t2: &Self,
t3: &Self,
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
Op::CustomOp3(t1, t2, t3, c.clone())
});
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
&self,
t2: &Self,
t3: &Self,
c: C,
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
}
// In place ops.
/// Unary ops that can be defined in user-land.
/// These ops work in place and as such back-prop is unsupported.
pub trait InplaceOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
pub trait InplaceOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
-> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &mut MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
pub trait InplaceOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &mut CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<()>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &mut CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &mut MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<()> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
}
impl Tensor {
/// Applies a unary custom op in place.
pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
self.storage_mut().inplace_op1(self.layout(), c)
}
/// Applies a unary custom op in place (for the first tensor).
pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
self.storage_mut()
.inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
}
/// Applies a ternary custom op in place (for the first tensor).
pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
self.storage_mut().inplace_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)
}
}
pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
func: cudarc::driver::CudaFunction,
#[cfg(feature = "metal")]
func: metal::ComputePipelineState,
}
impl UgIOp1 {
#[allow(unused)]
#[cfg(not(target_arch = "wasm32"))]
pub fn new(
name: &'static str,
kernel: ug::lang::ssa::Kernel,
device: &crate::Device,
) -> Result<Self> {
#[cfg(feature = "cuda")]
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self {
name,
func: func.into_cuda_function(),
})
}
#[cfg(feature = "metal")]
{
let device = device.as_metal_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(not(any(feature = "cuda", feature = "metal")))]
{
Ok(Self { name })
}
}
}
impl InplaceOp1 for UgIOp1 {
fn name(&self) -> &'static str {
self.name
}
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
crate::bail!("ug ops are only supported on metal/cuda at the moment")
}
#[cfg(feature = "metal")]
fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
use crate::backend::BackendStorage;
use candle_metal_kernels::utils::EncoderProvider;
let elem_count = layout.shape().elem_count();
if sto.dtype() != crate::DType::F32 {
// TODO: support more dtypes.
crate::bail!("input is not a f32 tensor")
}
let device = sto.device();
println!("here");
let command_buffer = device.command_buffer()?;
let command_buffer = &command_buffer;
let encoder = command_buffer.encoder();
let encoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&self.func);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let grid_dims = metal::MTLSize {
width: g as u64,
height: 1,
depth: 1,
};
let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
encoder.dispatch_threads(grid_dims, group_dims);
Ok(())
}
#[cfg(feature = "cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::PushKernelArg;
let elem_count = layout.shape().elem_count();
let stream = sto.device.cuda_stream();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (g as u32, 1, 1),
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&self.func);
builder.arg(&sto);
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
}

View File

@ -8,15 +8,17 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
Metal { gpu_id: usize },
}
/// Cpu, Cuda, or Metal
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
Metal(crate::MetalDevice),
}
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
pub trait NdArray {
fn shape(&self) -> Result<Shape>;
@ -81,15 +83,91 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
}
}
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
for &[[[[S; N4]; N3]; N2]; N1]
{
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((N1, N2, N3, N4)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
for i1 in 0..N1 {
for i2 in 0..N2 {
for i3 in 0..N3 {
vec.extend(self[i1][i2][i3])
}
}
}
S::to_cpu_storage_owned(vec)
}
}
impl<S: NdArray> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
}
fn to_cpu_storage(&self) -> CpuStorage {
// This allocates intermediary memory and shouldn't be necessary.
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
CpuStorage::concat(storages.as_slice()).unwrap()
}
}
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
match self {
Self::Cuda(d) => Ok(d),
Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
}
}
pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
match self {
Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
Self::Cpu => crate::bail!("expected a metal device, got cpu"),
Self::Metal(d) => Ok(d),
}
}
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
}
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}
pub fn set_seed(&self, seed: u64) -> Result<()> {
match self {
Self::Cpu => CpuDevice.set_seed(seed),
Self::Cuda(c) => c.set_seed(seed),
Self::Metal(m) => m.set_seed(seed),
}
}
pub fn same_device(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
(Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
_ => false,
}
}
@ -98,20 +176,35 @@ impl Device {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => device.location(),
Device::Metal(device) => device.location(),
}
}
pub fn is_cpu(&self) -> bool {
match self {
Self::Cpu => true,
Self::Cuda(_) => false,
}
matches!(self, Self::Cpu)
}
pub fn is_cuda(&self) -> bool {
matches!(self, Self::Cuda(_))
}
pub fn is_metal(&self) -> bool {
matches!(self, Self::Metal(_))
}
pub fn supports_bf16(&self) -> bool {
match self {
Self::Cuda(_) | Self::Metal(_) => true,
Self::Cpu => false,
Self::Cuda(_) => true,
}
}
/// Return `BF16` for devices that support it, otherwise default to `F32`.
pub fn bf16_default_to_f32(&self) -> DType {
if self.supports_bf16() {
DType::BF16
} else {
DType::F32
}
}
@ -136,8 +229,18 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
Device::Metal(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
Ok(Storage::Metal(storage))
}
}
}
@ -164,8 +267,18 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
Ok(Storage::Metal(storage))
}
}
}
@ -189,6 +302,10 @@ impl Device {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
@ -202,6 +319,41 @@ impl Device {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.alloc_uninit(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.alloc_uninit(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
Device::Cuda(device) => {
let storage = device.storage_from_slice(data)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.storage_from_slice(data)?;
Ok(Storage::Metal(storage))
}
}
}
@ -210,9 +362,14 @@ impl Device {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage(&storage)?;
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Metal(storage))
}
}
}
@ -221,9 +378,22 @@ impl Device {
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
Device::Cuda(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage(&storage)?;
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Metal(storage))
}
}
}
pub fn synchronize(&self) -> Result<()> {
match self {
Self::Cpu => Ok(()),
Self::Cuda(d) => d.synchronize(),
Self::Metal(d) => d.synchronize(),
}
}
}

View File

@ -1,6 +1,7 @@
/// Pretty printing of tensors
/// This implementation should be in line with the PyTorch version.
/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py
//! Pretty printing of tensors
//!
//! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py).
//!
use crate::{DType, Result, Tensor, WithDType};
use half::{bf16, f16};
@ -9,11 +10,17 @@ impl Tensor {
&self,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
let prefix = match self.device() {
crate::Device::Cpu => "Cpu",
crate::Device::Cuda(_) => "Cuda",
let device_str = match self.device().location() {
crate::DeviceLocation::Cpu => "".to_owned(),
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
};
write!(f, "{prefix}Tensor[")?;
write!(f, "Tensor[")?;
match self.dims() {
[] => {
if let Ok(v) = self.to_scalar::<T>() {
@ -40,7 +47,7 @@ impl Tensor {
}
}
}
write!(f, "; {}]", self.dtype().as_str())
write!(f, "; {}{}]", self.dtype().as_str(), device_str)
}
}
@ -49,6 +56,7 @@ impl std::fmt::Debug for Tensor {
match self.dtype() {
DType::U8 => self.fmt_dt::<u8>(f),
DType::U32 => self.fmt_dt::<u32>(f),
DType::I64 => self.fmt_dt::<i64>(f),
DType::BF16 => self.fmt_dt::<bf16>(f),
DType::F16 => self.fmt_dt::<f16>(f),
DType::F32 => self.fmt_dt::<f32>(f),
@ -58,12 +66,13 @@ impl std::fmt::Debug for Tensor {
}
/// Options for Tensor pretty printing
#[derive(Debug, Clone)]
pub struct PrinterOptions {
precision: usize,
threshold: usize,
edge_items: usize,
line_width: usize,
sci_mode: Option<bool>,
pub precision: usize,
pub threshold: usize,
pub edge_items: usize,
pub line_width: usize,
pub sci_mode: Option<bool>,
}
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
@ -82,6 +91,10 @@ impl PrinterOptions {
}
}
pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
&PRINT_OPTS
}
pub fn set_print_options(options: PrinterOptions) {
*PRINT_OPTS.lock().unwrap() = options
}
@ -110,6 +123,26 @@ pub fn set_print_options_full() {
}
}
pub fn set_line_width(line_width: usize) {
PRINT_OPTS.lock().unwrap().line_width = line_width
}
pub fn set_precision(precision: usize) {
PRINT_OPTS.lock().unwrap().precision = precision
}
pub fn set_edge_items(edge_items: usize) {
PRINT_OPTS.lock().unwrap().edge_items = edge_items
}
pub fn set_threshold(threshold: usize) {
PRINT_OPTS.lock().unwrap().threshold = threshold
}
pub fn set_sci_mode(sci_mode: Option<bool>) {
PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
}
struct FmtSize {
current_size: usize,
}
@ -431,6 +464,12 @@ impl std::fmt::Display for Tensor {
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::I64 => {
let tf: IntFormatter<i64> = IntFormatter::new();
let max_w = tf.max_width(&to_display);
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
DType::BF16 => {
if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
@ -460,6 +499,23 @@ impl std::fmt::Display for Tensor {
}
}
};
write!(f, "Tensor[{:?}, {}]", self.dims(), self.dtype().as_str())
let device_str = match self.device().location() {
crate::DeviceLocation::Cpu => "".to_owned(),
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal { gpu_id } => {
format!(", metal:{}", gpu_id)
}
};
write!(
f,
"Tensor[{:?}, {}{}]",
self.dims(),
self.dtype().as_str(),
device_str
)
}
}

View File

@ -1,18 +1,37 @@
//! Types for elements that can be stored and manipulated using tensors.
#![allow(clippy::redundant_closure_call)]
use crate::backend::BackendStorage;
use crate::{CpuStorage, Error, Result};
use crate::{CpuStorage, CpuStorageRef, Error, Result};
/// The different types of elements allowed in tensors.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
// Unsigned 8 bits integer.
U8,
// Unsigned 32 bits integer.
U32,
// Signed 64 bits integer.
I64,
// Brain floating-point using half precision (16 bits).
BF16,
// Floating-point using half precision (16 bits).
F16,
// Floating-point using single precision (32 bits).
F32,
// Floating-point using double precision (64 bits).
F64,
}
#[derive(Debug, PartialEq, Eq)]
pub struct DTypeParseError;
pub struct DTypeParseError(String);
impl std::fmt::Display for DTypeParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "cannot parse '{}' as a dtype", self.0)
}
}
impl std::error::Error for DTypeParseError {}
impl std::str::FromStr for DType {
type Err = DTypeParseError;
@ -20,20 +39,23 @@ impl std::str::FromStr for DType {
match s {
"u8" => Ok(Self::U8),
"u32" => Ok(Self::U32),
"i64" => Ok(Self::I64),
"bf16" => Ok(Self::BF16),
"f16" => Ok(Self::F16),
"f32" => Ok(Self::F32),
"f64" => Ok(Self::F64),
_ => Err(DTypeParseError),
_ => Err(DTypeParseError(s.to_string())),
}
}
}
impl DType {
/// String representation for dtypes.
pub fn as_str(&self) -> &'static str {
match self {
Self::U8 => "u8",
Self::U32 => "u32",
Self::I64 => "i64",
Self::BF16 => "bf16",
Self::F16 => "f16",
Self::F32 => "f32",
@ -41,23 +63,51 @@ impl DType {
}
}
/// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U8 => 4,
Self::U8 => 1,
Self::U32 => 4,
Self::I64 => 8,
Self::BF16 => 2,
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
}
}
pub fn is_int(&self) -> bool {
match self {
Self::U8 | Self::U32 | Self::I64 => true,
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
}
}
pub fn is_float(&self) -> bool {
match self {
Self::U8 | Self::U32 | Self::I64 => false,
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
}
}
}
pub trait WithDType: Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + 'static {
pub trait WithDType:
Sized
+ Copy
+ num_traits::NumAssign
+ std::cmp::PartialOrd
+ std::fmt::Display
+ 'static
+ Send
+ Sync
+ std::any::Any
+ crate::cpu::kernels::VecOps
{
const DTYPE: DType;
fn from_f64(v: f64) -> Self;
fn to_f64(self) -> f64;
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
@ -81,6 +131,10 @@ macro_rules! with_dtype {
$to_f64(self)
}
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
CpuStorageRef::$dtype(data)
}
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
CpuStorage::$dtype(data)
}
@ -115,6 +169,7 @@ use half::{bf16, f16};
with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
@ -125,6 +180,15 @@ pub trait IntDType: WithDType {
fn as_usize(&self) -> usize;
}
impl IntDType for i64 {
fn is_true(&self) -> bool {
*self != 0
}
fn as_usize(&self) -> usize {
*self as usize
}
}
impl IntDType for u32 {
fn is_true(&self) -> bool {
*self != 0

View File

@ -1,3 +1,5 @@
//! Implementation of the Cuda backend when Cuda support has not been compiled in.
//!
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
@ -14,6 +16,12 @@ macro_rules! fail {
};
}
impl CudaDevice {
pub fn new_with_stream(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
}
impl crate::backend::BackendStorage for CudaStorage {
type Device = CudaDevice;
@ -37,6 +45,10 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -75,6 +87,36 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv_transpose1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -119,6 +161,35 @@ impl crate::backend::BackendStorage for CudaStorage {
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn copy2d(
&self,
_: &mut Self,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
}
impl crate::backend::BackendDevice for CudaDevice {
@ -127,6 +198,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn location(&self) -> crate::DeviceLocation {
fail!()
}
@ -143,10 +218,22 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -154,4 +241,38 @@ impl crate::backend::BackendDevice for CudaDevice {
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn gemm_reduced_precision_f16() -> bool {
true
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn set_gemm_reduced_precision_f16(_: bool) {}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn gemm_reduced_precision_bf16() -> bool {
true
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn gemm_reduced_precision_f32() -> bool {
true
}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn set_gemm_reduced_precision_f32(_b: bool) {}

View File

@ -0,0 +1,252 @@
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
pub struct MetalDevice;
#[derive(Debug)]
pub struct MetalStorage;
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
macro_rules! fail {
() => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
};
}
impl crate::backend::BackendStorage for MetalStorage {
type Device = MetalDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn dtype(&self) -> DType {
fail!()
}
fn device(&self) -> &Self::Device {
fail!()
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn matmul(
&self,
_: &Self,
_: (usize, usize, usize, usize),
_: &Layout,
_: &Layout,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn copy2d(
&self,
_: &mut Self,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
}
impl crate::backend::BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn location(&self) -> crate::DeviceLocation {
fail!()
}
fn same_device(&self, _: &Self) -> bool {
fail!()
}
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
}

View File

@ -1,4 +1,5 @@
use crate::{DType, DeviceLocation, Layout, Shape};
//! Candle-specific Error and Result
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
#[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding {
@ -8,8 +9,14 @@ pub struct MatMulUnexpectedStriding {
pub msg: &'static str,
}
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}
/// Main library error type.
#[derive(thiserror::Error, Debug)]
#[derive(thiserror::Error)]
pub enum Error {
// === DType Errors ===
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
@ -30,7 +37,7 @@ pub enum Error {
UnsupportedDTypeForOp(DType, &'static str),
// === Dimension Index Errors ===
#[error("{op}: dimension index {dim} out of range for {shape:?}")]
#[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
DimOutOfRange {
shape: Shape,
dim: i32,
@ -142,6 +149,9 @@ pub enum Error {
#[error("{op} expects at least one tensor")]
OpRequiresAtLeastOneTensor { op: &'static str },
#[error("{op} expects at least two tensors")]
OpRequiresAtLeastTwoTensors { op: &'static str },
#[error("backward is not supported for {op}")]
BackwardNotSupported { op: &'static str },
@ -149,6 +159,9 @@ pub enum Error {
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
#[error("the candle crate has not been built with metal support")]
NotCompiledWithMetalSupport,
#[error("cannot find tensor {path}")]
CannotFindTensor { path: String },
@ -156,6 +169,13 @@ pub enum Error {
#[error(transparent)]
Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error("Metal error {0}")]
Metal(#[from] MetalError),
#[cfg(not(target_arch = "wasm32"))]
#[error(transparent)]
Ug(#[from] ug::Error),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),
@ -170,6 +190,10 @@ pub enum Error {
#[error(transparent)]
ParseInt(#[from] std::num::ParseIntError),
/// Utf8 parse error.
#[error(transparent)]
FromUtf8(#[from] std::string::FromUtf8Error),
/// I/O error.
#[error(transparent)]
Io(#[from] std::io::Error),
@ -182,8 +206,21 @@ pub enum Error {
UnsupportedSafeTensorDtype(safetensors::Dtype),
/// Arbitrary errors wrapping.
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
#[error("{0}")]
Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
#[error("{context}\n{inner}")]
Context {
inner: Box<Self>,
context: Box<dyn std::fmt::Display + Send + Sync>,
},
/// Adding path information to an error.
#[error("path: {path:?} {inner}")]
WithPath {
inner: Box<Self>,
path: std::path::PathBuf,
},
#[error("{inner}\n{backtrace}")]
WithBacktrace {
@ -194,13 +231,24 @@ pub enum Error {
/// User generated error message, typically created via `bail!`.
#[error("{0}")]
Msg(String),
#[error("unwrap none")]
UnwrapNone,
}
pub type Result<T> = std::result::Result<T, Error>;
impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err))
pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)).bt()
}
pub fn msg(err: impl std::fmt::Display) -> Self {
Self::Msg(err.to_string()).bt()
}
pub fn debug(err: impl std::fmt::Debug) -> Self {
Self::Msg(format!("{err:?}")).bt()
}
pub fn bt(self) -> Self {
@ -214,6 +262,20 @@ impl Error {
},
}
}
pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
Self::WithPath {
inner: Box::new(self),
path: p.as_ref().to_path_buf(),
}
}
pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Context {
inner: Box::new(self),
context: Box::new(c),
}
}
}
#[macro_export]
@ -236,3 +298,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
(_, Err(e)) => Err(e),
}
}
// Taken from anyhow.
pub trait Context<T> {
/// Wrap the error value with additional context.
fn context<C>(self, context: C) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static;
/// Wrap the error value with additional context that is evaluated lazily
/// only once an error does occur.
fn with_context<C, F>(self, f: F) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C;
}
impl<T> Context<T> for Option<T> {
fn context<C>(self, context: C) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
{
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(context).bt()),
}
}
fn with_context<C, F>(self, f: F) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C,
{
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(f()).bt()),
}
}
}

View File

@ -1,582 +0,0 @@
//! Support for the GGML file format.
use crate::{DType, Device, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use half::f16;
// Default to QK_K 256 rather than 64.
pub const QK_K: usize = 256;
pub const K_SCALE_SIZE: usize = 12;
pub const QK4_0: usize = 32;
pub const QK4_1: usize = 32;
pub const QK5_0: usize = 32;
pub const QK5_1: usize = 32;
pub const QK8_0: usize = 32;
pub const QK8_1: usize = 32;
#[repr(C)]
struct BlockQ4_0 {
d: f16,
qs: [u8; QK4_0 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
#[repr(C)]
struct BlockQ4_1 {
d: f16,
m: f16,
qs: [u8; QK4_1 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
#[repr(C)]
struct BlockQ5_0 {
d: f16,
qh: [u8; 4],
qs: [u8; QK5_0 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
#[repr(C)]
struct BlockQ5_1 {
d: f16,
m: f16,
qh: [u8; 4],
qs: [u8; QK5_1 / 2],
}
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
#[repr(C)]
struct BlockQ8_0 {
d: f16,
qs: [u8; QK8_0],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
#[repr(C)]
struct BlockQ8_1 {
d: f16,
s: f16,
qs: [u8; QK8_1],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
#[repr(C)]
struct BlockQ2K {
scales: [u8; QK_K / 16],
qs: [u8; QK_K / 4],
d: f16,
dmin: f16,
}
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
#[repr(C)]
struct BlockQ3K {
hmask: [u8; QK_K / 8],
qs: [u8; QK_K / 4],
scales: [u8; 12],
d: f16,
}
const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
#[repr(C)]
struct BlockQ4K {
d: f16,
dmin: f16,
scales: [u8; K_SCALE_SIZE],
qs: [u8; QK_K / 2],
}
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
#[repr(C)]
struct BlockQ5K {
d: f16,
dmin: f16,
scales: [u8; K_SCALE_SIZE],
qh: [u8; QK_K / 8],
qs: [u8; QK_K / 2],
}
const _: () =
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
#[repr(C)]
struct BlockQ6K {
ql: [u8; QK_K / 2],
qh: [u8; QK_K / 4],
scales: [i8; QK_K / 16],
d: f16,
}
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
fn dequantize_row_q2k(xs: &[BlockQ2K], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q2k: {k} is not divisible by {QK_K}")
}
let mut ys_index = 0;
for x in xs {
let d = x.d.to_f32();
let min = x.dmin.to_f32();
let q = &x.qs;
let mut is = 0;
for n in (0..QK_K).step_by(128) {
// Step by 32 over q.
let q = &q[n / 4..];
let mut shift = 0;
for _j in 0..4 {
let sc = x.scales[is];
is += 1;
let dl = d * (sc & 0xF) as f32;
let ml = min * (sc >> 4) as f32;
for q in &q[..16] {
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
ys[ys_index] = y;
ys_index += 1;
}
let sc = x.scales[is];
is += 1;
let dl = d * (sc & 0xF) as f32;
let ml = min * (sc >> 4) as f32;
for q in &q[16..32] {
let y = dl * ((q >> shift) & 3) as i8 as f32 - ml;
ys[ys_index] = y;
ys_index += 1;
}
shift += 2;
}
}
}
Ok(())
}
fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
if j < 4 {
let d = q[j] & 63;
let m = q[j + 4] & 63;
(d, m)
} else {
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
(d, m)
}
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
fn dequantize_row_q4k(xs: &[BlockQ4K], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q4k: {k} is not divisible by {QK_K}")
}
let mut ys_index = 0;
for x in xs.iter() {
let d = x.d.to_f32();
let min = x.dmin.to_f32();
let q = &x.qs;
let mut is = 0;
for j in (0..QK_K).step_by(64) {
let q = &q[j / 2..j / 2 + 32];
let (sc, m) = get_scale_min_k4(is, &x.scales);
let d1 = d * sc as f32;
let m1 = min * m as f32;
let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
let d2 = d * sc as f32;
let m2 = min * m as f32;
for q in q {
let y = d1 * (q & 0xF) as f32 - m1;
ys[ys_index] = y;
ys_index += 1;
}
for q in q {
let y = d2 * (q >> 4) as f32 - m2;
ys[ys_index] = y;
ys_index += 1;
}
is += 2;
}
}
Ok(())
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
fn dequantize_row_q3k(_xs: &[BlockQ3K], _ys: &mut [f32]) -> Result<()> {
todo!()
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928
fn dequantize_row_q5k(xs: &[BlockQ5K], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q5k: {k} is not divisible by {QK_K}")
}
let mut ys_index = 0;
for x in xs.iter() {
let d = x.d.to_f32();
let min = x.dmin.to_f32();
let ql = &x.qs;
let qh = &x.qh;
let mut is = 0;
let mut u1 = 1;
let mut u2 = 2;
for j in (0..QK_K).step_by(64) {
let ql = &ql[j / 2..j / 2 + 32];
let (sc, m) = get_scale_min_k4(is, &x.scales);
let d1 = d * sc as f32;
let m1 = min * m as f32;
let (sc, m) = get_scale_min_k4(is + 1, &x.scales);
let d2 = d * sc as f32;
let m2 = min * m as f32;
for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u1 != 0 { 16 } else { 1 };
let y = d1 * ((ql & 0xF) + to_add) as f32 - m1;
ys[ys_index] = y;
ys_index += 1;
}
for (ql, qh) in ql.iter().zip(qh) {
let to_add = if qh & u2 != 0 { 16 } else { 1 };
let y = d2 * ((ql >> 4) + to_add) as f32 - m2;
ys[ys_index] = y;
ys_index += 1;
}
is += 2;
u1 <<= 2;
u2 <<= 2;
}
}
Ok(())
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
fn dequantize_row_q6k(xs: &[BlockQ6K], ys: &mut [f32]) -> Result<()> {
let k = ys.len();
if k % QK_K != 0 {
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
}
for x in xs.iter() {
let d = x.d.to_f32();
let ql = &x.ql;
let qh = &x.qh;
let sc = &x.scales;
for n in (0..QK_K).step_by(128) {
let idx = n / 128;
let ys = &mut ys[n..];
let sc = &sc[8 * idx..];
let ql = &ql[64 * idx..];
let qh = &qh[32 * idx..];
for l in 0..32 {
let is = l / 16;
let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
ys[l] = d * sc[is] as f32 * q1 as f32;
ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
}
}
}
Ok(())
}
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Magic {
Ggjt,
Ggla,
Ggmf,
Ggml,
Ggsn,
}
impl TryFrom<u32> for Magic {
type Error = crate::Error;
fn try_from(value: u32) -> Result<Self> {
let magic = match value {
0x67676a74 => Self::Ggjt,
0x67676c61 => Self::Ggla,
0x67676d66 => Self::Ggmf,
0x67676d6c => Self::Ggml,
0x6767736e => Self::Ggsn,
_ => crate::bail!("unknown magic {value:08x}"),
};
Ok(magic)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VersionedMagic {
GgmlUnversioned,
GgmfV1,
GgjtV1,
GgjtV2,
GgjtV3,
}
impl VersionedMagic {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let magic = reader.read_u32::<LittleEndian>()?;
let magic = Magic::try_from(magic)?;
if magic == Magic::Ggml {
return Ok(Self::GgmlUnversioned);
}
let version = reader.read_u32::<LittleEndian>()?;
let versioned_magic = match (magic, version) {
(Magic::Ggmf, 1) => Self::GgmfV1,
(Magic::Ggjt, 1) => Self::GgjtV1,
(Magic::Ggjt, 2) => Self::GgjtV2,
(Magic::Ggjt, 3) => Self::GgjtV3,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
}
fn align32(&self) -> bool {
match self {
Self::GgmlUnversioned | Self::GgmfV1 => false,
Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HParams {
pub n_vocab: u32,
pub n_embd: u32,
pub n_mult: u32,
pub n_head: u32,
pub n_layer: u32,
pub n_rot: u32,
pub ftype: u32,
}
impl HParams {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let n_vocab = reader.read_u32::<LittleEndian>()?;
let n_embd = reader.read_u32::<LittleEndian>()?;
let n_mult = reader.read_u32::<LittleEndian>()?;
let n_head = reader.read_u32::<LittleEndian>()?;
let n_layer = reader.read_u32::<LittleEndian>()?;
let n_rot = reader.read_u32::<LittleEndian>()?;
let ftype = reader.read_u32::<LittleEndian>()?;
Ok(Self {
n_vocab,
n_embd,
n_mult,
n_head,
n_layer,
n_rot,
ftype,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Vocab {
pub token_score_pairs: Vec<(Vec<u8>, f32)>,
}
impl Vocab {
fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556
let mut token_score_pairs = Vec::with_capacity(n_vocab);
for _index in 0..n_vocab {
let len = reader.read_u32::<LittleEndian>()? as usize;
let mut word = vec![0u8; len];
reader.read_exact(&mut word)?;
let score = reader.read_f32::<LittleEndian>()?;
token_score_pairs.push((word, score))
}
Ok(Self { token_score_pairs })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GgmlDType {
F32,
F16,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
}
impl GgmlDType {
fn from_u32(u: u32) -> Result<Self> {
let dtype = match u {
0 => Self::F32,
1 => Self::F16,
2 => Self::Q4_0,
3 => Self::Q4_1,
6 => Self::Q5_0,
7 => Self::Q5_1,
8 => Self::Q8_0,
9 => Self::Q8_1,
10 => Self::Q2K,
11 => Self::Q3K,
12 => Self::Q4K,
13 => Self::Q5K,
14 => Self::Q6K,
_ => crate::bail!("unknown dtype for tensor {u}"),
};
Ok(dtype)
}
fn type_size(&self) -> usize {
match self {
Self::F32 => 4,
Self::F16 => 2,
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
Self::Q2K => std::mem::size_of::<BlockQ2K>(),
Self::Q3K => std::mem::size_of::<BlockQ3K>(),
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
}
}
fn blck_size(&self) -> usize {
match self {
Self::F32 => 1,
Self::F16 => 1,
Self::Q4_0 => QK4_0,
Self::Q4_1 => QK4_1,
Self::Q5_0 => QK5_0,
Self::Q5_1 => QK5_1,
Self::Q8_0 => QK8_0,
Self::Q8_1 => QK8_1,
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K => QK_K,
}
}
}
#[derive(Debug)]
pub struct Content {
pub magic: VersionedMagic,
pub hparams: HParams,
pub vocab: Vocab,
pub tensors: Vec<(String, Tensor)>,
}
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
device: &Device,
) -> Result<(String, Tensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
let dtype = reader.read_u32::<LittleEndian>()?;
let dtype = GgmlDType::from_u32(dtype)?;
let mut dims = vec![0u32; n_dims as usize];
reader.read_u32_into::<LittleEndian>(&mut dims)?;
let mut name = vec![0u8; name_len as usize];
reader.read_exact(&mut name)?;
let name = String::from_utf8_lossy(&name).into_owned();
if magic.align32() {
let pos = reader.stream_position()?;
reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
}
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * dtype.type_size() / dtype.blck_size();
println!("{name} {dtype:?} {dims:?}");
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
let tensor = match dtype {
GgmlDType::F32 => Tensor::from_raw_buffer(&raw_data, DType::F32, &dims, device)?,
GgmlDType::F16 => Tensor::from_raw_buffer(&raw_data, DType::F16, &dims, device)?,
GgmlDType::Q2K => {
let mut f32_data = vec![0f32; tensor_elems];
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ2K>();
let raw_data =
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ2K, n_blocks) };
dequantize_row_q2k(raw_data, &mut f32_data)?;
// Maybe we should use bf16 instead?
Tensor::from_vec(f32_data, dims, device)?
}
GgmlDType::Q3K => {
let mut f32_data = vec![0f32; tensor_elems];
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ3K>();
let raw_data =
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ3K, n_blocks) };
dequantize_row_q3k(raw_data, &mut f32_data)?;
Tensor::from_vec(f32_data, dims, device)?
}
GgmlDType::Q4K => {
let mut f32_data = vec![0f32; tensor_elems];
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ4K>();
let raw_data =
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ4K, n_blocks) };
dequantize_row_q4k(raw_data, &mut f32_data)?;
Tensor::from_vec(f32_data, dims, device)?
}
GgmlDType::Q5K => {
let mut f32_data = vec![0f32; tensor_elems];
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ5K>();
let raw_data =
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ5K, n_blocks) };
dequantize_row_q5k(raw_data, &mut f32_data)?;
Tensor::from_vec(f32_data, dims, device)?
}
GgmlDType::Q6K => {
let mut f32_data = vec![0f32; tensor_elems];
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<BlockQ6K>();
let raw_data =
unsafe { std::slice::from_raw_parts(raw_data_ptr as *const BlockQ6K, n_blocks) };
dequantize_row_q6k(raw_data, &mut f32_data)?;
Tensor::from_vec(f32_data, dims, device)?
}
_ => crate::bail!("quantized type {dtype:?} used in {name} is not supported yet"),
};
Ok((name, tensor))
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
let magic = VersionedMagic::read(reader)?;
let hparams = HParams::read(reader)?;
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
let mut tensors = vec![];
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.push((name, tensor))
}
Ok(Self {
magic,
hparams,
vocab,
tensors,
})
}
}

View File

@ -46,19 +46,31 @@ impl Tensor {
current_dim += 1;
out
}
TensorIndexer::IndexSelect(indexes) => {
if indexes.rank() != 1 {
crate::bail!("multi-dimensional tensor indexing is not supported")
}
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
current_dim += 1;
out
}
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
};
}
Ok(x)
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
/// This selects the elemnts for which an index has some specific value.
/// This selects the elements for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
/// Indexing via a 1d tensor
IndexSelect(Tensor),
Err(Error),
}
impl From<usize> for TensorIndexer {
@ -67,36 +79,55 @@ impl From<usize> for TensorIndexer {
}
}
macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {
fn from(range: $range_type) -> Self {
use std::ops::Bound::*;
let start = match range.start_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
let end = match range.end_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
TensorIndexer::Narrow(start, end)
}
impl From<&[u32]> for TensorIndexer {
fn from(index: &[u32]) -> Self {
match Tensor::new(index, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
};
}
}
impl_from_range!(Range<usize>);
impl_from_range!(RangeFrom<usize>);
impl_from_range!(RangeFull);
impl_from_range!(RangeInclusive<usize>);
impl_from_range!(RangeTo<usize>);
impl_from_range!(RangeToInclusive<usize>);
impl From<Vec<u32>> for TensorIndexer {
fn from(index: Vec<u32>) -> Self {
let len = index.len();
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<&Tensor> for TensorIndexer {
fn from(tensor: &Tensor) -> Self {
TensorIndexer::IndexSelect(tensor.clone())
}
}
trait RB: RangeBounds<usize> {}
impl RB for Range<usize> {}
impl RB for RangeFrom<usize> {}
impl RB for RangeFull {}
impl RB for RangeInclusive<usize> {}
impl RB for RangeTo<usize> {}
impl RB for RangeToInclusive<usize> {}
impl<T: RB> From<T> for TensorIndexer {
fn from(range: T) -> Self {
use std::ops::Bound::*;
let start = match range.start_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
let end = match range.end_bound() {
Included(idx) => Included(*idx),
Excluded(idx) => Excluded(*idx),
Unbounded => Unbounded,
};
TensorIndexer::Narrow(start, end)
}
}
/// Trait used to implement multiple signatures for ease of use of the slicing
/// of a tensor
@ -110,28 +141,117 @@ impl<T> IndexOp<T> for Tensor
where
T: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[
/// [0., 1.],
/// [2., 3.],
/// [4., 5.]
/// ], &Device::Cpu)?;
///
/// let b = a.i(0)?;
/// assert_eq!(b.shape().dims(), &[2]);
/// assert_eq!(b.to_vec1::<f64>()?, &[0., 1.]);
///
/// let c = a.i(..2)?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
/// assert_eq!(c.to_vec2::<f64>()?, &[
/// [0., 1.],
/// [2., 3.]
/// ]);
///
/// let d = a.i(1..)?;
/// assert_eq!(d.shape().dims(), &[2, 2]);
/// assert_eq!(d.to_vec2::<f64>()?, &[
/// [2., 3.],
/// [4., 5.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, index: T) -> Result<Tensor, Error> {
self.index(&[index.into()])
}
}
impl<A> IndexOp<(A,)> for Tensor
where
A: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[
/// [0f32, 1.],
/// [2. , 3.],
/// [4. , 5.]
/// ], &Device::Cpu)?;
///
/// let b = a.i((0,))?;
/// assert_eq!(b.shape().dims(), &[2]);
/// assert_eq!(b.to_vec1::<f32>()?, &[0., 1.]);
///
/// let c = a.i((..2,))?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
/// assert_eq!(c.to_vec2::<f32>()?, &[
/// [0., 1.],
/// [2., 3.]
/// ]);
///
/// let d = a.i((1..,))?;
/// assert_eq!(d.shape().dims(), &[2, 2]);
/// assert_eq!(d.to_vec2::<f32>()?, &[
/// [2., 3.],
/// [4., 5.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, (a,): (A,)) -> Result<Tensor, Error> {
self.index(&[a.into()])
}
}
#[allow(non_snake_case)]
impl<A, B> IndexOp<(A, B)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?;
///
/// let b = a.i((1, 0))?;
/// assert_eq!(b.to_vec0::<f32>()?, 3.);
///
/// let c = a.i((..2, 1))?;
/// assert_eq!(c.shape().dims(), &[2]);
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
///
/// let d = a.i((2.., ..))?;
/// assert_eq!(c.shape().dims(), &[2]);
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
self.index(&[a.into(), b.into()])
}
}
macro_rules! index_op_tuple {
($($t:ident),+) => {
($doc:tt, $($t:ident),+) => {
#[allow(non_snake_case)]
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
where
$($t: Into<TensorIndexer>,)*
{
#[doc=$doc]
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
self.index(&[$($t.into(),)*])
}
}
};
}
index_op_tuple!(A);
index_op_tuple!(A, B);
index_op_tuple!(A, B, C);
index_op_tuple!(A, B, C, D);
index_op_tuple!(A, B, C, D, E);
index_op_tuple!(A, B, C, D, E, F);
index_op_tuple!(A, B, C, D, E, F, G);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G);

View File

@ -1,3 +1,4 @@
//! Tensor Layouts including contiguous or sparse strides
use crate::{Error, Result, Shape};
#[derive(Debug, PartialEq, Eq, Clone)]
@ -9,6 +10,14 @@ pub struct Layout {
}
impl Layout {
pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
Self {
shape,
stride,
start_offset,
}
}
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
let shape = shape.into();
let stride = shape.stride_contiguous();
@ -27,6 +36,12 @@ impl Layout {
self.shape.dims()
}
/// The dimension size for a specified dimension index.
pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
let dim = dim.to_index(&self.shape, "dim")?;
Ok(self.dims()[dim])
}
pub fn shape(&self) -> &Shape {
&self.shape
}
@ -62,7 +77,7 @@ impl Layout {
self.shape.is_fortran_contiguous(&self.stride)
}
pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
let dims = self.shape().dims();
if dim >= dims.len() {
Err(Error::DimOutOfRange {
@ -91,7 +106,7 @@ impl Layout {
})
}
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
let rank = self.shape.rank();
if rank <= dim1 || rank <= dim2 {
Err(Error::UnexpectedNumberOfDims {
@ -112,6 +127,31 @@ impl Layout {
})
}
pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
let is_permutation =
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
if !is_permutation {
crate::bail!(
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
self.dims(),
idxs
)
}
let stride = self.stride();
let dims = self.shape().dims();
let mut perm_stride = stride.to_vec();
let mut perm_dims = dims.to_vec();
for (i, &idx) in idxs.iter().enumerate() {
perm_stride[i] = stride[idx];
perm_dims[i] = dims[idx];
}
Ok(Self {
shape: Shape::from(perm_dims),
stride: perm_stride,
start_offset: self.start_offset,
})
}
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
let shape = shape.into();
if shape.rank() < self.shape().rank() {

View File

@ -7,14 +7,14 @@
//!
//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
//! let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
//!
//! let c = a.matmul(&b)?;
//!
//! # Ok(())}
//! ```
//!
//! ## Features
//!
//! - Simple syntax (looks and like PyTorch)
//! - Simple syntax (looks and feels like PyTorch)
//! - CPU and Cuda backends (and M1 support)
//! - Enable serverless (CPU) small and fast deployments
//! - Model training
@ -32,52 +32,142 @@
//! Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
//!
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
//!
//! ## Other Crates
//!
//! Candle consists of a number of crates. This crate holds core the common data structures but you may wish
//! to look at the docs for the other crates which can be found here:
//!
//! - [candle-core](https://docs.rs/candle-core/). Core Datastructures and DataTypes.
//! - [candle-nn](https://docs.rs/candle-nn/). Building blocks for Neural Nets.
//! - [candle-datasets](https://docs.rs/candle-datasets/). Rust access to commonly used Datasets like MNIST.
//! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use.
//! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models.
//! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python.
//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models.
//!
#[cfg(feature = "accelerate")]
mod accelerate;
pub mod backend;
pub mod backprop;
mod conv;
pub mod conv;
mod convert;
pub mod cpu;
pub mod cpu_backend;
#[cfg(feature = "cuda")]
pub mod cuda_backend;
mod custom_op;
mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
pub mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error;
pub mod ggml;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "mkl")]
mod mkl;
pub mod npy;
mod op;
pub mod op;
pub mod pickle;
pub mod quantized;
pub mod safetensors;
pub mod scalar;
pub mod shape;
mod sort;
mod storage;
pub mod streaming;
mod strided_index;
mod tensor;
mod tensor_cat;
pub mod test_utils;
pub mod utils;
mod variable;
pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation};
pub use dtype::{DType, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
#[cfg(feature = "cudnn")]
pub use cuda_backend::cudnn;
pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Context, Error, Result};
pub use indexer::{IndexOp, TensorIndexer};
pub use layout::Layout;
pub use op::{CustomOp1, CustomOp2, CustomOp3};
pub use shape::{Shape, D};
pub use storage::Storage;
pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule};
pub use strided_index::{StridedBlocks, StridedIndex};
pub use tensor::{Tensor, TensorId};
pub use variable::Var;
#[cfg(feature = "cuda")]
pub use cuda_backend::{CudaDevice, CudaStorage};
pub use cuda_backend as cuda;
#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
pub use dummy_cuda_backend as cuda;
pub use cuda::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
pub trait ToUsize2 {
fn to_usize2(self) -> (usize, usize);
}
impl ToUsize2 for usize {
fn to_usize2(self) -> (usize, usize) {
(self, self)
}
}
impl ToUsize2 for (usize, usize) {
fn to_usize2(self) -> (usize, usize) {
self
}
}
/// Defining a module with forward method using a single argument.
pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)
}
}
impl<M: Module> Module for Option<&M> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
None => Ok(xs.clone()),
Some(m) => m.forward(xs),
}
}
}
/// A single forward method using a single single tensor argument and a flag to
/// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}
impl<M: Module> ModuleT for M {
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
self.forward(xs)
}
}

View File

@ -0,0 +1,340 @@
use crate::{DType, Result};
use candle_metal_kernels::Kernels;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};
use super::MetalError;
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct DeviceId(usize);
impl DeviceId {
pub(crate) fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
pub(crate) struct Commands {
/// Single command queue for the entire device.
command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
command_buffer: CommandBuffer,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
command_buffer_index: usize,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize,
}
impl Commands {
pub(crate) fn new(command_queue: CommandQueue) -> Result<Self> {
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 50,
};
Ok(Self {
command_queue,
command_buffer,
command_buffer_index: 0,
compute_per_buffer,
})
}
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> {
let mut command_buffer = self.command_buffer.to_owned();
let mut flushed = false;
if self.command_buffer_index > self.compute_per_buffer {
self.command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
self.command_buffer = command_buffer.clone();
self.command_buffer_index = 0;
flushed = true;
}
self.command_buffer_index += 1;
Ok((flushed, command_buffer))
}
pub fn wait_until_completed(&mut self) -> Result<()> {
match self.command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
self.command_buffer.commit();
self.command_buffer.wait_until_completed();
self.command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
}
#[derive(Clone)]
pub struct MetalDevice {
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
/// the device itself.
pub(crate) id: DeviceId,
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
pub(crate) device: metal::Device,
pub(crate) commands: Arc<RwLock<Commands>>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
/// (could be linked to FFI communication overhead).
///
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
/// to be reused. However, in order for this to work, we need to guarantee the order of
/// operation, so that this buffer is not being used by another kernel at the same time.
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
///
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
/// (strong_count = 1).
pub(crate) buffers: Arc<RwLock<BufferMap>>,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
/// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>,
}
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.id)
}
}
impl std::ops::Deref for MetalDevice {
type Target = metal::DeviceRef;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl MetalDevice {
#[cfg(not(target_arch = "wasm32"))]
pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<metal::ComputePipelineState> {
let mut buf = vec![];
ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?;
let metal_code = String::from_utf8(buf)?;
let lib = self
.device
.new_library_with_source(&metal_code, &metal::CompileOptions::new())
.map_err(MetalError::from)?;
let func = lib
.get_function(func_name, None)
.map_err(MetalError::from)?;
let pl = self
.device
.new_compute_pipeline_state_with_function(&func)
.map_err(MetalError::from)?;
Ok(pl)
}
pub fn id(&self) -> DeviceId {
self.id
}
pub fn metal_device(&self) -> &metal::Device {
&self.device
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut commands = self.commands.write().map_err(MetalError::from)?;
let (flushed, command_buffer) = commands.command_buffer()?;
if flushed {
self.drop_unused_buffers()?
}
Ok(command_buffer)
}
pub fn wait_until_completed(&self) -> Result<()> {
let mut commands = self.commands.write().map_err(MetalError::from)?;
commands.wait_until_completed()
}
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
pub fn device(&self) -> &metal::Device {
&self.device
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer data cannot be read on the CPU directly.
///
/// [`name`] is only used to keep track of the resource origin in case of bugs
pub fn new_buffer(
&self,
element_count: usize,
dtype: DType,
name: &str,
) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
}
/// Creates a new buffer (not necessarily zeroed).
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
/// This means the buffer can be read on the CPU but will require manual
/// synchronization when the CPU memory is modified
/// Used as a bridge to gather data back from the GPU
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
}
/// Creates a new buffer from data.
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
///
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
let size = core::mem::size_of_val(data) as NSUInteger;
let new_buffer = self.device.new_buffer_with_data(
data.as_ptr().cast(),
size,
MTLResourceOptions::StorageModeManaged,
);
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.end_encoding();
Ok(buffer)
}
/// The critical allocator algorithm
fn allocate_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
_name: &str,
) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
if let Some(b) = find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
}
let size = buf_size(size);
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone());
Ok(new_buffer)
}
/// Create a metal GPU capture trace on [`path`].
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let capture = metal::CaptureManager::shared();
let descriptor = metal::CaptureDescriptor::new();
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
descriptor.set_capture_device(self);
// The [set_output_url] call requires an absolute path so we convert it if needed.
if path.as_ref().is_absolute() {
descriptor.set_output_url(path);
} else {
let path = std::env::current_dir()?.join(path);
descriptor.set_output_url(path);
}
capture
.start_capture(&descriptor)
.map_err(MetalError::from)?;
Ok(())
}
}
fn buf_size(size: NSUInteger) -> NSUInteger {
size.saturating_sub(1).next_power_of_two() as NSUInteger
}
fn find_available_buffer(
size: NSUInteger,
option: MTLResourceOptions,
buffers: &BufferMap,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}

File diff suppressed because it is too large Load Diff

View File

@ -25,6 +25,10 @@ mod ffi {
pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsFmax(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdFmax(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsFmin(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdFmin(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn sgemm_(
transa: *const c_char,
@ -297,7 +301,7 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
}
#[inline]
fn vs_tanh(a: &[f32], y: &mut [f32]) {
pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@ -307,7 +311,7 @@ fn vs_tanh(a: &[f32], y: &mut [f32]) {
}
#[inline]
fn vd_tanh(a: &[f64], y: &mut [f64]) {
pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@ -329,6 +333,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_exp_inplace(y: &mut [f32]) {
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_exp_inplace(y: &mut [f64]) {
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
@ -351,6 +365,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
}
}
#[inline]
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vs_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
#[inline]
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vd_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
#[inline]
@ -376,3 +412,7 @@ binary_op!(vs_mul, f32, vsMul);
binary_op!(vd_mul, f64, vdMul);
binary_op!(vs_div, f32, vsDiv);
binary_op!(vd_div, f64, vdDiv);
binary_op!(vs_max, f32, vsFmax);
binary_op!(vd_max, f64, vdFmax);
binary_op!(vs_min, f32, vsFmin);
binary_op!(vd_min, f64, vdFmin);

View File

@ -26,7 +26,7 @@
//! values = np.loadz("test.npz")
//! ```
use crate::{DType, Device, Error, Result, Shape, Tensor};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use byteorder::{LittleEndian, ReadBytesExt};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::collections::HashMap;
use std::fs::File;
@ -85,6 +85,7 @@ impl Header {
DType::F16 => "f2",
DType::F32 => "f4",
DType::F64 => "f8",
DType::I64 => "i8",
DType::U32 => "u4",
DType::U8 => "u1",
};
@ -160,7 +161,7 @@ impl Header {
"f" | "f4" => DType::F32,
"d" | "f8" => DType::F64,
// "i" | "i4" => DType::S32,
// "q" | "i8" => DType::S64,
"q" | "i8" => DType::I64,
// "h" | "i2" => DType::S16,
// "b" | "i1" => DType::S8,
"B" | "u1" => DType::U8,
@ -196,7 +197,11 @@ impl Header {
impl Tensor {
// TODO: Add the possibility to read directly to a device?
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
pub(crate) fn from_reader<R: std::io::Read>(
shape: Shape,
dtype: DType,
reader: &mut R,
) -> Result<Self> {
let elem_count = shape.elem_count();
match dtype {
DType::BF16 => {
@ -229,6 +234,11 @@ impl Tensor {
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::I64 => {
let mut data_t = vec![0i64; elem_count];
reader.read_i64_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
}
}
@ -240,8 +250,6 @@ impl Tensor {
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
let mut data: Vec<u8> = vec![];
reader.read_to_end(&mut data)?;
Self::from_reader(header.shape(), header.descr, &mut reader)
}
@ -307,42 +315,7 @@ impl Tensor {
header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?;
let elem_count = self.elem_count();
match self.dtype() {
DType::BF16 => {
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
DType::U8 => {
let data = self.reshape(elem_count)?.to_vec1::<u8>()?;
f.write_all(&data)?;
}
}
Ok(())
self.write_bytes(f)
}
/// Writes a multi-dimensional array in the npy format.
@ -357,7 +330,7 @@ impl Tensor {
path: P,
) -> Result<()> {
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
let options =
let options: zip::write::FileOptions<()> =
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
for (name, tensor) in ts.iter() {
@ -373,7 +346,7 @@ pub struct NpzTensors {
index_per_name: HashMap<String, usize>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
// re-create a zip reader each time.
// re-create a zip reader for each tensor.
}
impl NpzTensors {
@ -396,6 +369,25 @@ impl NpzTensors {
})
}
pub fn names(&self) -> Vec<&String> {
self.index_per_name.keys().collect()
}
/// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
/// reading the whole tensor data.
pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
let index = match self.index_per_name.get(name) {
None => crate::bail!("cannot find tensor {name}"),
Some(index) => *index,
};
let zip_reader = BufReader::new(File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_index(index)?;
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
Ok((header.shape(), header.descr))
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
let index = match self.index_per_name.get(name) {
None => return Ok(None),

View File

@ -1,4 +1,7 @@
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
//! Tensor Opertion Enums and Traits
//!
#![allow(clippy::redundant_closure_call)]
use crate::Tensor;
use half::{bf16, f16};
use num_traits::float::Float;
@ -40,6 +43,8 @@ pub enum BinaryOp {
Mul,
Sub,
Div,
Maximum,
Minimum,
}
// Unary ops with no argument
@ -51,10 +56,19 @@ pub enum UnaryOp {
Cos,
Abs,
Neg,
Recip,
Sqr,
Sqrt,
Gelu,
GeluErf,
Erf,
Relu,
Silu,
Tanh,
Floor,
Ceil,
Round,
Sign,
}
#[derive(Clone)]
@ -77,6 +91,58 @@ pub enum Op {
kernel: Tensor,
padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
ConvTranspose1D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
Conv2D {
arg: Tensor,
kernel: Tensor,
padding: usize,
stride: usize,
dilation: usize,
},
#[allow(dead_code)]
ConvTranspose2D {
arg: Tensor,
kernel: Tensor,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
},
AvgPool2D {
arg: Tensor,
kernel_size: (usize, usize),
stride: (usize, usize),
},
MaxPool2D {
arg: Tensor,
kernel_size: (usize, usize),
stride: (usize, usize),
},
UpsampleNearest1D {
arg: Tensor,
target_size: usize,
},
UpsampleNearest2D {
arg: Tensor,
target_h: usize,
target_w: usize,
},
Cat(Vec<Tensor>, usize),
@ -91,119 +157,28 @@ pub enum Op {
Copy(Tensor),
Broadcast(Tensor),
Narrow(Tensor, usize, usize, usize),
SliceScatter0(Tensor, Tensor, usize),
Reshape(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Permute(Tensor, Vec<usize>),
Elu(Tensor, f64),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
}
/// Unary ops that can be defined in user-land.
pub trait CustomOp1: Send + Sync {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp2: Send + Sync {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp3: Send + Sync {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_arg3: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
Powf(Tensor, f64),
CustomOp1(
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
),
CustomOp2(
Tensor,
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
),
CustomOp3(
Tensor,
Tensor,
Tensor,
std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
),
}
pub trait UnaryOpT {
@ -216,6 +191,7 @@ pub trait UnaryOpT {
fn f64(v1: f64) -> f64;
fn u8(v1: u8) -> u8;
fn u32(v1: u32) -> u32;
fn i64(v1: i64) -> i64;
// There is no very good way to represent optional function in traits so we go for an explicit
// boolean flag to mark the function as existing.
@ -239,6 +215,7 @@ pub trait BinaryOpT {
fn f64(v1: f64, v2: f64) -> f64;
fn u8(v1: u8, v2: u8) -> u8;
fn u32(v1: u32, v2: u32) -> u32;
fn i64(v1: i64, v2: i64) -> i64;
const BF16_VEC: bool = false;
fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
@ -252,22 +229,35 @@ pub trait BinaryOpT {
fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
const U32_VEC: bool = false;
fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
const I64_VEC: bool = false;
fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
}
pub(crate) struct Add;
pub(crate) struct Div;
pub(crate) struct Mul;
pub(crate) struct Sub;
pub(crate) struct Maximum;
pub(crate) struct Minimum;
pub(crate) struct Exp;
pub(crate) struct Log;
pub(crate) struct Sin;
pub(crate) struct Cos;
pub(crate) struct Abs;
pub(crate) struct Neg;
pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
pub(crate) struct GeluErf;
pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Silu;
pub(crate) struct Tanh;
pub(crate) struct Floor;
pub(crate) struct Ceil;
pub(crate) struct Round;
pub(crate) struct Sign;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
@ -299,6 +289,10 @@ macro_rules! bin_op {
fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2)
}
#[inline(always)]
fn i64(v1: i64, v2: i64) -> i64 {
$e(v1, v2)
}
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
@ -314,6 +308,21 @@ macro_rules! bin_op {
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs1, xs2, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
crate::accelerate::$f32_vec(xs1, xs2, ys)
}
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
crate::accelerate::$f64_vec(xs1, xs2, ys)
}
}
};
}
@ -322,7 +331,22 @@ bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
bin_op!(
Minimum,
"minimum",
|v1, v2| if v1 > v2 { v2 } else { v1 },
vs_min,
vd_min
);
bin_op!(
Maximum,
"maximum",
|v1, v2| if v1 < v2 { v2 } else { v1 },
vs_max,
vd_max
);
#[allow(clippy::redundant_closure_call)]
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOpT for $op {
@ -353,6 +377,10 @@ macro_rules! unary_op {
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
#[inline(always)]
fn i64(_: i64) -> i64 {
todo!("no unary function for i64")
}
}
};
@ -385,6 +413,10 @@ macro_rules! unary_op {
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
#[inline(always)]
fn i64(_: i64) -> i64 {
todo!("no unary function for i64")
}
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
@ -400,6 +432,21 @@ macro_rules! unary_op {
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::$f64_vec(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::$f32_vec(xs, ys)
}
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::$f64_vec(xs, ys)
}
}
};
}
@ -408,12 +455,21 @@ unary_op!(Exp, "exp", v, v.exp(), vs_exp, vd_exp);
unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
/// `gelu` operation
// Hardcode the value for sqrt(2/pi)
// https://github.com/huggingface/candle/issues/1982
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
#[allow(clippy::excessive_precision)]
const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
/// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one.
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu";
@ -424,7 +480,7 @@ impl UnaryOpT for Gelu {
* v
* (bf16::ONE
+ bf16::tanh(
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
* v
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
))
@ -435,22 +491,18 @@ impl UnaryOpT for Gelu {
* v
* (f16::ONE
+ f16::tanh(
(f16::from_f32_const(2.0) / f16::PI).sqrt()
f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
* v
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
0.5 * v
* (1.0
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn f64(v: f64) -> f64 {
0.5 * v
* (1.0
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn u8(_: u8) -> u8 {
@ -460,6 +512,10 @@ impl UnaryOpT for Gelu {
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
const KERNEL: &'static str = "ugelu";
#[cfg(feature = "mkl")]
@ -479,6 +535,301 @@ impl UnaryOpT for Gelu {
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_gelu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_gelu(xs, ys)
}
}
/// `erf` operation
/// <https://en.wikipedia.org/wiki/Error_function>
impl UnaryOpT for Erf {
const NAME: &'static str = "erf";
const KERNEL: &'static str = "uerf";
const V: Self = Erf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
crate::cpu::erf::erf(v)
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
/// Silu operation
impl UnaryOpT for Silu {
const NAME: &'static str = "silu";
const V: Self = Silu;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v / (bf16::ONE + (-v).exp())
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v / (f16::ONE + (-v).exp())
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
const KERNEL: &'static str = "usilu";
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_silu(xs, ys)
}
#[cfg(feature = "mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_silu(xs, ys)
}
}
impl UnaryOpT for Abs {
const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs";
const V: Self = Abs;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.abs()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.abs()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.abs()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.abs()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v.abs()
}
}
impl UnaryOpT for Ceil {
const NAME: &'static str = "ceil";
const KERNEL: &'static str = "uceil";
const V: Self = Ceil;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.ceil()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.ceil()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.ceil()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.ceil()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for Floor {
const NAME: &'static str = "floor";
const KERNEL: &'static str = "ufloor";
const V: Self = Floor;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.floor()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.floor()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.floor()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.floor()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for Round {
const NAME: &'static str = "round";
const KERNEL: &'static str = "uround";
const V: Self = Round;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.round()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.round()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.round()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.round()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for GeluErf {
const NAME: &'static str = "gelu_erf";
const KERNEL: &'static str = "ugelu_erf";
const V: Self = GeluErf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
impl UnaryOpT for Relu {
@ -509,6 +860,10 @@ impl UnaryOpT for Relu {
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
@ -562,6 +917,10 @@ impl BackpropOp {
};
Self(op)
}
pub(crate) fn is_none(&self) -> bool {
self.0.is_none()
}
}
impl std::ops::Deref for BackpropOp {
@ -570,3 +929,37 @@ impl std::ops::Deref for BackpropOp {
&self.0
}
}
impl UnaryOpT for Sign {
const NAME: &'static str = "sign";
const KERNEL: &'static str = "usign";
const V: Self = Sign;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
}
#[inline(always)]
fn f32(v: f32) -> f32 {
f32::from(v > 0.) - f32::from(v < 0.)
}
#[inline(always)]
fn f64(v: f64) -> f64 {
f64::from(v > 0.) - f64::from(v < 0.)
}
#[inline(always)]
fn u8(v: u8) -> u8 {
u8::min(1, v)
}
#[inline(always)]
fn u32(v: u32) -> u32 {
u32::min(1, v)
}
#[inline(always)]
fn i64(v: i64) -> i64 {
(v > 0) as i64 - (v < 0) as i64
}
}

841
candle-core/src/pickle.rs Normal file
View File

@ -0,0 +1,841 @@
//! Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
use crate::{Context, DType, Error as E, Layout, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
use std::io::BufRead;
const VERBOSE: bool = false;
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
#[repr(u8)]
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum OpCode {
// https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123
Proto = 0x80,
Global = b'c',
BinPut = b'q',
LongBinPut = b'r',
EmptyTuple = b')',
Reduce = b'R',
Mark = b'(',
BinUnicode = b'X',
BinInt = b'J',
Tuple = b't',
BinPersId = b'Q',
BinInt1 = b'K',
BinInt2 = b'M',
Tuple1 = 0x85,
Tuple2 = 0x86,
Tuple3 = 0x87,
NewTrue = 0x88,
NewFalse = 0x89,
None = b'N',
BinGet = b'h',
LongBinGet = b'j',
SetItem = b's',
SetItems = b'u',
EmptyDict = b'}',
Dict = b'd',
Build = b'b',
Stop = b'.',
NewObj = 0x81,
EmptyList = b']',
BinFloat = b'G',
Append = b'a',
Appends = b'e',
Long1 = 0x8a,
}
// Avoid using FromPrimitive so as not to drag another dependency.
impl TryFrom<u8> for OpCode {
type Error = u8;
fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
match value {
0x80 => Ok(Self::Proto),
b'c' => Ok(Self::Global),
b'q' => Ok(Self::BinPut),
b'r' => Ok(Self::LongBinPut),
b')' => Ok(Self::EmptyTuple),
b'R' => Ok(Self::Reduce),
b'(' => Ok(Self::Mark),
b'X' => Ok(Self::BinUnicode),
b'J' => Ok(Self::BinInt),
b't' => Ok(Self::Tuple),
b'Q' => Ok(Self::BinPersId),
b'K' => Ok(Self::BinInt1),
b'M' => Ok(Self::BinInt2),
b'N' => Ok(Self::None),
0x85 => Ok(Self::Tuple1),
0x86 => Ok(Self::Tuple2),
0x87 => Ok(Self::Tuple3),
0x88 => Ok(Self::NewTrue),
0x89 => Ok(Self::NewFalse),
b'h' => Ok(Self::BinGet),
b'j' => Ok(Self::LongBinGet),
b's' => Ok(Self::SetItem),
b'u' => Ok(Self::SetItems),
b'}' => Ok(Self::EmptyDict),
b'd' => Ok(Self::EmptyDict),
b'b' => Ok(Self::Build),
b'.' => Ok(Self::Stop),
0x81 => Ok(Self::NewObj),
b']' => Ok(Self::EmptyList),
b'G' => Ok(Self::BinFloat),
b'a' => Ok(Self::Append),
b'e' => Ok(Self::Appends),
0x8a => Ok(Self::Long1),
value => Err(value),
}
}
}
fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
let mut data: Vec<u8> = Vec::with_capacity(32);
r.read_until(b'\n', &mut data)?;
data.pop();
if data.last() == Some(&b'\r') {
data.pop();
}
Ok(data)
}
#[derive(Debug, Clone, PartialEq)]
pub enum Object {
Class {
module_name: String,
class_name: String,
},
Int(i32),
Long(i64),
Float(f64),
Unicode(String),
Bool(bool),
None,
Tuple(Vec<Object>),
List(Vec<Object>),
Mark,
Dict(Vec<(Object, Object)>),
Reduce {
callable: Box<Object>,
args: Box<Object>,
},
Build {
callable: Box<Object>,
args: Box<Object>,
},
PersistentLoad(Box<Object>),
}
type OResult<T> = std::result::Result<T, Object>;
impl Object {
pub fn unicode(self) -> OResult<String> {
match self {
Self::Unicode(t) => Ok(t),
_ => Err(self),
}
}
pub fn reduce(self) -> OResult<(Self, Self)> {
match self {
Self::Reduce { callable, args } => Ok((*callable, *args)),
_ => Err(self),
}
}
pub fn none(self) -> OResult<()> {
match self {
Self::None => Ok(()),
_ => Err(self),
}
}
pub fn persistent_load(self) -> OResult<Self> {
match self {
Self::PersistentLoad(t) => Ok(*t),
_ => Err(self),
}
}
pub fn bool(self) -> OResult<bool> {
match self {
Self::Bool(t) => Ok(t),
_ => Err(self),
}
}
pub fn int(self) -> OResult<i32> {
match self {
Self::Int(t) => Ok(t),
_ => Err(self),
}
}
pub fn int_or_long(self) -> OResult<i64> {
match self {
Self::Int(t) => Ok(t as i64),
Self::Long(t) => Ok(t),
_ => Err(self),
}
}
pub fn tuple(self) -> OResult<Vec<Self>> {
match self {
Self::Tuple(t) => Ok(t),
_ => Err(self),
}
}
pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
match self {
Self::Dict(t) => Ok(t),
_ => Err(self),
}
}
pub fn class(self) -> OResult<(String, String)> {
match self {
Self::Class {
module_name,
class_name,
} => Ok((module_name, class_name)),
_ => Err(self),
}
}
pub fn into_tensor_info(
self,
name: Self,
dir_name: &std::path::Path,
) -> Result<Option<TensorInfo>> {
let name = match name.unicode() {
Ok(name) => name,
Err(_) => return Ok(None),
};
let (callable, args) = match self.reduce() {
Ok(callable_args) => callable_args,
_ => return Ok(None),
};
let (callable, args) = match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
let mut args = args.tuple()?;
let callable = args.remove(0);
let args = args.remove(1);
(callable, args)
}
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
let mut args = args.tuple()?;
args.remove(0).reduce()?
}
_ => (callable, args),
};
match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
_ => return Ok(None),
};
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
Ok(Some(TensorInfo {
name,
dtype,
layout,
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
storage_size,
}))
}
}
impl TryFrom<Object> for String {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Unicode(s) => Ok(s),
other => Err(other),
}
}
}
impl TryFrom<Object> for usize {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Int(s) if s >= 0 => Ok(s as usize),
other => Err(other),
}
}
}
impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
type Error = Object;
fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
match value {
Object::Tuple(values) => {
// This does not return the appropriate value in the error case but instead return
// the object related to the first error.
values
.into_iter()
.map(|v| T::try_from(v))
.collect::<std::result::Result<Vec<T>, Self::Error>>()
}
other => Err(other),
}
}
}
#[derive(Debug)]
pub struct Stack {
stack: Vec<Object>,
memo: HashMap<u32, Object>,
}
impl Stack {
pub fn empty() -> Self {
Self {
stack: Vec::with_capacity(512),
memo: HashMap::new(),
}
}
pub fn stack(&self) -> &[Object] {
self.stack.as_slice()
}
pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
loop {
if self.read(r)? {
break;
}
}
Ok(())
}
pub fn finalize(mut self) -> Result<Object> {
self.pop()
}
fn push(&mut self, obj: Object) {
self.stack.push(obj)
}
fn pop(&mut self) -> Result<Object> {
match self.stack.pop() {
None => crate::bail!("unexpected empty stack"),
Some(obj) => Ok(obj),
}
}
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
fn build(&mut self) -> Result<()> {
let args = self.pop()?;
let obj = self.pop()?;
let obj = match (obj, args) {
(Object::Dict(mut obj), Object::Dict(mut args)) => {
obj.append(&mut args);
Object::Dict(obj)
}
(obj, args) => Object::Build {
callable: Box::new(obj),
args: Box::new(args),
},
};
self.push(obj);
Ok(())
}
fn reduce(&mut self) -> Result<()> {
let args = self.pop()?;
let callable = self.pop()?;
#[allow(clippy::single_match)]
let reduced = match &callable {
Object::Class {
module_name,
class_name,
} => {
if module_name == "collections"
&& (class_name == "OrderedDict" || class_name == "defaultdict")
{
// TODO: have a separate ordered dict and a separate default dict.
Some(Object::Dict(vec![]))
} else {
None
}
}
_ => None,
};
let reduced = reduced.unwrap_or_else(|| Object::Reduce {
callable: Box::new(callable),
args: Box::new(args),
});
self.push(reduced);
Ok(())
}
fn last(&mut self) -> Result<&mut Object> {
match self.stack.last_mut() {
None => crate::bail!("unexpected empty stack"),
Some(obj) => Ok(obj),
}
}
fn memo_get(&self, id: u32) -> Result<Object> {
match self.memo.get(&id) {
None => crate::bail!("missing object in memo {id}"),
Some(obj) => {
// Maybe we should use refcounting rather than doing potential large clones here.
Ok(obj.clone())
}
}
}
fn memo_put(&mut self, id: u32) -> Result<()> {
let obj = self.last()?.clone();
self.memo.insert(id, obj);
Ok(())
}
fn persistent_load(&self, id: Object) -> Result<Object> {
Ok(Object::PersistentLoad(Box::new(id)))
}
fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
Ok(Object::Reduce {
callable: Box::new(class),
args: Box::new(args),
})
}
fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
let mut mark_idx = None;
for (idx, obj) in self.stack.iter().enumerate().rev() {
if obj == &Object::Mark {
mark_idx = Some(idx);
break;
}
}
match mark_idx {
Some(mark_idx) => {
let objs = self.stack.split_off(mark_idx + 1);
self.stack.pop();
Ok(objs)
}
None => {
crate::bail!("marker object not found")
}
}
}
pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
let op_code = match OpCode::try_from(r.read_u8()?) {
Ok(op_code) => op_code,
Err(op_code) => {
crate::bail!("unknown op-code {op_code}")
}
};
// println!("op: {op_code:?}");
// println!("{:?}", self.stack);
match op_code {
OpCode::Proto => {
let version = r.read_u8()?;
if VERBOSE {
println!("proto {version}");
}
}
OpCode::Global => {
let module_name = read_to_newline(r)?;
let class_name = read_to_newline(r)?;
let module_name = String::from_utf8_lossy(&module_name).to_string();
let class_name = String::from_utf8_lossy(&class_name).to_string();
self.push(Object::Class {
module_name,
class_name,
})
}
OpCode::BinInt1 => {
let arg = r.read_u8()?;
self.push(Object::Int(arg as i32))
}
OpCode::BinInt2 => {
let arg = r.read_u16::<LittleEndian>()?;
self.push(Object::Int(arg as i32))
}
OpCode::BinInt => {
let arg = r.read_i32::<LittleEndian>()?;
self.push(Object::Int(arg))
}
OpCode::BinFloat => {
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
let arg = r.read_f64::<byteorder::BigEndian>()?;
self.push(Object::Float(arg))
}
OpCode::BinUnicode => {
let len = r.read_u32::<LittleEndian>()?;
let mut data = vec![0u8; len as usize];
r.read_exact(&mut data)?;
let data = String::from_utf8(data).map_err(E::wrap)?;
self.push(Object::Unicode(data))
}
OpCode::BinPersId => {
let id = self.pop()?;
let obj = self.persistent_load(id)?;
self.push(obj)
}
OpCode::Tuple => {
let objs = self.pop_to_marker()?;
self.push(Object::Tuple(objs))
}
OpCode::Tuple1 => {
let obj = self.pop()?;
self.push(Object::Tuple(vec![obj]))
}
OpCode::Tuple2 => {
let obj2 = self.pop()?;
let obj1 = self.pop()?;
self.push(Object::Tuple(vec![obj1, obj2]))
}
OpCode::Tuple3 => {
let obj3 = self.pop()?;
let obj2 = self.pop()?;
let obj1 = self.pop()?;
self.push(Object::Tuple(vec![obj1, obj2, obj3]))
}
OpCode::NewTrue => self.push(Object::Bool(true)),
OpCode::NewFalse => self.push(Object::Bool(false)),
OpCode::Append => {
let value = self.pop()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.push(value)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::Appends => {
let objs = self.pop_to_marker()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.extend(objs)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::SetItem => {
let value = self.pop()?;
let key = self.pop()?;
let pydict = self.last()?;
if let Object::Dict(d) = pydict {
d.push((key, value))
} else {
crate::bail!("expected a dict, got {pydict:?}")
}
}
OpCode::SetItems => {
let mut objs = self.pop_to_marker()?;
let pydict = self.last()?;
if let Object::Dict(d) = pydict {
if objs.len() % 2 != 0 {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().context("empty objs")?;
d.push((key, value))
}
} else {
crate::bail!("expected a dict, got {pydict:?}")
}
}
OpCode::None => self.push(Object::None),
OpCode::Stop => {
return Ok(true);
}
OpCode::Build => self.build()?,
OpCode::EmptyDict => self.push(Object::Dict(vec![])),
OpCode::Dict => {
let mut objs = self.pop_to_marker()?;
let mut pydict = vec![];
if objs.len() % 2 != 0 {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().context("empty objs")?;
pydict.push((key, value))
}
self.push(Object::Dict(pydict))
}
OpCode::Mark => self.push(Object::Mark),
OpCode::Reduce => self.reduce()?,
OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
OpCode::EmptyList => self.push(Object::List(vec![])),
OpCode::BinGet => {
let arg = r.read_u8()?;
let obj = self.memo_get(arg as u32)?;
self.push(obj)
}
OpCode::LongBinGet => {
let arg = r.read_u32::<LittleEndian>()?;
let obj = self.memo_get(arg)?;
self.push(obj)
}
OpCode::BinPut => {
let arg = r.read_u8()?;
self.memo_put(arg as u32)?
}
OpCode::LongBinPut => {
let arg = r.read_u32::<LittleEndian>()?;
self.memo_put(arg)?
}
OpCode::NewObj => {
let args = self.pop()?;
let class = self.pop()?;
let obj = self.new_obj(class, args)?;
self.push(obj)
}
OpCode::Long1 => {
let n_bytes = r.read_u8()?;
let mut v = 0;
// Decode the next n bytes in little endian
for i in 0..n_bytes {
v |= (r.read_u8()? as i64) << (i * 8);
}
self.push(Object::Long(v))
}
}
Ok(false)
}
}
impl From<Object> for E {
fn from(value: Object) -> Self {
E::Msg(format!("conversion error on {value:?}"))
}
}
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
let mut args = args.tuple()?;
let stride = Vec::<usize>::try_from(args.remove(3))?;
let size = Vec::<usize>::try_from(args.remove(2))?;
let offset = args.remove(1).int_or_long()? as usize;
let storage = args.remove(0).persistent_load()?;
let mut storage = storage.tuple()?;
let storage_size = storage.remove(4).int_or_long()? as usize;
let path = storage.remove(2).unicode()?;
let (_module_name, class_name) = storage.remove(1).class()?;
let dtype = match class_name.as_str() {
"FloatStorage" => DType::F32,
"DoubleStorage" => DType::F64,
"HalfStorage" => DType::F16,
"BFloat16Storage" => DType::BF16,
"ByteStorage" => DType::U8,
"LongStorage" => DType::I64,
other => {
crate::bail!("unsupported storage type {other}")
}
};
let layout = Layout::new(
crate::Shape::from(size),
stride,
offset * dtype.size_in_bytes(),
);
Ok((layout, dtype, path, storage_size))
}
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub name: String,
pub dtype: DType,
pub layout: Layout,
pub path: String,
pub storage_size: usize,
}
/// Read the tensor info from a .pth file.
///
/// # Arguments
/// * `file` - The path to the .pth file.
/// * `verbose` - Whether to print debug information.
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
file: P,
verbose: bool,
key: Option<&str>,
) -> Result<Vec<TensorInfo>> {
let file = std::fs::File::open(file)?;
let zip_reader = std::io::BufReader::new(file);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let zip_file_names = zip
.file_names()
.map(|f| f.to_string())
.collect::<Vec<String>>();
let mut tensor_infos = vec![];
for file_name in zip_file_names.iter() {
if !file_name.ends_with("data.pkl") {
continue;
}
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
let reader = zip.by_name(file_name)?;
let mut reader = std::io::BufReader::new(reader);
let mut stack = Stack::empty();
stack.read_loop(&mut reader)?;
let obj = stack.finalize()?;
if VERBOSE || verbose {
println!("{obj:#?}");
}
let obj = match obj {
Object::Build { callable, args } => match *callable {
Object::Reduce { callable, args: _ } => match *callable {
Object::Class {
module_name,
class_name,
} if module_name == "__torch__" && class_name == "Module" => *args,
_ => continue,
},
_ => continue,
},
obj => obj,
};
// If key is provided, then we need to extract the state_dict from the object.
let obj = if let Some(key) = key {
if let Object::Dict(key_values) = obj {
key_values
.into_iter()
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
.map(|(_, v)| v)
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
} else {
obj
}
} else {
obj
};
// If the object is a dict, then we can extract the tensor info from it.
// NOTE: We are assuming that the `obj` is state_dict by this stage.
if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() {
match value.into_tensor_info(name, &dir_name) {
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
Ok(None) => {}
Err(err) => eprintln!("skipping: {err:?}"),
}
}
}
}
Ok(tensor_infos)
}
/// Lazy tensor loader.
pub struct PthTensors {
tensor_infos: HashMap<String, TensorInfo>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
// re-create a zip reader for each tensor.
}
impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
let tensor_infos = tensor_infos
.into_iter()
.map(|ti| (ti.name.to_string(), ti))
.collect();
let path = path.as_ref().to_owned();
Ok(Self { tensor_infos, path })
}
pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
&self.tensor_infos
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
use std::io::Read;
let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None),
Some(tensor_info) => tensor_info,
};
// We hope that the file has not changed since first reading it.
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?;
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
let rank = tensor_info.layout.shape().rank();
// Reading the data is a bit tricky as it can be strided, for now only support the basic
// case and when the tensor is fortran contiguous.
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
crate::bail!(
"cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout
)
}
let start_offset = tensor_info.layout.start_offset();
if start_offset > 0 {
std::io::copy(
&mut reader.by_ref().take(start_offset as u64),
&mut std::io::sink(),
)?;
}
let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(),
tensor_info.dtype,
&mut reader,
)?;
if rank > 1 && is_fortran_contiguous {
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
let tensor = tensor.reshape(shape_reversed)?;
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
let tensor = tensor.permute(dim_indeces_reversed)?;
Ok(Some(tensor))
} else {
Ok(Some(tensor))
}
}
}
/// Read all the tensors from a PyTorch pth file with a given key.
///
/// # Arguments
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,
) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path, key)?;
let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names {
if let Some(tensor) = pth.get(name)? {
tensors.push((name.to_string(), tensor))
}
}
Ok(tensors)
}
/// Read all the tensors from a PyTorch pth file.
///
/// # Arguments
/// * `path` - Path to the pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
read_all_with_key(path, None)
}

View File

@ -0,0 +1,667 @@
use super::k_quants::{
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[inline(always)]
pub(crate) unsafe fn sum_i16_pairs_float(x: __m256i) -> __m256 {
let ones = _mm256_set1_epi16(1);
let summed_pairs = _mm256_madd_epi16(ones, x);
_mm256_cvtepi32_ps(summed_pairs)
}
#[inline(always)]
pub(crate) unsafe fn mul_sum_us8_pairs_float(ax: __m256i, sy: __m256i) -> __m256 {
let dot = _mm256_maddubs_epi16(ax, sy);
sum_i16_pairs_float(dot)
}
#[inline(always)]
pub(crate) unsafe fn hsum_float_8(x: __m256) -> f32 {
let res = _mm256_extractf128_ps(x, 1);
let res = _mm_add_ps(res, _mm256_castps256_ps128(x));
let res = _mm_add_ps(res, _mm_movehl_ps(res, res));
let res = _mm_add_ss(res, _mm_movehdup_ps(res));
_mm_cvtss_f32(res)
}
#[inline(always)]
pub(crate) unsafe fn bytes_from_nibbles_32(rsi: *const u8) -> __m256i {
let tmp = _mm_loadu_si128(rsi as *const __m128i);
let bytes = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(tmp), _mm_srli_epi16(tmp, 4));
let low_mask = _mm256_set1_epi8(0xF);
_mm256_and_si256(low_mask, bytes)
}
#[inline(always)]
pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
let ax = _mm256_sign_epi8(x, x);
let sy = _mm256_sign_epi8(y, x);
mul_sum_us8_pairs_float(ax, sy)
}
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
let bx = bytes_from_nibbles_32(x.qs.as_ptr());
let off = _mm256_set1_epi8(8);
let bx = _mm256_sub_epi8(bx, off);
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
let q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(d, q, acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i);
let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
let q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(d, q, acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
unsafe fn get_scale_shuffle(i: usize) -> __m128i {
const K_SHUFFLE: [u8; 128] = [
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10,
11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
];
_mm_loadu_si128((K_SHUFFLE.as_ptr() as *const __m128i).add(i))
}
#[inline(always)]
unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i {
const K_SHUFFLE: [u8; 256] = [
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10,
11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13,
12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
13, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
];
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
}
#[inline(always)]
unsafe fn get_scale_shuffle_q3k(i: usize) -> __m256i {
const K_SHUFFLE: [u8; 128] = [
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11,
10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
];
_mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % qk != 0 {
crate::bail!("vec_dot_q6k_8k: {n} is not divisible by {qk}")
}
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let m2 = _mm256_set1_epi8(3);
let m32s = _mm256_set1_epi8(32);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let mut q4 = x.ql.as_ptr();
let mut qh = x.qh.as_ptr();
let mut q8 = y.qs.as_ptr();
let scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 128 {
let is = j * 4;
let scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is));
let scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
let scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
let scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
let q4bits1 = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4bits2 = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4bits_h = _mm256_loadu_si256(qh as *const __m256i);
qh = qh.add(32);
let q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bits_h, m2), 4);
let q4h_1 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 2), m2), 4);
let q4h_2 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 4), m2), 4);
let q4h_3 =
_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bits_h, 6), m2), 4);
let q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
let q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
let q4_2 =
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
let q4_3 =
_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
let q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
let q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
let q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
let p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
let p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
let p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
let p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
let p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
let p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
let p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
}
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i {
_mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1)
}
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
unsafe {
let m3 = _mm256_set1_epi8(3);
let m4 = _mm_set1_epi8(0xF);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
let mut q2 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
let scales8 = _mm_and_si128(mins_and_scales, m4);
let mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
let mins = _mm256_cvtepi8_epi16(mins8);
let prod =
_mm256_madd_epi16(mins, _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i));
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
let all_scales = _mm256_cvtepi8_epi16(scales8);
let l_scales = _mm256_extracti128_si256(all_scales, 0);
let h_scales = _mm256_extracti128_si256(all_scales, 1);
let scales = [
mm256_set_m128i(l_scales, l_scales),
mm256_set_m128i(h_scales, h_scales),
];
let mut sumi = _mm256_setzero_si256();
for scale in scales {
let q2bits = _mm256_loadu_si256(q2 as *const __m256i);
q2 = q2.add(32);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q2_0 = _mm256_and_si256(q2bits, m3);
let q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
let p0 = _mm256_maddubs_epi16(q2_0, q8_0);
let p1 = _mm256_maddubs_epi16(q2_1, q8_1);
let p2 = _mm256_maddubs_epi16(q2_2, q8_2);
let p3 = _mm256_maddubs_epi16(q2_3, q8_3);
let p0 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0);
let p1 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1);
let p2 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2);
let p3 =
_mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3);
let p0 = _mm256_add_epi32(p0, p1);
let p2 = _mm256_add_epi32(p2, p3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
}
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
}
const KMASK1: u32 = 0x03030303;
const KMASK2: u32 = 0x0f0f0f0f;
let mut aux = [0u32; 3];
unsafe {
let m3 = _mm256_set1_epi8(3);
let mone = _mm256_set1_epi8(1);
let m32 = _mm_set1_epi8(32);
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let mut q3 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
LittleEndian::read_u32_into(&x.scales, &mut aux);
let scales128 = _mm_set_epi32(
(((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32,
(((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32,
((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32,
((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32,
);
let scales128 = _mm_sub_epi8(scales128, m32);
let all_scales = _mm256_cvtepi8_epi16(scales128);
let l_scales = _mm256_extracti128_si256(all_scales, 0);
let h_scales = _mm256_extracti128_si256(all_scales, 1);
let scales = [
mm256_set_m128i(l_scales, l_scales),
mm256_set_m128i(h_scales, h_scales),
];
// high bit
let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i);
let mut sumi = _mm256_setzero_si256();
for (j, scale) in scales.iter().enumerate() {
// load low 2 bits
let q3bits = _mm256_loadu_si256(q3 as *const __m256i);
q3 = q3.add(32);
// Prepare low and high bits
// We hardcode the shifts here to avoid loading them into a separate register
let q3l_0 = _mm256_and_si256(q3bits, m3);
let q3h_0 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), 4)
};
let q3h_0 = _mm256_slli_epi16(q3h_0, 2);
let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
let q3h_1 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), 1)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), 5)
};
let q3h_1 = _mm256_slli_epi16(q3h_1, 2);
let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
let q3h_2 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), 2)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), 6)
};
let q3h_2 = _mm256_slli_epi16(q3h_2, 2);
let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
let q3h_3 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), 3)
} else {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), 7)
};
let q3h_3 = _mm256_slli_epi16(q3h_3, 2);
// load Q8 quants
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we
// can use _mm256_maddubs_epi16, and then subtract. The high bit part has the 2
// already subtracted (and so, it is zero if the high bit was not set, and 2 if the
// high bit was set)
let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
let p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
let p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
let p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
let p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
let p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
let p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
let p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
// multiply with scales
let p16_0 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0);
let p16_1 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1);
let p16_2 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2);
let p16_3 =
_mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3);
// accumulate
let p16_0 = _mm256_add_epi32(p16_0, p16_1);
let p16_2 = _mm256_add_epi32(p16_2, p16_3);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
}
// multiply with block scale and accumulate
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
let mut utmp = [0u32; 4];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let mut acc = _mm256_setzero_ps();
let mut acc_m = _mm_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
let mut q4 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
utmp[3] as i32,
utmp[2] as i32,
utmp[1] as i32,
utmp[0] as i32,
));
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
let q8s = _mm_hadd_epi16(
_mm256_extracti128_si256(q8sums, 0),
_mm256_extracti128_si256(q8sums, 1),
);
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
let scales = mm256_set_m128i(sc128, sc128);
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 64 {
let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
let q4bits = _mm256_loadu_si256(q4 as *const __m256i);
q4 = q4.add(32);
let q4l = _mm256_and_si256(q4bits, m4);
let q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
let q8l = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16l = _mm256_maddubs_epi16(q4l, q8l);
let p16l = _mm256_madd_epi16(scale_l, p16l);
sumi = _mm256_add_epi32(sumi, p16l);
let q8h = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16h = _mm256_maddubs_epi16(q4h, q8h);
let p16h = _mm256_madd_epi16(scale_h, p16h);
sumi = _mm256_add_epi32(sumi, p16h);
}
let vd = _mm256_set1_ps(d);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
}
let acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
let acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
}
let mut utmp = [0u32; 4];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4 = _mm256_set1_epi8(0xF);
let mzero = _mm_setzero_si128();
let mone = _mm256_set1_epi8(1);
let mut acc = _mm256_setzero_ps();
let mut summs = 0.0;
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
let mut q5 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
utmp[3] as i32,
utmp[2] as i32,
utmp[1] as i32,
utmp[0] as i32,
));
let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
let q8s = _mm_hadd_epi16(
_mm256_extracti128_si256(q8sums, 0),
_mm256_extracti128_si256(q8sums, 1),
);
let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
let hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
summs += dmin * _mm_extract_epi32(hsum, 0) as f32;
let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
let scales = mm256_set_m128i(sc128, sc128);
let hbits = _mm256_loadu_si256(x.qh.as_ptr() as *const __m256i);
let mut hmask = mone;
let mut sumi = _mm256_setzero_si256();
for j in 0..QK_K / 64 {
let scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
let scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
q5 = q5.add(32);
//Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
let q5l_0 = _mm256_and_si256(q5bits, m4);
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_0_right_shift = match j {
0 => _mm256_srli_epi16(q5l_0_shift_input, 0),
1 => _mm256_srli_epi16(q5l_0_shift_input, 2),
2 => _mm256_srli_epi16(q5l_0_shift_input, 4),
3 => _mm256_srli_epi16(q5l_0_shift_input, 6),
_ => unreachable!(),
};
let q5h_0 = _mm256_slli_epi16(q5l_0_right_shift, 4);
let q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
hmask = _mm256_slli_epi16(hmask, 1);
let q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
let q5l_1_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_1_right_shift = match j {
0 => _mm256_srli_epi16(q5l_1_shift_input, 1),
1 => _mm256_srli_epi16(q5l_1_shift_input, 3),
2 => _mm256_srli_epi16(q5l_1_shift_input, 5),
3 => _mm256_srli_epi16(q5l_1_shift_input, 7),
_ => unreachable!(),
};
let q5h_1 = _mm256_slli_epi16(q5l_1_right_shift, 4);
let q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
hmask = _mm256_slli_epi16(hmask, 1);
let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
q8 = q8.add(32);
let p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
let p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
let p16_0 = _mm256_madd_epi16(scale_0, p16_0);
let p16_1 = _mm256_madd_epi16(scale_1, p16_1);
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
}
let vd = _mm256_set1_ps(d);
acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc) + summs)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % qk != 0 {
crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (xs, ys) in xs.iter().zip(ys.iter()) {
let mut sumi = _mm256_setzero_si256();
let x_qs = xs.qs.as_ptr();
let y_qs = ys.qs.as_ptr();
for j in (0..QK_K).step_by(32) {
let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i);
let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i);
let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0));
let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0));
let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1));
let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1));
}
let d = _mm256_set1_ps(xs.d * ys.d);
acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}

View File

@ -0,0 +1,737 @@
use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{builder_arg as barg, CudaDevice, CudaStorage, Result};
use half::f16;
use cudarc::driver::{CudaSlice, CudaView, PushKernelArg};
#[derive(Clone, Debug)]
struct PaddedCudaSlice {
inner: CudaSlice<u8>,
len: usize,
}
#[derive(Clone, Debug)]
pub struct QCudaStorage {
data: PaddedCudaSlice,
dtype: GgmlDType,
device: CudaDevice,
}
static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
pub fn set_force_dmmv(f: bool) {
FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
}
pub const WARP_SIZE: usize = 32;
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
pub const MATRIX_ROW_PADDING: usize = 512;
fn ceil_div(p: usize, q: usize) -> usize {
p.div_ceil(q)
}
fn pad(p: usize, q: usize) -> usize {
ceil_div(p, q) * q
}
fn quantize_q8_1(
src: &CudaView<f32>,
dst: &mut CudaSlice<u8>,
elem_count: usize,
ky: usize,
dev: &CudaDevice,
) -> Result<()> {
let kx = elem_count;
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, ky as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = func.builder();
builder.arg(src);
builder.arg(dst);
barg!(builder, kx as i32, kx_padded as i32);
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
fn dequantize_f32(
data: &PaddedCudaSlice,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
let nb = elem_count.div_ceil(256);
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(elem_count)? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
if is_k {
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
unsafe { builder.launch(cfg) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
barg!(builder, nb32 as i32);
unsafe { builder.launch(cfg) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_f16(
data: &PaddedCudaSlice,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
let nb = elem_count.div_ceil(256);
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
GgmlDType::Q5_0 => (
"dequantize_block_q5_0_f16",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
"dequantize_block_q5_1_f16",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
if is_k {
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
unsafe { builder.launch(cfg) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
barg!(builder, nb32 as i32);
unsafe { builder.launch(cfg) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mul_mat_vec(
data: &PaddedCudaSlice,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
let kernel_name = match dtype {
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows)? };
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
shared_mem_bytes: 0,
};
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(y);
builder.arg(&dst);
barg!(builder, ncols as i32, nrows as i32);
unsafe { builder.launch(cfg) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn mul_mat_vec_via_q8_1(
data: &PaddedCudaSlice,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
b_size: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
if y.len() != ncols * b_size {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
if b_size == 0 || b_size > 8 {
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
}
// Start by quantizing y
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes =
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size)? };
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
1 => (nrows as u32, 4),
2..=4 => ((nrows as u32).div_ceil(2), 4),
5..=8 => ((nrows as u32).div_ceil(2), 2),
_ => crate::bail!("unexpected bsize {b_size}"),
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (nblocks, 1, 1),
block_dim: (WARP_SIZE as u32, nwarps, 1),
shared_mem_bytes: 0,
};
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&y_q8_1);
builder.arg(&dst);
barg!(
builder,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
/* nrows_y */ ncols_padded as i32,
/* nrows_dst */ nrows as i32
);
unsafe { builder.launch(cfg) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
#[allow(clippy::too_many_arguments)]
fn mul_mat_via_q8_1(
data: &PaddedCudaSlice,
y: &CudaView<f32>,
dtype: GgmlDType,
x_rows: usize,
x_cols: usize,
y_rows: usize,
y_cols: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
}
if y.len() != y_rows * y_cols {
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
}
if x_cols != y_rows {
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
}
let k = x_cols;
// Start by quantizing y
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
let (kernel_name, mmq_x, mmq_y) = match dtype {
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols)? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
ceil_div(x_rows, mmq_y) as u32,
ceil_div(y_cols, mmq_x) as u32,
1,
),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
let mut builder = func.builder();
builder.arg(/* vx */ &data.inner);
builder.arg(/* vy */ &y_q8_1);
builder.arg(/* dst */ &dst);
barg!(
builder,
/* ncols_x */ x_cols as i32,
/* nrows_x */ x_rows as i32,
/* ncols_y */ y_cols as i32,
/* nrows_y */ k_padded as i32,
/* nrows_dst */ x_rows as i32
);
unsafe { builder.launch(cfg) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let padded_size_in_bytes =
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes)?;
Ok(QCudaStorage {
data: PaddedCudaSlice {
inner,
len: size_in_bytes,
},
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
fn deq<T: GgmlType>(buffer: &[u8], n: usize, dst: &mut [f32]) -> Result<()> {
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
let vec = slice.to_vec();
T::to_float(&vec, dst)
}
let fast_kernel = matches!(
self.dtype,
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
| GgmlDType::Q8K
);
if fast_kernel {
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.
let buffer = self
.device
.memcpy_dtov(&self.data.inner.slice(..self.data.len))?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_1 => deq::<crate::quantized::BlockQ5_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q8_0 => deq::<crate::quantized::BlockQ8_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q8_1 => deq::<crate::quantized::BlockQ8_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q2K => deq::<crate::quantized::BlockQ2K>(&buffer, block_len, &mut out)?,
GgmlDType::Q3K => deq::<crate::quantized::BlockQ3K>(&buffer, block_len, &mut out)?,
GgmlDType::Q4K => deq::<crate::quantized::BlockQ4K>(&buffer, block_len, &mut out)?,
GgmlDType::Q5K => deq::<crate::quantized::BlockQ5K>(&buffer, block_len, &mut out)?,
GgmlDType::Q6K => deq::<crate::quantized::BlockQ6K>(&buffer, block_len, &mut out)?,
GgmlDType::Q8K => deq::<crate::quantized::BlockQ8K>(&buffer, block_len, &mut out)?,
}
self.device
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
}
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
}
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?,
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?;
let padded_len =
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
self.device
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
self.data = PaddedCudaSlice {
inner,
len: data.len(),
};
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.data.len
}
pub fn fwd(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
1
} else {
8
};
let use_vec_kernel = match layout.shape().dims() {
[b, m, _k] => b * m <= max_bm,
[b, _k] => *b <= max_bm,
_ => false,
};
if use_vec_kernel {
self.dequantize_matmul_vec(self_shape, storage, layout)
} else {
self.dequantize_matmul(self_shape, storage, layout)
}
}
}
impl QCudaStorage {
fn dequantize_matmul_vec(
&self,
self_shape: &crate::Shape,
rhs: &CudaStorage,
rhs_l: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
let (nrows, ncols) = self_shape.dims2()?;
let rhs = rhs.as_cuda_slice::<f32>()?;
let rhs = match rhs_l.contiguous_offsets() {
Some((o1, o2)) => rhs.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
};
let (b_size, k) = match rhs_l.shape().dims() {
[b, m, k] => (b * m, *k),
[b, k] => (*b, *k),
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
};
if ncols != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
} else {
mul_mat_vec_via_q8_1(
&self.data,
&rhs,
self.dtype,
ncols,
nrows,
b_size,
self.device(),
)?
};
let mut out_shape = rhs_l.shape().dims().to_vec();
out_shape.pop();
out_shape.push(nrows);
Ok((out, out_shape.into()))
}
fn dequantize_matmul(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
use crate::backend::BackendStorage;
let (n, k) = self_shape.dims2()?;
let (b, m, k2) = match layout.shape().dims() {
&[b, m, k2] => (b, m, k2),
&[m, k2] => (1, m, k2),
s => crate::bail!("unexpected shape for input {s:?}"),
};
if k2 != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
}
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
} else {
let storage = storage.as_cuda_slice::<f32>()?;
let storage = match layout.contiguous_offsets() {
Some((o1, o2)) => storage.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous {
op: "quantized-matmul",
}
.bt())?,
};
mul_mat_via_q8_1(
&self.data,
&storage,
self.dtype,
/* x_rows */ n,
/* x_cols */ k,
/* y_rows */ k,
/* y_cols */ b * m,
self.device(),
)?
};
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();
out_shape.push(n);
Ok((out, out_shape.into()))
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
device: &CudaDevice,
data: &[T],
) -> Result<super::QStorage> {
let data = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
};
let dtype = T::DTYPE;
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
let mut inner = unsafe { device.alloc::<u8>(padded_len)? };
device.memcpy_htod(data, &mut inner.slice_mut(..data.len()))?;
Ok(QStorage::Cuda(QCudaStorage {
data: PaddedCudaSlice {
inner,
len: data.len(),
},
device: device.clone(),
dtype,
}))
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn cuda_quantize_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let el = 256;
let el_padded = pad(el, MATRIX_ROW_PADDING);
let y_size_in_bytes =
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.memcpy_stod(&vs)?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
#[test]
fn cuda_mmv_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
let y = dev.memcpy_stod(&vs)?;
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
/* b_size */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..))?;
assert_eq!(vs.len(), 1);
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
// Q8 means 1/256 precision.
assert_eq!(vs[0], 5561664.5);
let cuda_storage = dequantize_mul_mat_vec(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..))?;
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], 5561851.0);
Ok(())
}
#[test]
fn cuda_mm_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
let y = dev.memcpy_stod(&vs)?;
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ 4,
/* x_cols */ ncols,
/* y_rows */ ncols,
/* y_cols */ 4,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..))?;
/*
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
x @ x.t() / 16
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
*/
assert_eq!(vs.len(), 16);
assert_eq!(vs[0], 347604.0);
assert_eq!(vs[1], 888153.06);
assert_eq!(vs[4], 869780.7);
assert_eq!(vs[5], 2483145.0);
assert_eq!(vs[11], 9407368.0);
assert_eq!(vs[14], 9470856.0);
assert_eq!(vs[15], 13138824.0);
Ok(())
}
// The following test used to fail under compute-sanitizer until #2526.
#[test]
fn cuda_mm_q8_1_pad() -> Result<()> {
let dev = CudaDevice::new(0)?;
let (x_rows, ncols, y_cols) = (4, 16, 2048);
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
let y = dev.memcpy_stod(&vs)?;
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ x_rows,
/* x_cols */ ncols,
/* y_rows */ ncols,
/* y_cols */ y_cols,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let _vs = dev.memcpy_dtov(&vs.slice(..))?;
Ok(())
}
}

View File

@ -0,0 +1,54 @@
#![allow(unused)]
use super::GgmlDType;
use crate::{CudaDevice, CudaStorage, Error, Result};
pub struct QCudaStorage {
dtype: GgmlDType,
device: CudaDevice,
}
impl QCudaStorage {
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &CudaStorage,
_layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
Err(Error::NotCompiledWithCudaSupport)
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &CudaDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -0,0 +1,50 @@
#![allow(unused)]
use super::GgmlDType;
use crate::{Error, MetalDevice, MetalStorage, Result};
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
}
impl QMetalStorage {
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &MetalStorage,
_layout: &crate::Layout,
) -> Result<(MetalStorage, crate::Shape)> {
Err(Error::NotCompiledWithMetalSupport)
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &MetalDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithMetalSupport)
}

View File

@ -0,0 +1,265 @@
//! Support for the GGML file format.
use super::{k_quants, GgmlDType, QStorage};
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Magic {
Ggjt,
Ggla,
Ggmf,
Ggml,
Ggsn,
}
impl TryFrom<u32> for Magic {
type Error = crate::Error;
fn try_from(value: u32) -> Result<Self> {
let magic = match value {
0x67676a74 => Self::Ggjt,
0x67676c61 => Self::Ggla,
0x67676d66 => Self::Ggmf,
0x67676d6c => Self::Ggml,
0x6767736e => Self::Ggsn,
_ => crate::bail!("unknown magic {value:08x}"),
};
Ok(magic)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VersionedMagic {
GgmlUnversioned,
GgmfV1,
GgjtV1,
GgjtV2,
GgjtV3,
}
impl VersionedMagic {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let magic = reader.read_u32::<LittleEndian>()?;
let magic = Magic::try_from(magic)?;
if magic == Magic::Ggml {
return Ok(Self::GgmlUnversioned);
}
let version = reader.read_u32::<LittleEndian>()?;
let versioned_magic = match (magic, version) {
(Magic::Ggmf, 1) => Self::GgmfV1,
(Magic::Ggjt, 1) => Self::GgjtV1,
(Magic::Ggjt, 2) => Self::GgjtV2,
(Magic::Ggjt, 3) => Self::GgjtV3,
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
}
fn align32(&self) -> bool {
match self {
Self::GgmlUnversioned | Self::GgmfV1 => false,
Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HParams {
pub n_vocab: u32,
pub n_embd: u32,
pub n_mult: u32,
pub n_head: u32,
pub n_layer: u32,
pub n_rot: u32,
pub ftype: u32,
}
impl HParams {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let n_vocab = reader.read_u32::<LittleEndian>()?;
let n_embd = reader.read_u32::<LittleEndian>()?;
let n_mult = reader.read_u32::<LittleEndian>()?;
let n_head = reader.read_u32::<LittleEndian>()?;
let n_layer = reader.read_u32::<LittleEndian>()?;
let n_rot = reader.read_u32::<LittleEndian>()?;
let ftype = reader.read_u32::<LittleEndian>()?;
Ok(Self {
n_vocab,
n_embd,
n_mult,
n_head,
n_layer,
n_rot,
ftype,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Vocab {
pub token_score_pairs: Vec<(Vec<u8>, f32)>,
}
impl Vocab {
fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556
let mut token_score_pairs = Vec::with_capacity(n_vocab);
for _index in 0..n_vocab {
let len = reader.read_u32::<LittleEndian>()? as usize;
let mut word = vec![0u8; len];
reader.read_exact(&mut word)?;
let score = reader.read_f32::<LittleEndian>()?;
token_score_pairs.push((word, score))
}
Ok(Self { token_score_pairs })
}
}
fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8],
size_in_bytes: usize,
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
let data: QStorage = match device {
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
};
super::QTensor::new(data, dims)
}
/// Creates a Tensor from a raw GGML tensor.
pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
let block_size = ggml_dtype.block_size();
if tensor_elems % block_size != 0 {
crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
)
}
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => {
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4_1 => {
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_0 => {
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_1 => {
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
device: &Device,
) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
let mut dims = vec![0u32; n_dims as usize];
reader.read_u32_into::<LittleEndian>(&mut dims)?;
// The dimensions are stored in reverse order, see for example:
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/convert.py#L969
dims.reverse();
let mut name = vec![0u8; name_len as usize];
reader.read_exact(&mut name)?;
let name = String::from_utf8_lossy(&name).into_owned();
if magic.align32() {
let pos = reader.stream_position()?;
reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
}
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
}
}
pub struct Content {
pub magic: VersionedMagic,
pub hparams: HParams,
pub vocab: Vocab,
pub tensors: HashMap<String, super::QTensor>,
pub device: Device,
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
let magic = VersionedMagic::read(reader)?;
let hparams = HParams::read(reader)?;
let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
let mut tensors = HashMap::new();
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.insert(name, tensor);
}
let device = device.clone();
Ok(Self {
magic,
hparams,
vocab,
tensors,
device,
})
}
pub fn remove(&mut self, name: &str) -> Result<super::QTensor> {
match self.tensors.remove(name) {
None => crate::bail!("cannot find tensor with name '{name}'"),
Some(tensor) => Ok(tensor),
}
}
}

View File

@ -0,0 +1,538 @@
//! Support for the [GGUF file format](https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md).
//!
use super::{GgmlDType, QTensor};
use crate::{Context, Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
pub const DEFAULT_ALIGNMENT: u64 = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Magic {
Gguf,
}
impl TryFrom<u32> for Magic {
type Error = crate::Error;
fn try_from(value: u32) -> Result<Self> {
let magic = match value {
0x46554747 | 0x47475546 => Self::Gguf,
_ => crate::bail!("unknown magic 0x{value:08x}"),
};
Ok(magic)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VersionedMagic {
GgufV1,
GgufV2,
GgufV3,
}
impl VersionedMagic {
fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
let magic = reader.read_u32::<LittleEndian>()?;
let magic = Magic::try_from(magic)?;
let version = reader.read_u32::<LittleEndian>()?;
let versioned_magic = match (magic, version) {
(Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2,
(Magic::Gguf, 3) => Self::GgufV3,
_ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
}
}
#[derive(Debug)]
pub struct TensorInfo {
pub ggml_dtype: GgmlDType,
pub shape: crate::Shape,
pub offset: u64,
}
impl TensorInfo {
pub fn read<R: std::io::Seek + std::io::Read>(
&self,
reader: &mut R,
tensor_data_offset: u64,
device: &Device,
) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count();
let block_size = self.ggml_dtype.block_size();
if tensor_elems % block_size != 0 {
crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}"
)
}
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml(
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
}
}
#[derive(Debug)]
pub struct Content {
pub magic: VersionedMagic,
pub metadata: HashMap<String, Value>,
pub tensor_infos: HashMap<String, TensorInfo>,
pub tensor_data_offset: u64,
}
fn read_string<R: std::io::Read>(reader: &mut R, magic: &VersionedMagic) -> Result<String> {
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let mut v = vec![0u8; len];
reader.read_exact(&mut v)?;
// GGUF strings are supposed to be non-null terminated but in practice this happens.
while let Some(0) = v.last() {
v.pop();
}
// GGUF strings are utf8 encoded but there are cases that don't seem to be valid.
Ok(String::from_utf8_lossy(&v).into_owned())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ValueType {
// The value is a 8-bit unsigned integer.
U8,
// The value is a 8-bit signed integer.
I8,
// The value is a 16-bit unsigned little-endian integer.
U16,
// The value is a 16-bit signed little-endian integer.
I16,
// The value is a 32-bit unsigned little-endian integer.
U32,
// The value is a 32-bit signed little-endian integer.
I32,
// The value is a 64-bit unsigned little-endian integer.
U64,
// The value is a 64-bit signed little-endian integer.
I64,
// The value is a 32-bit IEEE754 floating point number.
F32,
// The value is a 64-bit IEEE754 floating point number.
F64,
// The value is a boolean.
// 1-byte value where 0 is false and 1 is true.
// Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy.
Bool,
// The value is a UTF-8 non-null-terminated string, with length prepended.
String,
// The value is an array of other values, with the length and type prepended.
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
Array,
}
#[derive(Debug, Clone)]
pub enum Value {
U8(u8),
I8(i8),
U16(u16),
I16(i16),
U32(u32),
I32(i32),
U64(u64),
I64(i64),
F32(f32),
F64(f64),
Bool(bool),
String(String),
Array(Vec<Value>),
}
impl Value {
pub fn value_type(&self) -> ValueType {
match self {
Self::U8(_) => ValueType::U8,
Self::I8(_) => ValueType::I8,
Self::U16(_) => ValueType::U16,
Self::I16(_) => ValueType::I16,
Self::U32(_) => ValueType::U32,
Self::I32(_) => ValueType::I32,
Self::U64(_) => ValueType::U64,
Self::I64(_) => ValueType::I64,
Self::F32(_) => ValueType::F32,
Self::F64(_) => ValueType::F64,
Self::Bool(_) => ValueType::Bool,
Self::String(_) => ValueType::String,
Self::Array(_) => ValueType::Array,
}
}
pub fn to_u8(&self) -> Result<u8> {
match self {
Self::U8(v) => Ok(*v),
v => crate::bail!("not a u8 {v:?}"),
}
}
pub fn to_i8(&self) -> Result<i8> {
match self {
Self::I8(v) => Ok(*v),
v => crate::bail!("not a i8 {v:?}"),
}
}
pub fn to_u16(&self) -> Result<u16> {
match self {
Self::U16(v) => Ok(*v),
v => crate::bail!("not a u16 {v:?}"),
}
}
pub fn to_i16(&self) -> Result<i16> {
match self {
Self::I16(v) => Ok(*v),
v => crate::bail!("not a i16 {v:?}"),
}
}
pub fn to_u32(&self) -> Result<u32> {
match self {
Self::U32(v) => Ok(*v),
v => crate::bail!("not a u32 {v:?}"),
}
}
pub fn to_i32(&self) -> Result<i32> {
match self {
Self::I32(v) => Ok(*v),
v => crate::bail!("not a i32 {v:?}"),
}
}
/// This will also automatically upcast any integral types which will not truncate.
pub fn to_u64(&self) -> Result<u64> {
match self {
Self::U64(v) => Ok(*v),
// Autoupcast cases here
Self::U8(v) => Ok(*v as u64),
Self::U16(v) => Ok(*v as u64),
Self::U32(v) => Ok(*v as u64),
Self::Bool(v) => Ok(*v as u64),
v => crate::bail!("not a u64 or upcastable to u64 {v:?}"),
}
}
pub fn to_i64(&self) -> Result<i64> {
match self {
Self::I64(v) => Ok(*v),
v => crate::bail!("not a i64 {v:?}"),
}
}
pub fn to_f32(&self) -> Result<f32> {
match self {
Self::F32(v) => Ok(*v),
v => crate::bail!("not a f32 {v:?}"),
}
}
pub fn to_f64(&self) -> Result<f64> {
match self {
Self::F64(v) => Ok(*v),
v => crate::bail!("not a f64 {v:?}"),
}
}
pub fn to_bool(&self) -> Result<bool> {
match self {
Self::Bool(v) => Ok(*v),
v => crate::bail!("not a bool {v:?}"),
}
}
pub fn to_vec(&self) -> Result<&Vec<Value>> {
match self {
Self::Array(v) => Ok(v),
v => crate::bail!("not a vec {v:?}"),
}
}
pub fn to_string(&self) -> Result<&String> {
match self {
Self::String(v) => Ok(v),
v => crate::bail!("not a string {v:?}"),
}
}
fn read<R: std::io::Read>(
reader: &mut R,
value_type: ValueType,
magic: &VersionedMagic,
) -> Result<Self> {
let v = match value_type {
ValueType::U8 => Self::U8(reader.read_u8()?),
ValueType::I8 => Self::I8(reader.read_i8()?),
ValueType::U16 => Self::U16(reader.read_u16::<LittleEndian>()?),
ValueType::I16 => Self::I16(reader.read_i16::<LittleEndian>()?),
ValueType::U32 => Self::U32(reader.read_u32::<LittleEndian>()?),
ValueType::I32 => Self::I32(reader.read_i32::<LittleEndian>()?),
ValueType::U64 => Self::U64(reader.read_u64::<LittleEndian>()?),
ValueType::I64 => Self::I64(reader.read_i64::<LittleEndian>()?),
ValueType::F32 => Self::F32(reader.read_f32::<LittleEndian>()?),
ValueType::F64 => Self::F64(reader.read_f64::<LittleEndian>()?),
ValueType::Bool => match reader.read_u8()? {
0 => Self::Bool(false),
1 => Self::Bool(true),
b => crate::bail!("unexpected bool value {b}"),
},
ValueType::String => Self::String(read_string(reader, magic)?),
ValueType::Array => {
let value_type = reader.read_u32::<LittleEndian>()?;
let value_type = ValueType::from_u32(value_type)?;
let len = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let mut vs = Vec::with_capacity(len);
for _ in 0..len {
vs.push(Value::read(reader, value_type, magic)?)
}
Self::Array(vs)
}
};
Ok(v)
}
fn write<W: std::io::Write>(&self, w: &mut W) -> Result<()> {
match self {
&Self::U8(v) => w.write_u8(v)?,
&Self::I8(v) => w.write_i8(v)?,
&Self::U16(v) => w.write_u16::<LittleEndian>(v)?,
&Self::I16(v) => w.write_i16::<LittleEndian>(v)?,
&Self::U32(v) => w.write_u32::<LittleEndian>(v)?,
&Self::I32(v) => w.write_i32::<LittleEndian>(v)?,
&Self::U64(v) => w.write_u64::<LittleEndian>(v)?,
&Self::I64(v) => w.write_i64::<LittleEndian>(v)?,
&Self::F32(v) => w.write_f32::<LittleEndian>(v)?,
&Self::F64(v) => w.write_f64::<LittleEndian>(v)?,
&Self::Bool(v) => w.write_u8(u8::from(v))?,
Self::String(v) => write_string(w, v.as_str())?,
Self::Array(v) => {
// The `Value` type does not enforce that all the values in an Array have the same
// type.
let value_type = if v.is_empty() {
// Doesn't matter, the array is empty.
ValueType::U32
} else {
let value_type: std::collections::HashSet<_> =
v.iter().map(|elem| elem.value_type()).collect();
if value_type.len() != 1 {
crate::bail!("multiple value-types in the same array {value_type:?}")
}
value_type.into_iter().next().context("empty value_type")?
};
w.write_u32::<LittleEndian>(value_type.to_u32())?;
w.write_u64::<LittleEndian>(v.len() as u64)?;
for elem in v.iter() {
elem.write(w)?
}
}
}
Ok(())
}
}
impl ValueType {
fn from_u32(v: u32) -> Result<Self> {
let v = match v {
0 => Self::U8,
1 => Self::I8,
2 => Self::U16,
3 => Self::I16,
4 => Self::U32,
5 => Self::I32,
6 => Self::F32,
7 => Self::Bool,
8 => Self::String,
9 => Self::Array,
10 => Self::U64,
11 => Self::I64,
12 => Self::F64,
v => crate::bail!("unrecognized value-type {v:#08x}"),
};
Ok(v)
}
fn to_u32(self) -> u32 {
match self {
Self::U8 => 0,
Self::I8 => 1,
Self::U16 => 2,
Self::I16 => 3,
Self::U32 => 4,
Self::I32 => 5,
Self::F32 => 6,
Self::Bool => 7,
Self::String => 8,
Self::Array => 9,
Self::U64 => 10,
Self::I64 => 11,
Self::F64 => 12,
}
}
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Self> {
let magic = VersionedMagic::read(reader)?;
let tensor_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let metadata_kv_count = match magic {
VersionedMagic::GgufV1 => reader.read_u32::<LittleEndian>()? as usize,
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
reader.read_u64::<LittleEndian>()? as usize
}
};
let mut metadata = HashMap::new();
for _idx in 0..metadata_kv_count {
let key = read_string(reader, &magic)?;
let value_type = reader.read_u32::<LittleEndian>()?;
let value_type = ValueType::from_u32(value_type)?;
let value = Value::read(reader, value_type, &magic)?;
metadata.insert(key, value);
}
let mut tensor_infos = HashMap::new();
for _idx in 0..tensor_count {
let tensor_name = read_string(reader, &magic)?;
let n_dimensions = reader.read_u32::<LittleEndian>()?;
let mut dimensions: Vec<usize> = match magic {
VersionedMagic::GgufV1 => {
let mut dimensions = vec![0; n_dimensions as usize];
reader.read_u32_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()
}
VersionedMagic::GgufV2 | VersionedMagic::GgufV3 => {
let mut dimensions = vec![0; n_dimensions as usize];
reader.read_u64_into::<LittleEndian>(&mut dimensions)?;
dimensions.into_iter().map(|c| c as usize).collect()
}
};
dimensions.reverse();
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
let offset = reader.read_u64::<LittleEndian>()?;
tensor_infos.insert(
tensor_name,
TensorInfo {
shape: crate::Shape::from(dimensions),
offset,
ggml_dtype,
},
);
}
let position = reader.stream_position()?;
let alignment = match metadata.get("general.alignment") {
Some(Value::U8(v)) => *v as u64,
Some(Value::U16(v)) => *v as u64,
Some(Value::U32(v)) => *v as u64,
Some(Value::I8(v)) if *v >= 0 => *v as u64,
Some(Value::I16(v)) if *v >= 0 => *v as u64,
Some(Value::I32(v)) if *v >= 0 => *v as u64,
_ => DEFAULT_ALIGNMENT,
};
let tensor_data_offset = position.div_ceil(alignment) * alignment;
Ok(Self {
magic,
metadata,
tensor_infos,
tensor_data_offset,
})
}
pub fn tensor<R: std::io::Seek + std::io::Read>(
&self,
reader: &mut R,
name: &str,
device: &Device,
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor info for {name}"),
};
tensor_info.read(reader, self.tensor_data_offset, device)
}
}
fn write_string<W: std::io::Write>(w: &mut W, str: &str) -> Result<()> {
let bytes = str.as_bytes();
w.write_u64::<LittleEndian>(bytes.len() as u64)?;
w.write_all(bytes)?;
Ok(())
}
pub fn write<W: std::io::Seek + std::io::Write>(
w: &mut W,
metadata: &[(&str, &Value)],
tensors: &[(&str, &QTensor)],
) -> Result<()> {
w.write_u32::<LittleEndian>(0x46554747)?;
w.write_u32::<LittleEndian>(2)?; // version 2.
w.write_u64::<LittleEndian>(tensors.len() as u64)?;
w.write_u64::<LittleEndian>(metadata.len() as u64)?;
for (name, value) in metadata.iter() {
write_string(w, name)?;
w.write_u32::<LittleEndian>(value.value_type().to_u32())?;
value.write(w)?;
}
let mut offset = 0usize;
let mut offsets = Vec::with_capacity(tensors.len());
for (name, tensor) in tensors.iter() {
write_string(w, name)?;
let dims = tensor.shape().dims();
w.write_u32::<LittleEndian>(dims.len() as u32)?;
for &dim in dims.iter().rev() {
w.write_u64::<LittleEndian>(dim as u64)?;
}
w.write_u32::<LittleEndian>(tensor.dtype().to_u32())?;
w.write_u64::<LittleEndian>(offset as u64)?;
offsets.push(offset);
let size_in_bytes = tensor.storage_size_in_bytes();
let padding = 31 - (31 + size_in_bytes) % 32;
offset += size_in_bytes + padding;
}
let pos = w.stream_position()? as usize;
let padding = 31 - (31 + pos) % 32;
w.write_all(&vec![0u8; padding])?;
let tensor_start_pos = w.stream_position()? as usize;
for (offset, (_name, tensor)) in offsets.iter().zip(tensors.iter()) {
let pos = w.stream_position()? as usize;
if tensor_start_pos + offset != pos {
crate::bail!(
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
)
}
let data = tensor.data()?;
let size_in_bytes = data.len();
w.write_all(&data)?;
let padding = 31 - (31 + size_in_bytes) % 32;
w.write_all(&vec![0u8; padding])?;
}
Ok(())
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,230 @@
use super::{GgmlDType, QStorage};
use crate::backend::BackendStorage;
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
use metal::Buffer;
use std::sync::Arc;
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
buffer: Arc<Buffer>,
}
impl QMetalStorage {
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
let size = elem_count * dtype.type_size() / dtype.block_size();
let buffer = device.allocate_zeros(size)?;
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
use crate::quantized::k_quants::GgmlType;
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
GgmlDType::F32 => {
let vec: Vec<f32> = read_to_vec(&buffer, block_len);
f32::to_float(&vec, &mut out)?;
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
}
let buffer = self.device.new_buffer_with_data(&out)?;
Ok(MetalStorage::new(
buffer,
self.device.clone(),
elem_count,
DType::F32,
))
}
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
// Quantization only happens on CPU for now.
let src = src.to_cpu::<f32>()?;
let elem_count = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
qcpu_storage.quantize(&src)?;
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
self.buffer = buffer;
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.buffer.length() as usize
}
pub fn fwd(
&self,
self_shape: &Shape,
storage: &MetalStorage,
layout: &crate::Layout,
) -> Result<(MetalStorage, Shape)> {
use crate::MetalError;
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
// We always use a single batch dimension and stack all the tensors in the batch on the
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
// properly.
let m = match dst_shape.len() {
3 => dst_shape[0] * dst_shape[1],
2 => dst_shape[0],
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;
// In some cases it would be better to use the mm variant, though it has its drawbacks
// around memory alignemnt.
for batch_id in 0..m {
candle_metal_kernels::call_quantized_matmul_mv_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(1, 1, n, k),
storage.buffer(),
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
&self.buffer,
batch_id * n * DType::F32.size_in_bytes(),
&dst,
)
.map_err(MetalError::from)?;
}
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
Ok((dst_storage, dst_shape))
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
device: &MetalDevice,
data: &[T],
) -> Result<QStorage> {
let buffer = device.new_buffer_with_data(data)?;
let device = device.clone();
Ok(QStorage::Metal(QMetalStorage {
dtype: T::DTYPE,
device,
buffer,
}))
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec()
}
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
fn from(value: GgmlDType) -> Self {
match value {
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
}
}
}

View File

@ -0,0 +1,550 @@
//! Code for GGML and GGUF files
use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
#[cfg(target_feature = "avx")]
pub mod avx;
mod dummy_cuda;
mod dummy_metal;
pub mod ggml_file;
pub mod gguf_file;
pub mod k_quants;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(not(feature = "metal"))]
mod metal {
pub use super::dummy_metal::*;
}
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(not(feature = "cuda"))]
mod cuda {
pub use super::dummy_cuda::*;
}
#[cfg(target_feature = "neon")]
pub mod neon;
#[cfg(target_feature = "simd128")]
pub mod simd128;
pub mod utils;
use half::f16;
pub use k_quants::GgmlType;
pub struct QTensor {
storage: QStorage,
shape: Shape,
}
impl Device {
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
match self {
Device::Cpu => {
let storage = dtype.cpu_zeros(elem_count);
Ok(QStorage::Cpu(storage))
}
Device::Metal(metal) => {
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
Ok(QStorage::Metal(storage))
}
Device::Cuda(cuda) => {
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
Ok(QStorage::Cuda(storage))
}
}
}
}
pub enum QStorage {
Cpu(Box<dyn QuantizedType>),
Metal(metal::QMetalStorage),
Cuda(cuda::QCudaStorage),
}
impl QStorage {
fn block_size(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.block_size(),
QStorage::Metal(storage) => storage.dtype().block_size(),
QStorage::Cuda(storage) => storage.dtype().block_size(),
}
}
fn dtype(&self) -> GgmlDType {
match self {
QStorage::Cpu(storage) => storage.dtype(),
QStorage::Metal(storage) => storage.dtype(),
QStorage::Cuda(storage) => storage.dtype(),
}
}
fn device(&self) -> Device {
match self {
QStorage::Cpu(_storage) => Device::Cpu,
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
}
}
fn size_in_bytes(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
}
}
fn quantize(&mut self, src: &Storage) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?;
}
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"),
}
Ok(())
}
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
match self {
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
}
}
fn data(&self) -> Result<Cow<[u8]>> {
match self {
QStorage::Cpu(storage) => {
let data_ptr = storage.as_ptr();
let size_in_bytes = storage.storage_size_in_bytes();
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
Ok(Cow::from(data))
}
QStorage::Metal(_) | QStorage::Cuda(_) => {
crate::bail!("not implemented");
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GgmlDType {
F32,
F16,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
}
impl GgmlDType {
pub(crate) fn from_u32(u: u32) -> Result<Self> {
let dtype = match u {
0 => Self::F32,
1 => Self::F16,
2 => Self::Q4_0,
3 => Self::Q4_1,
6 => Self::Q5_0,
7 => Self::Q5_1,
8 => Self::Q8_0,
9 => Self::Q8_1,
10 => Self::Q2K,
11 => Self::Q3K,
12 => Self::Q4K,
13 => Self::Q5K,
14 => Self::Q6K,
15 => Self::Q8K,
_ => crate::bail!("unknown dtype for tensor {u}"),
};
Ok(dtype)
}
pub(crate) fn to_u32(self) -> u32 {
match self {
Self::F32 => 0,
Self::F16 => 1,
Self::Q4_0 => 2,
Self::Q4_1 => 3,
Self::Q5_0 => 6,
Self::Q5_1 => 7,
Self::Q8_0 => 8,
Self::Q8_1 => 9,
Self::Q2K => 10,
Self::Q3K => 11,
Self::Q4K => 12,
Self::Q5K => 13,
Self::Q6K => 14,
Self::Q8K => 15,
}
}
/// The block dtype
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
match self {
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
}
}
/// The type size for blocks in bytes.
pub fn type_size(&self) -> usize {
use k_quants::*;
match self {
Self::F32 => 4,
Self::F16 => 2,
Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
Self::Q2K => std::mem::size_of::<BlockQ2K>(),
Self::Q3K => std::mem::size_of::<BlockQ3K>(),
Self::Q4K => std::mem::size_of::<BlockQ4K>(),
Self::Q5K => std::mem::size_of::<BlockQ5K>(),
Self::Q6K => std::mem::size_of::<BlockQ6K>(),
Self::Q8K => std::mem::size_of::<BlockQ8K>(),
}
}
/// The block size, i.e. the number of elements stored in each block.
pub fn block_size(&self) -> usize {
match self {
Self::F32 => 1,
Self::F16 => 1,
Self::Q4_0 => k_quants::QK4_0,
Self::Q4_1 => k_quants::QK4_1,
Self::Q5_0 => k_quants::QK5_0,
Self::Q5_1 => k_quants::QK5_1,
Self::Q8_0 => k_quants::QK8_0,
Self::Q8_1 => k_quants::QK8_1,
Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
}
}
}
// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
pub trait QuantizedType: Send + Sync {
fn dtype(&self) -> GgmlDType;
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
fn storage_size_in_bytes(&self) -> usize;
fn as_ptr(&self) -> *const u8;
fn block_size(&self) -> usize;
#[allow(clippy::wrong_self_convention)]
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
fn size(&self) -> usize;
}
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
}
fn size(&self) -> usize {
self.len() * core::mem::size_of::<T>()
}
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
T::from_float(xs, self)
}
fn dtype(&self) -> GgmlDType {
T::DTYPE
}
fn block_size(&self) -> usize {
T::BLCK_SIZE
}
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
let mut ys = vec![0.0f32; elem_count];
T::to_float(self.as_slice(), &mut ys)?;
Ok(CpuStorage::F32(ys))
}
fn storage_size_in_bytes(&self) -> usize {
self.len() * std::mem::size_of::<T>()
}
fn as_ptr(&self) -> *const u8 {
self.as_ptr() as *const u8
}
}
impl std::fmt::Debug for QTensor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
}
}
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
let dims = shape.dims();
if dims.is_empty() {
crate::bail!("scalar tensor cannot be quantized {shape:?}")
}
if dims[dims.len() - 1] % block_size != 0 {
crate::bail!(
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
block_size
)
}
Ok(())
}
impl QTensor {
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
let shape = shape.into();
check_shape(&shape, storage.block_size())?;
Ok(Self { storage, shape })
}
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
let shape = src.shape();
let block_size = dtype.block_size();
check_shape(shape, block_size)?;
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
let elem_count = shape.elem_count();
if elem_count % block_size != 0 {
crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}",
block_size
)
}
let mut storage = src.device().qzeros(elem_count, dtype)?;
storage.quantize(&src.storage())?;
Ok(Self {
storage,
shape: shape.clone(),
})
}
pub fn dtype(&self) -> GgmlDType {
self.storage.dtype()
}
pub fn device(&self) -> Device {
self.storage.device()
}
pub fn rank(&self) -> usize {
self.shape.rank()
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
}
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
// In the CUDA case, we have a specialized kernel as this can be useful for volta
// architectures. https://github.com/huggingface/candle/issues/2136
match &self.storage {
QStorage::Cuda(s) => {
let s = s.dequantize_f16(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
.to_device(device)
}
_ => {
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
Ok(s)
}
}
}
pub fn storage_size_in_bytes(&self) -> usize {
self.storage.size_in_bytes()
}
pub fn data(&self) -> Result<Cow<'_, [u8]>> {
self.storage.data()
}
}
#[derive(Clone, Debug)]
pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
TensorF16(Tensor),
}
thread_local! {
static DEQUANTIZE_ALL: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
thread_local! {
static DEQUANTIZE_ALL_F16: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let dequantize = match qtensor.dtype() {
GgmlDType::F32 | GgmlDType::F16 => true,
_ => DEQUANTIZE_ALL.with(|b| *b),
};
let t = if dequantize {
let tensor = qtensor.dequantize(&qtensor.device())?;
Self::Tensor(tensor)
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
Self::TensorF16(tensor)
} else {
Self::QTensor(qtensor)
};
Ok(t)
}
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
Self::from_arc(std::sync::Arc::new(qtensor))
}
pub fn dequantize_f16(&self) -> Result<Tensor> {
match self {
Self::QTensor(t) => t.dequantize_f16(&t.device()),
Self::Tensor(t) => t.to_dtype(DType::F16),
Self::TensorF16(t) => Ok(t.clone()),
}
}
pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
let w = self.dequantize_f16()?;
let in_dtype = xs.dtype();
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}
}
impl crate::CustomOp1 for QTensor {
fn name(&self) -> &'static str {
"qmatmul"
}
fn cpu_fwd(
&self,
storage: &crate::CpuStorage,
layout: &crate::Layout,
) -> Result<(crate::CpuStorage, Shape)> {
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
let (n, k) = self.shape.dims2()?;
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().context("empty dst_shape")?;
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
#[allow(clippy::infallible_destructuring_match)]
let self_storage = match &self.storage {
QStorage::Cpu(storage) => storage,
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
};
let slice = storage.as_slice::<f32>()?;
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
}
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
let self_storage = match &self.storage {
QStorage::Metal(metal) => metal,
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
};
self_storage.fwd(&self.shape, storage, layout)
}
fn cuda_fwd(
&self,
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, Shape)> {
let self_storage = match &self.storage {
QStorage::Cuda(cuda) => cuda,
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
};
self_storage.fwd(&self.shape, storage, layout)
}
}
impl crate::Module for QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(w) => {
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.matmul(&w)
}
Self::TensorF16(w) => {
let in_dtype = xs.dtype();
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
}
}
}
}

View File

@ -0,0 +1,613 @@
use super::k_quants::{
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
#[allow(unused_imports)]
#[cfg(target_arch = "arm")]
use core::arch::arm::*;
#[allow(unused_imports)]
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[inline(always)]
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
// TODO: dotprod
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
}
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
for i in 0..nb {
let x0 = &xs[i];
let y0 = &ys[i];
let m4b = vdupq_n_u8(0x0F);
let s8b = vdupq_n_s8(0x8);
let v0_0 = vld1q_u8(x0.qs.as_ptr());
// 4-bit -> 8-bit
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
// sub 8
let v0_0ls = vsubq_s8(v0_0l, s8b);
let v0_0hs = vsubq_s8(v0_0h, s8b);
// load y
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
let pl0 = vdotq_s32(v0_0ls, v1_0l);
let ph0 = vdotq_s32(v0_0hs, v1_0h);
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
x0.d.to_f32() * y0.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
for i in 0..nb {
let x0 = &xs[i];
let y0 = &ys[i];
let x0_0 = vld1q_s8(x0.qs.as_ptr());
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
// load y
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
let p0 = vdotq_s32(x0_0, y0_0);
let p1 = vdotq_s32(x0_1, y0_1);
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(p0, p1)),
x0.d.to_f32() * y0.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0))
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % QK_K != 0 {
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
}
let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) {
unsafe {
let mut sum_i = vdupq_n_s32(0);
let scale = xs.d * ys.d;
let xs = xs.qs.as_ptr();
let ys = ys.qs.as_ptr();
for i in (0..QK_K).step_by(16) {
let xs = vld1q_s8(xs.add(i));
let ys = vld1q_s8(ys.add(i));
let xy = vdotq_s32(xs, ys);
sum_i = vaddq_s32(sum_i, xy)
}
sumf += vaddvq_s32(sum_i) as f32 * scale
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
}
let mut sum = 0f32;
unsafe {
let m4b = vdupq_n_u8(0xF);
let mone = vdupq_n_u8(3);
for (x, y) in xs.iter().zip(ys.iter()) {
let d_all = x.d.to_f32();
let mut q6 = x.ql.as_ptr();
let mut qh = x.qh.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut scale = x.scales.as_ptr();
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
let scales = vld1q_s8(scale);
let q6scales = int16x8x2_t(
vmovl_s8(vget_low_s8(scales)),
vmovl_s8(vget_high_s8(scales)),
);
let prod = vaddq_s32(
vaddq_s32(
vmull_s16(vget_low_s16(q8sums.0), vget_low_s16(q6scales.0)),
vmull_s16(vget_high_s16(q8sums.0), vget_high_s16(q6scales.0)),
),
vaddq_s32(
vmull_s16(vget_low_s16(q8sums.1), vget_low_s16(q6scales.1)),
vmull_s16(vget_high_s16(q8sums.1), vget_high_s16(q6scales.1)),
),
);
let isum_mins = vaddvq_s32(prod);
let mut isum = 0i32;
for _j in 0..QK_K / 128 {
let qhbits = vld1q_u8_x2(qh);
qh = qh.add(32);
let q6bits = vld1q_u8_x4(q6);
q6 = q6.add(64);
let q8bytes = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q6h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
let q6h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
let shifted = vshrq_n_u8(qhbits.0, 2);
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.1, 2);
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.0, m4b), q6h_0));
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.1, m4b), q6h_1));
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
let q8bytes = vld1q_s8_x4(q8);
q8 = q8.add(64);
let shifted = vshrq_n_u8(qhbits.0, 4);
let q6h_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.1, 4);
let q6h_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.0, 6);
let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let shifted = vshrq_n_u8(qhbits.1, 6);
let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4);
let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.0, 4), q6h_0));
let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.1, 4), q6h_1));
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
}
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
}
}
Ok(sum)
}
#[inline(always)]
pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut utmp = [0u32; 4];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4b = vdupq_n_u8(0xF);
let mone = vdupq_n_u8(1);
let mtwo = vdupq_n_u8(2);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();
let q8sums = vpaddq_s16(
vld1q_s16(y.bsums.as_ptr()),
vld1q_s16(y.bsums.as_ptr().add(8)),
);
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8));
let mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
let prod = vaddq_s32(
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
);
let sumi_mins = vaddvq_s32(prod);
let mut scales = utmp.as_ptr() as *const u8;
let mut q5 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut qhbits = vld1q_u8_x2(x.qh.as_ptr());
let mut sumi = 0i32;
for _j in 0..QK_K / 64 {
let q5bits = vld1q_u8_x2(q5);
q5 = q5.add(32);
let q8bytes = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3);
let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3);
qhbits.0 = vshrq_n_u8(qhbits.0, 2);
qhbits.1 = vshrq_n_u8(qhbits.1, 2);
let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0));
let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1));
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
scales = scales.add(1);
let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
scales = scales.add(1);
}
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut utmp = [0u32; 4];
let mut scales = [0u8; 16];
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
unsafe {
let m4b = vdupq_n_u8(0xF);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();
let q8sums = vpaddq_s16(
vld1q_s16(y.bsums.as_ptr()),
vld1q_s16(y.bsums.as_ptr().add(8)),
);
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
let mins8 = vld1_u32(
[
utmp[1] & KMASK1,
((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4),
]
.as_ptr(),
);
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[0] &= KMASK1;
let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
let prod = vaddq_s32(
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
);
sumf -= dmin * vaddvq_s32(prod) as f32;
LittleEndian::write_u32_into(&utmp, &mut scales);
let mut q4 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut sumi1 = 0i32;
let mut sumi2 = 0i32;
for j in 0..QK_K / 64 {
let q4bits = vld1q_u8_x2(q4);
q4 = q4.add(32);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
);
let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
);
let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
}
sumf += d * (sumi1 + sumi2) as f32;
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut utmp = [0u32; 4];
let mut aux = [0u32; 3];
const KMASK1: u32 = 0x03030303;
const KMASK2: u32 = 0x0f0f0f0f;
unsafe {
let m3b = vdupq_n_u8(0x3);
let m0 = vdupq_n_u8(1);
let m1 = vshlq_n_u8(m0, 1);
let m2 = vshlq_n_u8(m0, 2);
let m3 = vshlq_n_u8(m0, 3);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let mut q3 = x.qs.as_ptr();
let qh = x.hmask.as_ptr();
let mut q8 = y.qs.as_ptr();
let mut qhbits = vld1q_u8_x2(qh);
let mut isum = 0i32;
// Set up scales
LittleEndian::read_u32_into(&x.scales, &mut aux);
utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4);
utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4);
utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4);
utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4);
let mut scale = utmp.as_mut_ptr() as *mut i8;
for j in 0..16 {
*scale.add(j) -= 32i8
}
for j in 0..QK_K / 128 {
let q3bits = vld1q_u8_x2(q3);
q3 = q3.add(32);
let q8bytes_1 = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q8bytes_2 = vld1q_s8_x4(q8);
q8 = q8.add(64);
let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2);
let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2);
let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1);
let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1);
let q3bytes_0 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)),
vreinterpretq_s8_u8(q3h_0),
);
let q3bytes_1 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)),
vreinterpretq_s8_u8(q3h_1),
);
let q3bytes_2 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)),
vreinterpretq_s8_u8(q3h_2),
);
let q3bytes_3 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)),
vreinterpretq_s8_u8(q3h_3),
);
let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
isum += vaddvq_s32(p0) * *scale as i32
+ vaddvq_s32(p1) * *scale.add(1) as i32
+ vaddvq_s32(p2) * *scale.add(2) as i32
+ vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
let q3h_0 = vbicq_u8(m2, qhbits.0);
let q3h_1 = vbicq_u8(m2, qhbits.1);
let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1);
let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1);
let q3bytes_0 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)),
vreinterpretq_s8_u8(q3h_0),
);
let q3bytes_1 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)),
vreinterpretq_s8_u8(q3h_1),
);
let q3bytes_2 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)),
vreinterpretq_s8_u8(q3h_2),
);
let q3bytes_3 = vsubq_s8(
vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)),
vreinterpretq_s8_u8(q3h_3),
);
let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
isum += vaddvq_s32(p0) * *scale as i32
+ vaddvq_s32(p1) * *scale.add(1) as i32
+ vaddvq_s32(p2) * *scale.add(2) as i32
+ vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
if j == 0 {
qhbits.0 = vshrq_n_u8(qhbits.0, 4);
qhbits.1 = vshrq_n_u8(qhbits.1, 4);
}
}
sumf += d * isum as f32;
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut aux = [0u8; 16];
unsafe {
let m3 = vdupq_n_u8(0x3);
let m4 = vdupq_n_u8(0xF);
for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = -y.d * x.dmin.to_f32();
let mut q2 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
let sc = x.scales.as_ptr();
let mins_and_scales = vld1q_u8(sc);
let scales = vandq_u8(mins_and_scales, m4);
vst1q_u8(aux.as_mut_ptr(), scales);
let mins = vshrq_n_u8(mins_and_scales, 4);
let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
let mins16 = int16x8x2_t(
vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))),
vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))),
);
let s0 = vaddq_s32(
vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)),
vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)),
);
let s1 = vaddq_s32(
vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)),
vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)),
);
sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32;
let mut isum = 0i32;
let mut is = 0usize;
// TODO: dotprod
for _j in 0..QK_K / 128 {
let q2bits = vld1q_u8_x2(q2);
q2 = q2.add(32);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let mut q2bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)),
vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)),
);
isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3));
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3));
isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3));
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3));
isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes);
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3));
q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3));
isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes);
is += 8;
}
sumf += d * isum as f32;
}
}
Ok(sumf)
}
#[inline(always)]
unsafe fn multiply_accum_with_scale(
aux: &[u8; 16],
is: usize,
index: usize,
q2bytes: int8x16x2_t,
q8bytes: int8x16x2_t,
) -> i32 {
let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
}

View File

@ -0,0 +1,419 @@
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
use core::arch::wasm32::*;
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
let x12 = v128_and(x1234, u8x16_splat(0x0F));
let x12 = i8x16_sub(x12, i8x16_splat(8));
let x34 = u8x16_shr(x1234, 4);
let x34 = i8x16_sub(x34, i8x16_splat(8));
let x1 = i16x8_extend_low_i8x16(x12);
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
let sum_xy = i32x4_dot_i16x8(x1, y1);
let x2 = i16x8_extend_high_i8x16(x12);
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
let x3 = i16x8_extend_low_i8x16(x34);
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
let x4 = i16x8_extend_high_i8x16(x34);
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
let sum_xy = f32x4_convert_i32x4(sum_xy);
// f32x4_relaxed_madd is nightly only.
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
let scaled = f32x4_mul(sum_xy, d);
acc = f32x4_add(acc, scaled)
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr());
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
let sum_xy = i32x4_dot_i16x8(x1, y1);
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8));
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16));
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24));
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
let sum_xy = f32x4_convert_i32x4(sum_xy);
// f32x4_relaxed_madd is nightly only.
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
let scaled = f32x4_mul(sum_xy, d);
acc = f32x4_add(acc, scaled)
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
unsafe {
let mut sumf = f32x4_splat(0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let mut q2: &[_] = &x.qs;
let mut q8: &[_] = &y.qs;
let sc = &x.scales;
let mut summs = i32x4_splat(0);
for i in (0..(QK_K / 16)).step_by(4) {
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
let scales = i32x4_shr(
i32x4(
sc[i] as i32,
sc[i + 1] as i32,
sc[i + 2] as i32,
sc[i + 3] as i32,
),
4,
);
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
}
let summs = f32x4_convert_i32x4(summs);
let dall = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();
let mut isum = i32x4_splat(0);
let mut is = 0;
for _ in 0..(QK_K / 128) {
let mut shift = 0;
for _ in 0..4 {
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = i16x8_splat(0);
for l in (0..16).step_by(8) {
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
}
let dd = i32x4_splat(d);
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = i16x8_splat(0);
for l in (16..32).step_by(8) {
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
}
let dd = i32x4_splat(d);
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
shift += 2;
// adjust the indexing
q8 = &q8[32..];
}
// adjust the indexing
q2 = &q2[32..];
}
let isum = f32x4_convert_i32x4(isum);
sumf = f32x4_add(
sumf,
f32x4_sub(
f32x4_mul(isum, f32x4_splat(dall)),
f32x4_mul(summs, f32x4_splat(dmin)),
),
);
}
let sumf = f32x4_extract_lane::<0>(sumf)
+ f32x4_extract_lane::<1>(sumf)
+ f32x4_extract_lane::<2>(sumf)
+ f32x4_extract_lane::<3>(sumf);
Ok(sumf)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
let mut utmp: [u32; 4] = [0; 4];
let mut scales: [u8; 8] = [0; 8];
let mut mins: [u8; 8] = [0; 8];
let mut aux8: [u8; QK_K] = [0; QK_K];
let mut sums = f32x4_splat(0f32);
unsafe {
for (y, x) in ys.iter().zip(xs.iter()) {
let q4 = &x.qs;
let q8 = &y.qs;
for j in 0..QK_K / 64 {
let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128);
let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128);
v128_store(
aux8.as_mut_ptr().add(64 * j) as *mut v128,
v128_and(q4_1, u8x16_splat(0x0F)),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 16) as *mut v128,
v128_and(q4_2, u8x16_splat(0x0F)),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 32) as *mut v128,
u8x16_shr(q4_1, 4),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 48) as *mut v128,
u8x16_shr(q4_2, 4),
);
}
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
//extract scales and mins
LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
let mut sumi = i32x4_splat(0);
for j in (0..QK_K / 16).step_by(4) {
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j));
let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32);
let mins = i32x4(m1, m1, m2, m2);
sumi = i32x4_add(sumi, i32x4_mul(bsums, mins));
}
let mut aux32 = i32x4_splat(0i32);
for (scale_i, scale) in scales.iter().enumerate() {
let scale = i32x4_splat(*scale as i32);
for j in 0..4 {
let i = 32 * scale_i + 8 * j;
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i));
let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i));
let aux16 = i16x8_mul(q8, aux8);
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16)));
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16)));
}
}
let aux32 = f32x4_convert_i32x4(aux32);
let d = f32x4_splat(x.d.to_f32() * y.d);
sums = f32x4_add(sums, f32x4_mul(aux32, d));
let dmin = x.dmin.to_f32() * y.d;
let dmin = f32x4_splat(dmin);
let sumi = f32x4_convert_i32x4(sumi);
sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
}
let sums = f32x4_extract_lane::<0>(sums)
+ f32x4_extract_lane::<1>(sums)
+ f32x4_extract_lane::<2>(sums)
+ f32x4_extract_lane::<3>(sums);
Ok(sums)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
}
let mut aux8 = [0i8; QK_K];
unsafe {
let mut sums = f32x4_splat(0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let q4 = &x.ql;
let qh = &x.qh;
let q8 = &y.qs;
let mut aux32 = f32x4_splat(0f32);
for j in (0..QK_K).step_by(128) {
let aux8 = aux8.as_mut_ptr().add(j);
let q4 = &q4.as_ptr().add(j / 2);
let qh = &qh.as_ptr().add(j / 4);
for l in (0..32).step_by(16) {
// aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)),
u8x16_shl(
v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 32] =
// (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 2),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 32) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
u8x16_shr(v128_load(q4.add(l) as *const v128), 4),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 4),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 64) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 96] =
// (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 6),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 96) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
}
}
for (j, &scale) in x.scales.iter().enumerate() {
let scale = f32x4_splat(scale as f32);
for offset in [0, 8] {
let aux16 = i16x8_mul(
i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)),
i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)),
);
aux32 = f32x4_add(
aux32,
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale),
);
aux32 = f32x4_add(
aux32,
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale),
);
}
}
let d = f32x4_splat(x.d.to_f32() * y.d);
sums = f32x4_add(sums, f32x4_mul(aux32, d));
}
let sums = f32x4_extract_lane::<0>(sums)
+ f32x4_extract_lane::<1>(sums)
+ f32x4_extract_lane::<2>(sums)
+ f32x4_extract_lane::<3>(sums);
Ok(sums)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % QK_K != 0 {
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (xs, ys) in xs.iter().zip(ys.iter()) {
let x_qs = xs.qs.as_ptr();
let y_qs = ys.qs.as_ptr();
let mut sumi = i32x4_splat(0);
for j in (0..QK_K).step_by(8) {
let xs = i16x8_load_extend_i8x8(x_qs.add(j));
let ys = i16x8_load_extend_i8x8(y_qs.add(j));
let sum_xy = i32x4_dot_i16x8(xs, ys);
sumi = i32x4_add(sumi, sum_xy)
}
let d = f32x4_splat(xs.d * ys.d);
acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d))
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}

View File

@ -0,0 +1,326 @@
use crate::Result;
pub(super) fn nearest_int(v: f32) -> i32 {
v.round() as i32
}
/// Validates that the input and output are the right size and returns an iterator which maps each
/// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed
/// to be `T::BLCK_SIZE` long.
pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
xs: &'b [f32],
ys: &'a mut [T],
) -> Result<Vec<(&'a mut T, &'b [f32])>> {
let block_size = T::BLCK_SIZE;
let dtype = T::DTYPE;
let expected_blocks = xs.len() / block_size;
let actual_blocks = ys.len();
// Validate that the input is the right size
if expected_blocks != actual_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
}
Ok(ys.iter_mut().zip(xs.chunks_exact(block_size)).collect())
}
/// Validates that the input and output are the right size and returns an iterator which maps each
/// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed
/// to be `T::BLCK_SIZE` long.
pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
xs: &'a [T],
ys: &'b mut [f32],
) -> Result<Vec<(&'a T, &'b mut [f32])>> {
let block_size = T::BLCK_SIZE;
let dtype = T::DTYPE;
let actual_output_len = ys.len();
let expected_output_len = xs.len() * block_size;
// Validate that the output is the right size
if expected_output_len != actual_output_len {
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
}
// Zip the blocks and outputs together
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
}
pub(super) fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
if j < 4 {
let d = q[j] & 63;
let m = q[j + 4] & 63;
(d, m)
} else {
let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
(d, m)
}
}
pub(super) unsafe fn make_qx_quants(
n: usize,
nmax: i32,
x: *const f32,
ls: *mut i8,
rmse_type: i32,
) -> f32 {
let mut max = 0f32;
let mut amax = 0f32;
for i in 0..n {
let x = *x.add(i);
let ax = x.abs();
if ax > amax {
amax = ax;
max = x;
}
}
if amax == 0. {
// all zero
for i in 0..n {
*ls.add(i) = 0;
}
return 0.;
}
let mut iscale = -(nmax as f32) / max;
if rmse_type == 0 {
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
return 1.0 / iscale;
}
let weight_type = rmse_type % 2;
let mut sumlx = 0f32;
let mut suml2 = 0f32;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
*ls.add(i) = (l + nmax) as i8;
let w = if weight_type == 1 { x * x } else { 1.0 };
let l = l as f32;
sumlx += w * x * l;
suml2 += w * l * l;
}
let mut scale = sumlx / suml2;
let mut best = scale * sumlx;
for _itry in 0..3 {
let iscale = 1.0 / scale;
let mut slx = 0f32;
let mut sl2 = 0f32;
let mut changed = false;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
if l + nmax != *ls.add(i) as i32 {
changed = true;
}
let w = if weight_type == 1 { x * x } else { 1f32 };
let l = l as f32;
slx += w * x * l;
sl2 += w * l * l;
}
if !changed || sl2 == 0.0 || slx * slx <= best * sl2 {
break;
}
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
sumlx = slx;
suml2 = sl2;
scale = sumlx / suml2;
best = scale * sumlx;
}
for _itry in 0..5 {
let mut n_changed = 0;
for i in 0..n {
let x = *x.add(i);
let w = if weight_type == 1 { x * x } else { 1. };
let l = *ls.add(i) as i32 - nmax;
let mut slx = sumlx - w * x * l as f32;
if slx > 0. {
let mut sl2 = suml2 - w * l as f32 * l as f32;
let new_l = nearest_int(x * sl2 / slx);
let new_l = new_l.clamp(-nmax, nmax - 1);
if new_l != l {
slx += w * x * new_l as f32;
sl2 += w * new_l as f32 * new_l as f32;
if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 {
*ls.add(i) = (nmax + new_l) as i8;
sumlx = slx;
suml2 = sl2;
scale = sumlx / suml2;
best = scale * sumlx;
n_changed += 1;
}
}
}
}
if n_changed == 0 {
break;
}
}
if rmse_type < 3 {
return scale;
}
for is in -4..4 {
if is == 0 {
continue;
}
iscale = -(nmax as f32 + 0.1f32 * is as f32) / max;
let mut sumlx = 0.;
let mut suml2 = 0.;
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
let l = l.clamp(-nmax, nmax - 1);
let w = if weight_type == 1 { x * x } else { 1. };
let l = l as f32;
sumlx += w * x * l;
suml2 += w * l * l;
}
if suml2 > 0. && sumlx * sumlx > best * suml2 {
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
*ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
}
scale = sumlx / suml2;
best = scale * sumlx;
}
}
scale
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L224
pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) {
let n = x.len();
let mut l = vec![0; n];
// Get min/max
let min = *x
.iter()
.take(n)
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&x[0]);
let max = *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&x[0]);
// If min == max, all values are the same => nothing to do here
if max == min {
return (0.0, 0.0);
}
// Ensure min <= 0.0
let mut min = min.min(0.);
// Compute scale and inverse scale
let mut iscale = nmax as f32 / (max - min);
let mut scale = 1.0 / iscale;
for _ in 0..ntry {
let mut sumlx = 0.0;
let mut suml2 = 0;
let mut did_change = false;
for (i, value) in x.iter().enumerate().take(n) {
let li = nearest_int(iscale * (value - min)).clamp(0, nmax);
let clamped_li = li as u8;
if clamped_li != l[i] {
l[i] = clamped_li;
did_change = true;
}
sumlx += (value - min) * li as f32;
suml2 += li * li;
}
scale = sumlx / suml2 as f32;
let sum: f32 = x
.iter()
.take(n)
.zip(l.iter().take(n))
.map(|(xi, &li)| xi - scale * li as f32)
.sum();
min = sum / n as f32;
if min > 0.0 {
min = 0.0;
}
iscale = 1.0 / scale;
if !did_change {
break;
}
}
(scale, -min)
}
// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L165
pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
let n = x.len();
let mut l = vec![0i8; n];
let mut max = 0.0;
let mut amax = 0.0;
for &xi in x.iter().take(n) {
let ax = xi.abs();
if ax > amax {
amax = ax;
max = xi;
}
}
if amax == 0.0 {
return 0.0;
}
let iscale = -(nmax as f32) / max;
if do_rmse {
let mut sumlx = 0.0;
let mut suml2 = 0.0;
for i in 0..n {
let li = (iscale * x[i]).round() as i32;
let li = li.clamp(-nmax, nmax - 1);
l[i] = li as i8;
let w = x[i] * x[i];
sumlx += w * x[i] * li as f32;
suml2 += w * (li * li) as f32;
}
for _ in 0..5 {
let mut n_changed = 0;
for i in 0..n {
let w = x[i] * x[i];
let mut slx = sumlx - w * x[i] * l[i] as f32;
if slx > 0.0 {
let mut sl2 = suml2 - w * (l[i] as i32 * l[i] as i32) as f32;
let mut new_l = (x[i] * sl2 / slx).round() as i32;
new_l = new_l.clamp(-nmax, nmax - 1);
if new_l != l[i] as i32 {
slx += w * x[i] * new_l as f32;
sl2 += w * (new_l * new_l) as f32;
if sl2 > 0.0 && slx * slx * suml2 > sumlx * sumlx * sl2 {
l[i] = new_l as i8;
sumlx = slx;
suml2 = sl2;
n_changed += 1;
}
}
}
}
if n_changed == 0 {
break;
}
}
for li in l.iter_mut() {
*li += nmax as i8;
}
return sumlx / suml2;
}
for i in 0..n {
let li = (iscale * x[i]).round() as i32;
l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8;
}
1.0 / iscale
}

View File

@ -1,3 +1,14 @@
//! Module to load `safetensor` files into CPU/GPU memory.
//!
//! There are multiple ways to load tensors from safetensor files:
//! - `load` function for loading directly into memory and returning a HashMap of tensors
//! - `MmapedSafetensors` for memory mapping files and avoiding full allocation
//! - `SliceSafetensors` for working with in-memory buffers
//! - `BufferedSafetensors` for owning a buffer of data
//!
//! Tensors can also be serialized to safetensor format using the `save` function or
//! `Tensor::save_safetensors` method.
//!
use crate::{DType, Device, Error, Result, Tensor, WithDType};
use safetensors::tensor as st;
use safetensors::tensor::SafeTensors;
@ -10,6 +21,7 @@ impl From<DType> for st::Dtype {
match value {
DType::U8 => st::Dtype::U8,
DType::U32 => st::Dtype::U32,
DType::I64 => st::Dtype::I64,
DType::BF16 => st::Dtype::BF16,
DType::F16 => st::Dtype::F16,
DType::F32 => st::Dtype::F32,
@ -24,6 +36,7 @@ impl TryFrom<st::Dtype> for DType {
match value {
st::Dtype::U8 => Ok(DType::U8),
st::Dtype::U32 => Ok(DType::U32),
st::Dtype::I64 => Ok(DType::I64),
st::Dtype::BF16 => Ok(DType::BF16),
st::Dtype::F16 => Ok(DType::F16),
st::Dtype::F32 => Ok(DType::F32),
@ -76,11 +89,7 @@ impl st::View for &Tensor {
}
impl Tensor {
pub fn save_safetensors<P: AsRef<std::path::Path>>(
&self,
name: &str,
filename: P,
) -> Result<()> {
pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
let data = [(name, self.clone())];
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
}
@ -173,7 +182,7 @@ pub trait Load {
fn load(&self, device: &Device) -> Result<Tensor>;
}
impl<'a> Load for st::TensorView<'a> {
impl Load for st::TensorView<'_> {
fn load(&self, device: &Device) -> Result<Tensor> {
convert(self, device)
}
@ -189,6 +198,7 @@ impl Tensor {
match dtype {
DType::U8 => convert_slice::<u8>(data, shape, device),
DType::U32 => convert_slice::<u32>(data, shape, device),
DType::I64 => convert_slice::<i64>(data, shape, device),
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
DType::F16 => convert_slice::<half::f16>(data, shape, device),
DType::F32 => convert_slice::<f32>(data, shape, device),
@ -205,24 +215,15 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
convert_with_cast_::<u16, u32, _>(view, device, conv)
}
st::Dtype::U32 => convert_::<u32>(view, device),
st::Dtype::I32 => {
let conv = |x| Ok(i64::from(x));
convert_with_cast_::<i32, i64, _>(view, device, conv)
}
st::Dtype::I64 => convert_::<i64>(view, device),
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
st::Dtype::F16 => convert_::<half::f16>(view, device),
st::Dtype::F32 => convert_::<f32>(view, device),
st::Dtype::F64 => convert_::<f64>(view, device),
st::Dtype::I32 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i32, u32, _>(view, device, conv)
}
st::Dtype::I64 => {
let conv = |x| {
u32::try_from(x)
.map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
};
convert_with_cast_::<i64, u32, _>(view, device, conv)
}
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}
@ -233,6 +234,7 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
match tensor.dtype() {
DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
@ -242,18 +244,180 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
let data = std::fs::read(filename.as_ref())?;
let st = safetensors::SafeTensors::deserialize(&data)?;
load_buffer(&data[..], device)
}
pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
let st = safetensors::SafeTensors::deserialize(data)?;
st.tensors()
.into_iter()
.map(|(name, view)| Ok((name, view.load(device)?)))
.collect()
}
pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Result<()> {
pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
tensors: &HashMap<K, Tensor>,
filename: P,
) -> Result<()> {
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
pub struct MmapedFile(memmap2::Mmap);
#[derive(yoke::Yokeable)]
struct SafeTensors_<'a>(SafeTensors<'a>);
pub struct MmapedSafetensors {
safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
routing: Option<HashMap<String, usize>>,
}
impl MmapedSafetensors {
/// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let file = memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| Error::from(e).with_path(p))?;
let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
file,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)
.map_err(|e| Error::from(e).with_path(p))?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
Ok(Self {
safetensors: vec![safetensors],
routing: None,
})
}
/// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
///
/// If a tensor name appears in multiple files, the last entry is returned.
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
let mut routing = HashMap::new();
let mut safetensors = vec![];
for (index, p) in paths.iter().enumerate() {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let file = memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| Error::from(e).with_path(p))?;
let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
file,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)
.map_err(|e| Error::from(e).with_path(p))?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
for k in data.get().0.names() {
routing.insert(k.to_string(), index);
}
safetensors.push(data)
}
Ok(Self {
safetensors,
routing: Some(routing),
})
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.get(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
let mut tensors = vec![];
for safetensors in self.safetensors.iter() {
tensors.push(safetensors.get().0.tensors())
}
tensors.into_iter().flatten().collect()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
let index = match &self.routing {
None => 0,
Some(routing) => {
let index = routing.get(name).ok_or_else(|| {
Error::CannotFindTensor {
path: name.to_string(),
}
.bt()
})?;
*index
}
};
Ok(self.safetensors[index].get().0.tensor(name)?)
}
}
pub struct SliceSafetensors<'a> {
safetensors: SafeTensors<'a>,
}
impl<'a> SliceSafetensors<'a> {
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
pub fn new(buffer: &'a [u8]) -> Result<Self> {
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
Ok(Self { safetensors })
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.safetensors.tensor(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
self.safetensors.tensors()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
Ok(self.safetensors.tensor(name)?)
}
}
pub struct BufferedSafetensors {
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
}
impl BufferedSafetensors {
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
pub fn new(buffer: Vec<u8>) -> Result<Self> {
let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
buffer,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
Ok(Self { safetensors })
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.get(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
self.safetensors.get().0.tensors()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
Ok(self.safetensors.get().0.tensor(name)?)
}
}
pub struct MmapedFile {
path: std::path::PathBuf,
inner: memmap2::Mmap,
}
impl MmapedFile {
/// Creates a wrapper around a memory mapped file from which you can retrieve
@ -262,14 +426,21 @@ impl MmapedFile {
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let file = std::fs::File::open(p)?;
let mmap = memmap2::MmapOptions::new().map(&file)?;
Ok(Self(mmap))
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let inner = memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| Error::from(e).with_path(p))?;
Ok(Self {
inner,
path: p.to_path_buf(),
})
}
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
let st = safetensors::SafeTensors::deserialize(&self.0)?;
let st = safetensors::SafeTensors::deserialize(&self.inner)
.map_err(|e| Error::from(e).with_path(&self.path))?;
Ok(st)
}
}

Some files were not shown because too many files have changed in this diff Show More