Compare commits

...

75 Commits

Author SHA1 Message Date
ed353eb76d revert some changes 2025-05-17 03:46:18 +00:00
ffb8d63324 Use HF Papers 2025-05-17 03:41:24 +00:00
92106c8762 Fixes for clippy 1.87. (#2956) 2025-05-15 21:50:27 +02:00
9ce4fe6194 Fix docs quantized qwen3 (#2955)
* fixed docs quantized-qwen3 README

* fixed docs quantized-qwen2-instruct README
2025-05-15 07:58:03 +02:00
450a49ed1a Olmo 2 model (#2954)
* OLMo 2 model

* Update olmo-2 to example

* Clippy fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-05-14 19:18:02 +02:00
6bd61727bc Make tensor contiguous before the repeat_kv calls to avoid strided copies (#2953) 2025-05-14 10:47:28 +02:00
485ddf2996 Fixed Quantized Qwen3 Model (#2951)
* optimize KV cache to reduce GPU memory usage

* revert to using candle_nn::kv_cache::KvCache with initial capacity of 512
2025-05-13 05:53:42 +02:00
36508a2c93 Add Resize to onnx ops (#2946)
* added resize to candle-onnx, not currently working

* changed unreachable to bail, and bailed when both scales and sizes are set

* cleanup and added other unused options for this op

* cleanup

* fixed image loading to make output work

* cleanup and removed unused variables

* removed path path creation code, and changed unwrap to ?
2025-05-10 07:05:03 +02:00
3d05f5cf3d Qwen3 quantized implementation (#2939)
* fixed quantized_phi3 implementation

* quantized_qwen3 implementation

* Update quantized_phi3.rs

* Update quantized_phi3.rs

* add quantized_qwen3 example

* Clippy fixes.

* Cleanup.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-05-08 15:06:10 +02:00
637473cb5e Bump cudarc to 0.16.3. (#2942) 2025-05-04 09:14:28 +02:00
e27b4700ad Indexing with max-value results in zero/no-op. (#2940)
* Indexing with max-value results in zero/no-op.

* Add some testing.

* Also adapt the metal kernels.

* Another test.

* Fix.
2025-05-03 11:36:31 +02:00
1fdfb58de5 Updating Add qwen3 (PR 2903) to use HF weights (#2930)
* add Qwen3.rs

* fixed compile error

* attempting to gett pr 2903 working with qwen weights

* different qwen variants working

* added moe model

* clippy

* added additional eos token

* translated Korean comments to English as well as I can

* removed specialized Qwen3RmsNorm and replaced with generic Candle RmsNorm

* replaced custom repeat_kv implementation with candle's repeat_kv implementation

* replace linear with linear_b in attention initalization

* replaced custom custom kv_cache implementation with candle kv_cache

* style

* replaced explicit broadcast add with normal add in decoder layer

* removed keeping the Rotary embedding layer in the model struct

* used tie_word_embeddings bool from config instead of relying on existence of weights for lm head in CasualLM

* removed duplicate code from qwen3_moe

* removed sliding window from qwen3 attention

* removed MoE code

* removed unused option

* Fixed Typo

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>

* fixed tie word embeddings to use the correct embedding weights instead of the opposite

---------

Co-authored-by: Max <naturale@hufs.ac.kr>
Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2025-05-02 06:05:53 +02:00
cd96fa80da Add a scattered kv cache. (#2936)
* Add a scattered kv cache.

* Update some comments.
2025-05-01 10:20:48 +02:00
8a19bb7df2 Bump the candle version to 0.9.1. (#2935) 2025-05-01 10:08:16 +02:00
38fc86621c Add support for Helium-v1. (#2932) 2025-04-30 19:38:44 +02:00
5029ac52bb Added tracing page to the candle book. (#2922)
* tracing page

* warned about asynchronous execution

* cleanup

* added Nsignt Systems recommendation
2025-04-29 21:35:36 +02:00
de23d34a28 Switch Tensor::full to return a contiguous tensor. (#2929) 2025-04-28 21:36:39 +02:00
d4bac37a61 Fix the gumbel softmax by casting to f32. (#2928) 2025-04-28 19:48:51 +02:00
e98754fc5a Optimize Tensor::new when called on nested Vec<..>. (#2927)
* Optimize Tensor::new when called on nested Vec<..>.

* Improve performance.

* Similar flattening for the 4d case.

* More tweaks.

* Add some dummy test.
2025-04-28 09:19:45 +02:00
e3db30021f Support for "unbatched" rope. (#2926)
* Support for (un)-batched rope.

* Use 3d rope in the rope/ropei/rope_thd functions.

* Get the CPU versions to work.

* Fix the cuda version.

* Adapt the metal side.

* Fix the metal tests.
2025-04-27 15:12:02 +02:00
6e0646c208 Remove redundant mlx gemm dtype check (#2925) 2025-04-27 06:14:57 +02:00
fbaf0b0e32 Bump the crate version to 0.9.0. (#2924) 2025-04-26 11:01:21 +02:00
a2e925462c Add the scatter in place ops. (#2923)
* Add the scatter_set op.

* Metal op.

* Cuda version.

* Merge the checks.

* Add the actual ops.
2025-04-26 07:36:49 +02:00
3827685524 Add the scatter op. (#2921)
* Add the scatter op.

* Backprop support.

* Cuda support.
2025-04-25 21:46:58 +02:00
3aeb9575c7 Fixed Quantized Gemma3 Model and example (#2918)
* removed scale factor from computation and made quantized gemma3 work similarly to non-quantized gemma3

* created default consts, replaced is_sliding with Option holding a window_size
2025-04-25 05:47:48 +02:00
6ff0a6999c Fixed Gemma3 model and example (#2917)
* gemma3: changed RotaryEmbedding base freq based on layer and sliding window

* Changed attention mask per layer, either normal or sliding

* made attention mask creation slightly more efficient by only creating them once per model iteration

* changed is_sliding to an Option

* clippy

* changed to stop on both <eos> and <end_of_turn> instead of either or
2025-04-25 05:35:08 +02:00
82def7ae38 Cudarc update. (#2915) 2025-04-23 07:03:26 +02:00
99bd69f383 fixed quantized-gemma example (#2914)
* fixed quantized-gemma example

* lint
2025-04-23 05:39:03 +02:00
a4c56a958e Add the const-set op. (#2910)
* Add the const-set op.

* Cuda implementation.

* Bugfix.

* Metal cleanup.

* Add the metal kernels.

* Add some testing.

* Finish the metal implementation.

* Bump the version.
2025-04-19 10:07:02 +02:00
b2904a830b implemented quantized-gemma3 (#2902)
* implemented quantized-gemma, inference not working

* Fixed a few modeling bugs: outputing the correct tokens for a few iterations then garbage

* lint

* clippy

* quantized-gemma3 example working

* added readme

* clippy
2025-04-19 07:46:41 +02:00
21055b5697 Add PRelu operation (#2904)
* Add PRelu operation

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-04-19 07:24:10 +02:00
9dbaf958dc Add an enum for scalar values. (#2909)
* Add a scalar enum type.

* Add a bit more to the scalar type.

* Small tweak.

* More scalar usage.
2025-04-18 22:13:38 +02:00
ce5f8dd129 Check the bounds in the cuda indexing kernels. (#2908)
* Check the bounds in the cuda indexing kernels.

* Another check.
2025-04-18 20:08:17 +02:00
9954981327 Allow from_vec/from_slice to use a ShapeWithOneHole as shape. (#2905) 2025-04-17 08:59:18 +02:00
7f0f83a7c1 Rotating kv cache positions (#2901)
* Retrieve the current positions for rotating KV caches.

* Add the function to the kv cache too.

* More testing.
2025-04-15 23:09:26 +02:00
76e565c4ab Updated candle-book: Introduction, Installation, MNIST guide, and added CONTRIBUTING.md (#2897)
* added CONTRIBUTING.md to candle-book

* added description to candle-book introduction

* Updated formatting and added different features to candle-book installation

* mnist guide first draft candle-book

* updated mnist guide syntax and grammar for candle-book

* changed HelloWorld - Mnist to Tutorial - Mnist in SUMMARY.md

* updated intro to mnist guide in candle-book
2025-04-15 21:41:10 +02:00
e4e7b0b2da Use cudarc 0.16. (#2900)
* Use cudarc 0.16.

* Allow for disabling event tracking.

* Tweaks.

* Bump the ug version.

* And bump the candle version too.
2025-04-15 21:40:18 +02:00
b01ebbad8a Use cudarc 0.15.2. (#2896) 2025-04-14 20:47:52 +02:00
1d1d6d4fe6 Bump the crate version. (#2895) 2025-04-14 15:52:11 +02:00
2653002f29 Gumbel-Softmax sampling. (#2894)
* Gumbel-Softmax sampling.

* Add a sampling test.

* Share the gumbel-softmax bits.
2025-04-14 15:42:42 +02:00
a52b76ae82 Expose the cudnn algo in the conv ops. (#2892)
* Set the algo.

* Expose the cudnn preferred algo for conv ops.
2025-04-14 08:25:32 +02:00
fb660b8d43 Add a cudnn feature to candle-nn/candle-transformers. (#2890) 2025-04-13 17:43:41 +02:00
2f9606b187 Exclude candle-book to avoid some CI failures. (#2889)
* Exclude candle-book to avoid some CI failures.

* Remove the book CIs.
2025-04-13 17:11:41 +02:00
f3a73f80d1 Support for cudnn conv1d. (#2888)
* Support for cudnn conv1d.

* More conv1d work.

* Get the conv1d to work with cudnn.

* Cleanup.
2025-04-13 16:47:37 +02:00
b44d38de0e Add the Orpheus TTS. (#2886)
* Add the Orpheus TTS.

* Add a small readme.

* Token fix.

* Support more voices.

* Clippy fixes.
2025-04-13 12:02:17 +02:00
d9198deb37 Im2col cuda optimization. (#2885) 2025-04-13 10:07:53 +02:00
15ed0b11ce Optimize the batched matmul for the cpu backend. (#2884) 2025-04-12 21:40:40 +02:00
34505fdf3a Avoid using batched-matmul in nn::Linear. (#2883)
* Avoid using batched-matmul in nn::Linear.

* Also avoid batched matmul in conv1d.

* Also tweak the conv2d.

* Batched tests.

* Also cover conv2d.
2025-04-12 19:53:58 +02:00
d7b7ce16e4 Upgrade ug. (#2882) 2025-04-12 13:19:32 +02:00
19fb6dac1f Bump the crate version. (#2881) 2025-04-11 22:28:21 +02:00
acc5bd335f Cuda cleanup. (#2880)
* Cuda cleanup.

* More fixes.
2025-04-11 21:43:35 +02:00
eb478ece92 Implementing DistilBertForMaskedLM. (#2866)
* Initial commit: model weights working, prediciton incorrect

* moved distilbertformaskedlm into distilbert modeling file

* made maskedLM like bert example, still incorrect predictions

* finally not getting NaNs, fixed attention mask

* getting correct output sentences

* get top k predictions

* fixed output formatting slightly

* added default arg for model_id

* lint

* moved masked token example code from distilbertformaskedlm example to distilbert example

* lint

* removed distilbertformaskedlm example

* cleanup

* clippy

* removed embedding normalization from example

* made output and model dependent on args instead of prompt

* lint

* replaced or_ok anyhow error with anyhow context

* changed error message for mask token not found
2025-04-11 13:25:39 +02:00
d339b01726 Fix hardcoded f32 dtype for attention_mask. Use the model dtype for compatibility. (#2872) 2025-04-08 06:12:14 +02:00
2f3bf42bcb Support more snac variants. (#2871) 2025-04-07 08:23:47 +02:00
e3370c6316 Add the SNAC audio tokenizer. (#2869)
* Add the SNAC audio tokenizer.

* More snac.

* Again more snac.

* Add some example code for snac.

* Get the weights to load.

* Add to the snac model.

* Fixes.

* Get round-tripping to work.

* Save/load code files.

* Clippy fix.

* Fmt fix.
2025-04-06 22:15:36 +02:00
338f6a102e Clippy 1.86 fixes for cuda. (#2868) 2025-04-05 15:45:35 +02:00
bc33df77e1 Add the missing voices for CSM. (#2867) 2025-04-05 06:52:36 +02:00
cf9d7bf24c Add the CSM model. (#2862)
* Add the CSM model.

* Add some code to load the model.

* Load the text tokenizer.

* Add frame generation.

* Get the sampling to work.

* Rope fix.

* Autoregressive generation.

* Generate some audio file.

* Use the actual prompt.

* Support multiple turns.

* Add a very barebone readme.

* Move some of the shared bits to the model.
2025-04-04 06:48:03 +02:00
9d31361c4f Fix for clippy 1.86. (#2864)
* Fix for clippy 1.86.

* More clippy fixes.

* More fixes.
2025-04-03 19:38:27 +02:00
648596c073 Added readmes to examples (#2835)
* added chatGLM readme

* changed wording in readme

* added readme for chinese-clip

* added readme for convmixer

* added readme for custom ops

* added readme for efficientnet

* added readme for llama

* added readme to mnist-training

* added readme to musicgen

* added readme to quantized-phi

* added readme to starcoder2

* added readme to whisper-microphone

* added readme to yi

* added readme to yolo-v3

* added readme to whisper-microphone

* added space to example in glm4 readme

* fixed mamba example readme to run mamba instead of mamba-minimal

* removed slash escape character

* changed moondream image to yolo-v8 example image

* added procedure for making the reinforcement-learning example work with a virtual environment on my machine

* added simple one line summaries to the example readmes without

* changed non-existant image to yolo example's bike.jpg

* added backslash to sam command

* removed trailing - from siglip

* added SoX to silero-vad example readme

* replaced procedure for uv on mac with warning that uv isn't currently compatible with pyo3

* added example to falcon readme

* added --which arg to stella-en-v5 readme

* fixed image path in vgg readme

* fixed the image path in the vit readme

* Update README.md

* Update README.md

* Update README.md

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2025-04-03 09:18:29 +02:00
d9904a3baf Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14.

* Adapt a couple more things.

* And a couple more fixes.

* More tweaks.

* And a couple more fixes.

* Bump the major version number.

* Proper module system for the cuda kernels.

* Proper ptx loading.

* Launch the sort kernel.

* Custom op.

* Start using the builder pattern.

* More builder.

* More builder.

* Get candle-core to compile.

* Get the tests to pass.

* Get candle-nn to work too.

* Support for custom cuda functions.

* cudnn fixes.

* Get flash attn to run.

* Switch the crate versions to be alpha.

* Bump the ug dependency.
2025-04-03 09:12:19 +02:00
d6db305829 Added new language pairs to marian-mt example. (#2860)
* added new language pairs to marian-mt

* lint

* seperated python code for converting tokenizers into its own file and and added a reqirements.txt for dependencies, updated instructions in readme and included python version

* Cleanup.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-04-02 23:50:14 +02:00
b4daa03e59 add as_cuda_slice_mut to CudaStorage and CudaDType (#2859) 2025-04-01 19:34:52 +02:00
9541467d6b Add flip to tensor (#2855)
* Add `flip` to `tensor`

* Move the tests to the proper places.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-04-01 09:07:16 +02:00
6429609090 Added Deepseekr1 Llama8b variant to quantized example (#2842)
* added deepseekr1 llama8b variant to quantized example

* lint
2025-03-30 10:55:21 +02:00
ba473290da Added DeepseekR1 Qwen7B variant to quantized-qwen2-instruct example (#2843)
* quantized deepseek qwen generating tokens

* removed is_deepseek from Args and replaced prompt if statement with pattern matching
2025-03-30 10:54:22 +02:00
59c26195db Fix CIFAR10 dataset types and dimension ordering (#2845) 2025-03-30 10:53:25 +02:00
cb02b389d5 Fix reinforcement learning example (#2837) 2025-03-26 16:27:45 +01:00
0d4097031c fixed rand import for mnist-training (#2833) 2025-03-26 08:10:03 +01:00
10853b803c fixed rand imports for whisper-microphone example (#2834) 2025-03-26 08:09:27 +01:00
f3d472952f fix: candle-flash-attn linux and msvc build (#2829)
* fix: candle-flash-attn linux and msvc build

* Missing newline at eof.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-03-25 08:45:12 +01:00
67b85f79f1 Pickle decoder fix and Long1 opcode addition. (#2824)
* Pickle decoder changes: added Long1 opcode, fixed tensor offset calculation

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-03-23 08:10:08 +01:00
0b24f7f0a4 Fix for whisper example. rand::distribution is now rand::distr (#2811) 2025-03-16 19:14:55 +01:00
3afb04925a Allow for growing the default KV cache when needed. (#2810) 2025-03-16 17:30:25 +01:00
cbf5fc80c2 Add Gemma 3 1b IT toe Gemma examples (#2809)
- Updates the Gemma example to include Gemma 3 1b instruction tuned.
2025-03-16 17:00:48 +01:00
257 changed files with 10759 additions and 3072 deletions

View File

@ -1,40 +0,0 @@
name: Deploy Rust book
on:
push:
branches:
- main
jobs:
deploy:
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir mdbook
curl -sSL $url | tar -xz --directory=./mdbook
echo `pwd`/mdbook >> $GITHUB_PATH
- name: Deploy GitHub Pages
run: |
# This assumes your book is in the root of your repository.
# Just add a `cd` here if you need to change to another directory.
cd candle-book
mdbook build
git worktree add gh-pages
git config user.name "Deploy from CI"
git config user.email ""
cd gh-pages
# Delete the ref to avoid keeping history.
git update-ref -d refs/heads/gh-pages
rm -rf *
mv ../book/* .
git add .
git commit -m "Deploy $GITHUB_SHA to gh-pages"
git push --force --set-upstream origin gh-pages

View File

@ -1,29 +0,0 @@
name: CI
on:
pull_request:
jobs:
test:
name: Test candle-book
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@master
- name: Install Rust
run: |
rustup set profile minimal
rustup toolchain install stable
rustup default stable
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir bin
curl -sSL $url | tar -xz --directory=bin
echo "$(pwd)/bin" >> $GITHUB_PATH
- name: Run tests
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/

View File

@ -3,7 +3,6 @@ members = [
"candle-core",
"candle-datasets",
"candle-examples",
"candle-book",
"candle-nn",
"candle-pyo3",
"candle-transformers",
@ -12,6 +11,7 @@ members = [
"tensor-tools",
]
exclude = [
"candle-book",
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.8.4"
version = "0.9.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.8.4" }
candle-datasets = { path = "./candle-datasets", version = "0.8.4" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.4" }
candle-kernels = { path = "./candle-kernels", version = "0.8.4" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.4" }
candle-nn = { path = "./candle-nn", version = "0.8.4" }
candle-onnx = { path = "./candle-onnx", version = "0.8.4" }
candle-transformers = { path = "./candle-transformers", version = "0.8.4" }
candle = { path = "./candle-core", package = "candle-core", version = "0.9.1" }
candle-datasets = { path = "./candle-datasets", version = "0.9.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.1" }
candle-kernels = { path = "./candle-kernels", version = "0.9.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.1" }
candle-nn = { path = "./candle-nn", version = "0.9.1" }
candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
ug = "0.1.0"
ug-cuda = "0.1.0"
ug-metal = "0.1.0"
ug = "0.4.0"
ug-cuda = "0.4.0"
ug-metal = "0.4.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}

View File

@ -290,6 +290,8 @@ Cheatsheet:
### Why should I use Candle?
<!--- ANCHOR: goals --->
Candle's core goal is to *make serverless inference possible*. Full machine learning frameworks like PyTorch
are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight
binaries.
@ -299,6 +301,7 @@ and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-fut
Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers).
<!--- ANCHOR_END: goals --->
### Other ML frameworks

View File

@ -0,0 +1,13 @@
# Candle Book
The book uses [mdBook](https://github.com/rust-lang/mdBook) for building.
## Installation
To install mdBook, run `cargo install mdbook`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/installation.html).
## Viewing the book
To view the book, run `mdbook serve --open candle-book`. More instructions can be found [here](https://rust-lang.github.io/mdBook/guide/creating.html).
The book is built automatically in github CI.

View File

@ -1,6 +1,7 @@
# Introduction
{{#include ../../README.md:goals}}
{{#include ../../README.md:features}}
This book will introduce step by step how to use `candle`.
This book will introduce step by step how to use `candle`.

View File

@ -5,7 +5,10 @@
# User Guide
- [Installation](guide/installation.md)
- [Hello World - MNIST](guide/hello_world.md)
- [Tutorial - MNIST](guide/mnist/intro.md)
- [Modeling](guide/mnist/modeling.md)
- [Training](guide/mnist/training.md)
- [Saving And Loading](guide/mnist/saving_loading.md)
- [PyTorch cheatsheet](guide/cheatsheet.md)
# Reference Guide
@ -13,6 +16,7 @@
- [Running a model](inference/inference.md)
- [Using the hub](inference/hub.md)
- [Error management](error_manage.md)
- [Tracing](tracing.md)
- [Training](training/training.md)
- [Simplified](training/simplified.md)
- [MNIST](training/mnist.md)

View File

@ -1,8 +1,23 @@
# Installation
**With Cuda support**:
## 1. Create a new rust app or library
1. First, make sure that Cuda is correctly installed.
```bash
cargo new myapp
cd myapp
```
## 2. Add the correct candle version
### Standard
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core
```
### CUDA
First, make sure that Cuda is correctly installed.
- `nvcc --version` should print information about your Cuda compiler driver.
- `nvidia-smi --query-gpu=compute_cap --format=csv` should print your GPUs compute capability, e.g. something
like:
@ -17,43 +32,36 @@ You can also compile the Cuda kernels for a specific compute cap using the
If any of the above commands errors out, please make sure to update your Cuda version.
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.
Start by creating a new cargo:
```bash
cargo new myapp
cd myapp
```
Make sure to add the `candle-core` crate with the cuda feature:
Add the `candle-core` crate with the cuda feature:
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
```
### MKL
You can also see the `mkl` feature which can get faster inference on CPU.
Add the `candle-core` crate with the mkl feature:
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core --features "mkl"
```
### Metal
Metal is exclusive to MacOS.
Add the `candle-core` crate with the metal feature:
```bash
cargo add --git https://github.com/huggingface/candle.git candle-core --features "metal"
```
## 3. Building
Run `cargo build` to make sure everything can be correctly built.
```bash
cargo build
```
**Without Cuda support**:
Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) as follows:
```bash
cargo new myapp
cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core
```
Finally, run `cargo build` to make sure everything can be correctly built.
```bash
cargo build
```
**With mkl support**
You can also see the `mkl` feature which could be interesting to get faster inference on CPU. [Using mkl](./advanced/mkl.md)

View File

@ -0,0 +1,17 @@
# Candle MNIST Tutorial
## Introduction
This tutorial provides an introduction to Candle by implementing and training a neural network for MNIST digit classification from scratch.
Throughout this tutorial, you will learn the basics of:
- Tensor operations and model construction
- Creating and implementing neural network layers
- Parameter initialization
- Training loop implementation
- Saving and loading trained models
## Getting Started
Before proceeding, please ensure that you have properly installed Candle by following the instructions in the [Installation](../installation.md) guide.

View File

@ -0,0 +1,172 @@
# Candle MNIST Tutorial
## Modeling
Open `src/main.rs` in your project folder and insert the following code:
```rust
use candle_core::{Device, Result, Tensor};
struct Model {
first: Tensor,
second: Tensor,
}
impl Model {
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = image.matmul(&self.first)?;
let x = x.relu()?;
x.matmul(&self.second)
}
}
fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to utilize GPU acceleration.
let device = Device::Cpu;
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}
```
Execute the program with:
```bash
$ cargo run --release
> Digit Tensor[dims 1, 10; f32] digit
```
Since random inputs are provided, expect an incoherent output.
## Implementing a `Linear` Layer
To create a more sophisticated layer type, add a `bias` to the weight to construct the standard `Linear` layer.
Replace the entire content of `src/main.rs` with:
```rust
use candle_core::{Device, Result, Tensor};
struct Linear {
weight: Tensor,
bias: Tensor,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.weight)?;
x.broadcast_add(&self.bias)
}
}
struct Model {
first: Linear,
second: Linear,
}
impl Model {
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = self.first.forward(image)?;
let x = x.relu()?;
self.second.forward(&x)
}
}
fn main() -> Result<()> {
// Use Device::new_cuda(0)?; for GPU acceleration.
// Use Device::Cpu; for CPU computation.
let device = Device::cuda_if_available(0)?;
// Initialize model parameters
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear { weight, bias };
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear { weight, bias };
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
// Perform inference
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}
```
Execute again with:
```bash
$ cargo run --release
> Digit Tensor[dims 1, 10; f32] digit
```
## Utilizing `candle_nn`
Many classical layers (such as [Linear](https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs)) are already implemented in [candle-nn](https://github.com/huggingface/candle/tree/main/candle-nn).
This `Linear` implementation follows PyTorch conventions for improved compatibility with existing models, utilizing the transpose of weights rather than direct weights.
Let's simplify our implementation. First, add `candle-nn` as a dependency:
```bash
$ cargo add --git https://github.com/huggingface/candle.git candle-nn
```
Now, replace the entire content of `src/main.rs` with:
```rust
use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module};
struct Model {
first: Linear,
second: Linear,
}
impl Model {
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = self.first.forward(image)?;
let x = x.relu()?;
self.second.forward(&x)
}
}
fn main() -> Result<()> {
// Use Device::new_cuda(0)?; for GPU acceleration.
let device = Device::Cpu;
// Note the dimension change: (784, 100) -> (100, 784)
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear::new(weight, Some(bias));
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear::new(weight, Some(bias));
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}
```
Execute the final version:
```bash
$ cargo run --release
> Digit Tensor[dims 1, 10; f32] digit
```

View File

@ -0,0 +1,158 @@
# Candle MNIST Tutorial
## Saving and Loading Models
After training a model, it is useful to save and subsequently load the model parameters. In Candle, this functionality is managed through the `VarMap` data structure, with parameters stored on disk using the [safetensors](https://huggingface.co/docs/safetensors/index) format.
### Saving Model Parameters
Let's modify our `training_loop` function to include functionality for saving weights:
```rust
fn training_loop(
m: candle_datasets::vision::Dataset,
) -> anyhow::Result<()> {
let dev = Device::cuda_if_available(0)?;
let train_labels = m.train_labels;
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
// Initialize a VarMap for trainable parameters
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
let model = Model::new(vs.clone())?;
let learning_rate = 0.05;
let epochs = 10;
// Initialize stochastic gradient descent optimizer
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
for epoch in 1..epochs {
// Standard MNIST forward pass
let logits = model.forward(&train_images)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
// Compute Negative Log Likelihood loss
let loss = loss::nll(&log_sm, &train_labels)?;
// Perform backward pass and update weights
sgd.backward_step(&loss)?;
// Evaluate model on test set
let test_logits = model.forward(&test_images)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_labels)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
println!(
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
loss.to_scalar::<f32>()?,
test_accuracy
);
}
// Save model weights to disk
varmap.save("model_weights.safetensors")?;
Ok(())
}
```
```bash
$ cargo run --release
> 1 train loss: 2.40485 test acc: 0.11%
> 2 train loss: 2.34161 test acc: 0.14%
> 3 train loss: 2.28841 test acc: 0.17%
> 4 train loss: 2.24158 test acc: 0.19%
> 5 train loss: 2.19898 test acc: 0.23%
> 6 train loss: 2.15927 test acc: 0.26%
> 7 train loss: 2.12161 test acc: 0.29%
> 8 train loss: 2.08549 test acc: 0.32%
> 9 train loss: 2.05053 test acc: 0.35%
```
### Loading Model Parameters
Now that we have saved our model parameters, we can modify the code to load them. The primary change required is to make the `varmap` variable mutable:
```rust
fn training_loop(
m: candle_datasets::vision::Dataset,
) -> anyhow::Result<()> {
let dev = Device::cuda_if_available(0)?;
let train_labels = m.train_labels;
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
// Create a mutable VarMap for trainable parameters
let mut varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
let model = Model::new(vs.clone())?;
// Load pre-trained weights from file
varmap.load("model_weights.safetensors")?;
let learning_rate = 0.05;
let epochs = 10;
// Initialize stochastic gradient descent optimizer
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
for epoch in 1..epochs {
// Standard MNIST forward pass
let logits = model.forward(&train_images)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
// Compute Negative Log Likelihood loss
let loss = loss::nll(&log_sm, &train_labels)?;
// Perform backward pass and update weights
sgd.backward_step(&loss)?;
// Evaluate model on test set
let test_logits = model.forward(&test_images)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_labels)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
println!(
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
loss.to_scalar::<f32>()?,
test_accuracy
);
}
// Save updated weights back to disk
varmap.save("model_weights.safetensors")?;
Ok(())
}
```
```bash
$ cargo run --release
> 1 train loss: 2.01645 test acc: 0.38%
> 2 train loss: 1.98300 test acc: 0.41%
> 3 train loss: 1.95008 test acc: 0.44%
> 4 train loss: 1.91754 test acc: 0.47%
> 5 train loss: 1.88534 test acc: 0.50%
> 6 train loss: 1.85349 test acc: 0.53%
> 7 train loss: 1.82198 test acc: 0.56%
> 8 train loss: 1.79077 test acc: 0.59%
> 9 train loss: 1.75989 test acc: 0.61%
```
Note that loading the weights will fail if the specified file does not exist or is incompatible with the current model architecture. Implementing file existence checks and appropriate error handling is left to the user.

View File

@ -0,0 +1,134 @@
# Candle MNIST Tutorial
## Training Implementation
First, let's create a utility function `make_linear` that accepts a `VarBuilder` and returns an initialized linear layer. The `VarBuilder` constructs a `VarMap`, which is the data structure that stores our trainable parameters.
```rust
use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder, VarMap};
fn make_linear(vs: VarBuilder, in_dim: usize, out_dim: usize) -> Result<Linear> {
let ws = vs.get_with_hints(
(out_dim, in_dim),
"weight",
candle_nn::init::DEFAULT_KAIMING_NORMAL,
)?;
let bound = 1. / (in_dim as f64).sqrt();
let bs = vs.get_with_hints(
out_dim,
"bias",
candle_nn::Init::Uniform {
lo: -bound,
up: bound,
},
)?;
Ok(Linear::new(ws, Some(bs)))
}
```
Next, let's implement a `new` method for our model class to accept a `VarBuilder` and initialize the model. We use `VarBuilder::pp` to "push prefix" so that the parameter names are organized hierarchically: the first layer weights as `first.weight` and `first.bias`, and the second layer weights as `second.weight` and `second.bias`.
```rust
impl Model {
fn new(vs: VarBuilder) -> Result<Self> {
const IMAGE_DIM: usize = 784;
const HIDDEN_DIM: usize = 100;
const LABELS: usize = 10;
let first = make_linear(vs.pp("first"), IMAGE_DIM, HIDDEN_DIM)?;
let second = make_linear(vs.pp("second"), HIDDEN_DIM, LABELS)?;
Ok(Self { first, second })
}
fn forward(&self, image: &Tensor) -> Result<Tensor> {
let x = self.first.forward(image)?;
let x = x.relu()?;
self.second.forward(&x)
}
}
```
Now, let's add the `candle-datasets` package to our project to access the MNIST dataset:
```bash
$ cargo add --git https://github.com/huggingface/candle.git candle-datasets
```
With the dataset available, we can implement our training loop:
```rust
use candle_core::{DType, Device, Result, Tensor, D};
use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap};
fn training_loop(
m: candle_datasets::vision::Dataset,
) -> anyhow::Result<()> {
let dev = Device::cuda_if_available(0)?;
let train_labels = m.train_labels;
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
// Initialize a VarMap to store trainable parameters
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
let model = Model::new(vs.clone())?;
let learning_rate = 0.05;
let epochs = 10;
// Initialize a stochastic gradient descent optimizer to update parameters
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), learning_rate)?;
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
for epoch in 1..epochs {
// Perform forward pass on MNIST data
let logits = model.forward(&train_images)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
// Compute Negative Log Likelihood loss
let loss = loss::nll(&log_sm, &train_labels)?;
// Perform backward pass and update weights
sgd.backward_step(&loss)?;
// Evaluate model on test set
let test_logits = model.forward(&test_images)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_labels)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
println!(
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
loss.to_scalar::<f32>()?,
test_accuracy
);
}
Ok(())
}
```
Finally, let's implement our main function:
```rust
pub fn main() -> anyhow::Result<()> {
let m = candle_datasets::vision::mnist::load()?;
return training_loop(m);
}
```
Let's execute the training process:
```bash
$ cargo run --release
> 1 train loss: 2.35449 test acc: 0.12%
> 2 train loss: 2.30760 test acc: 0.15%
> ...
```

View File

@ -0,0 +1,68 @@
# Tracing
Tracing is a powerful tool for identifying performance issues and bottlenecks in code.
> Profiling on GPUs is trickier due to asynchronous execution, see the [GPU section](#gpu).
## Overview
Candle uses the [tracing](https://docs.rs/tracing/latest/tracing/) crate for instrumentation.
To try it out, run an example in `candle-examples` with the `--tracing` flag.
This generates a trace file, typically named `trace-<timestamp>.json`.
You can view the trace in Chrome by navigating to `chrome://tracing/`, clicking **Load**, and selecting the generated trace file.
## Adding Tracing
Candle includes built-in tracing for many internal operations, using [spans](https://docs.rs/tracing/latest/tracing/struct.Span.html) to mark key points of execution.
To add custom tracing in your code, you can define a span like this:
```rust
let span = tracing::span!(tracing::Level::TRACE, name);
```
Then, to record the span during execution, create a guard:
```rust
let _enter = span.enter();
```
This guard will record the span's duration, from when it is created to when it is dropped, into a global data structure managed by the tracing crate.
## Recording and Saving a Trace
To capture and save trace data, you need to configure the tracing system with an output format. Candle uses the [tracing_subscriber](https://docs.rs/tracing-subscriber/latest/tracing_subscriber/) and [tracing_chrome](https://docs.rs/tracing-chrome/latest/tracing_chrome/) crates.
The snippet below sets up a Chrome compatible recorder that logs all tracing activity between creation and drop of the guard:
```rust
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let _guard = {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
guard
};
```
## GPU
When using CUDA, Metal, or other asynchronous GPU backends, tracing may produce misleading timing data because operations are queued rather than executed immediately.
### CUDA
For CUDA-specific profiling, you have two options:
1. Set the environment variable `CUDA_LAUNCH_BLOCKING=1` which forces synchronous execution. This makes trace timings more accurate, at the cost of reduced performance.
2. Use [NVIDIA's Nsight Systems](https://developer.nvidia.com/nsight-systems) (`nsys profile` and `nsys-ui`) which are designed specifically for profiling asynchronous CUDA executions.
We recommend using NVIDIA's Nsight Systems when possible, as it offers accurate performance data without altering typical execution patterns. In contrast, setting the `CUDA_LAUNCH_BLOCKING` environment variable forces synchronous execution, which can significantly alter execution behavior.
#### Performance Profiling with NVIDIA Nsight Systems
1. Generate an `.nsys-rep` file containing performance data ([docs](https://docs.nvidia.com/nsight-systems/UserGuide/index.html#example-single-command-lines))
- Run `nsys profile --trace cuda,nvtx,osrt --gpu-metrics-device=all --output profile_run ./target/debug/... --prompt "whatever "`
1. Open the generated `.nsys-rep` report file in Nsight Systems GUI
- File > Open

View File

@ -56,3 +56,7 @@ harness = false
[[example]]
name = "metal_basics"
required-features = ["metal"]
[[example]]
name = "cuda_basics"
required-features = ["cuda"]

View File

@ -4,11 +4,12 @@ use criterion::criterion_main;
criterion_main!(
benchmarks::affine::benches,
benchmarks::copy::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::matmul::benches,
benchmarks::qmatmul::benches,
benchmarks::random::benches,
benchmarks::reduce::benches,
benchmarks::unary::benches,
benchmarks::where_cond::benches,
benchmarks::conv_transpose2d::benches,
benchmarks::qmatmul::benches,
benchmarks::unary::benches
);

View File

@ -0,0 +1,38 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{Device, Tensor, WithDType};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn run_copy_mask_benchmark<D: WithDType>(c: &mut Criterion, device: &Device, name: &str) {
let batch_size = 128;
let in_seq_len = 1;
let kv_seq_len = 1024;
let attn_mask = vec![vec![vec![D::zero(); kv_seq_len]; in_seq_len]; batch_size];
let size_in_bytes = batch_size * in_seq_len * kv_seq_len * D::DTYPE.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name(name));
group.throughput(Throughput::Bytes(size_in_bytes as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let attn_masks = vec![attn_mask.clone(); iters as usize];
let start = Instant::now();
for attn_mask in attn_masks.into_iter() {
let tensor = Tensor::new(black_box(attn_mask), device).unwrap();
black_box(tensor);
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_copy_mask_benchmark::<f32>(c, &device, "copy_mask");
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -1,5 +1,6 @@
pub(crate) mod affine;
pub(crate) mod conv_transpose2d;
pub(crate) mod copy;
pub(crate) mod matmul;
pub(crate) mod qmatmul;
pub(crate) mod random;
@ -21,7 +22,9 @@ impl BenchDevice for Device {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device.synchronize()?);
return Ok(device
.synchronize()
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}

View File

@ -6,28 +6,18 @@ extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Tensor};
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("tf32: {:?}", start_time.elapsed());
let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
drop(_x1);
for _ in 0..20 {
let start_time = std::time::Instant::now();
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
device.synchronize()?;
println!("conv1d: {:?}", start_time.elapsed());
}
Ok(())
}

View File

@ -71,15 +71,27 @@ pub trait BackendStorage: Sized {
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,
fn scatter_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self>;
) -> Result<()>;
fn scatter_add_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<()>;
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
fn index_add(
&self,
@ -113,6 +125,8 @@ pub trait BackendStorage: Sized {
_src_offset: usize,
_dst_offset: usize,
) -> Result<()>;
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
}
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
@ -127,8 +141,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
/// # Safety
/// This function is unsafe as it doesn't initialize the underlying data store.
/// The caller should ensure that the data is properly initialized as early as possible

View File

@ -53,6 +53,7 @@ impl Tensor {
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
| Op::Scatter(t1, t2, t3, _)
| Op::ScatterAdd(t1, t2, t3, _)
| Op::CustomOp3(t1, t2, t3, _)
| Op::WhereCond(t1, t2, t3) => {
@ -419,7 +420,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
}
Op::ScatterAdd(init, indexes, src, dim) => {
Op::Scatter(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;
@ -427,6 +428,16 @@ impl Tensor {
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::ScatterAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
let mask = init.ones_like()?;
let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
*init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
let src_grad = grad.gather(indexes, *dim)?;
let src_sum_grad = grads.or_insert(src)?;
*src_sum_grad = src_sum_grad.add(&src_grad)?;
}
Op::IndexAdd(init, indexes, src, dim) => {
let init_sum_grad = grads.or_insert(init)?;
*init_sum_grad = init_sum_grad.add(&grad)?;

View File

@ -14,6 +14,7 @@ pub struct ParamsConv1D {
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl ParamsConv1D {
@ -54,7 +55,7 @@ impl ParamsConvTranspose1D {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo {
ImplicitGemm,
ImplicitPrecompGemm,
@ -151,6 +152,19 @@ impl Tensor {
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
}
/// Applies a 1D convolution over the input tensor.
pub fn conv1d_with_algo(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
@ -174,6 +188,7 @@ impl Tensor {
padding,
stride,
dilation,
cudnn_fwd_algo,
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)
@ -278,6 +293,18 @@ impl Tensor {
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
}
pub fn conv2d_with_algo(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
@ -297,7 +324,7 @@ impl Tensor {
padding,
stride,
dilation,
cudnn_fwd_algo: None,
cudnn_fwd_algo,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)

View File

@ -7,7 +7,7 @@ use rayon::prelude::*;
mod utils;
pub use utils::{
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
};
const USE_IM2COL_CONV1D: bool = true;
@ -483,17 +483,22 @@ impl<I: IntDType> Map1 for Gather<'_, I> {
let start_dst_idx = start_dst_idx + i * dst_right_len;
for right_i in 0..dst_right_len {
let dst_idx = start_dst_idx + right_i;
let index = ids[dst_idx].as_usize();
if index >= src_dim_len {
Err(Error::InvalidIndex {
index,
size: src_dim_len,
op: "gather",
let index = ids[dst_idx];
if index == I::max_value() {
dst[dst_idx] = T::zero();
} else {
let index = index.as_usize();
if index >= src_dim_len {
Err(Error::InvalidIndex {
index,
size: src_dim_len,
op: "gather",
}
.bt())?
}
.bt())?
let src_idx = start_src_idx + index * src_right_len + right_i;
dst[dst_idx] = src[src_idx]
}
let src_idx = start_src_idx + index * src_right_len + right_i;
dst[dst_idx] = src[src_idx]
}
}
}
@ -535,45 +540,89 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
let start_src_idx = left_i * right_len * src_dim;
let start_dst_idx = left_i * right_len * n_ids;
for i in 0..n_ids {
let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
if index >= src_dim {
Err(Error::InvalidIndex {
index,
size: src_dim,
op: "index-select",
}
.bt())?
}
let start_src_idx = start_src_idx + index * right_len;
let start_dst_idx = start_dst_idx + i * right_len;
dst[start_dst_idx..start_dst_idx + right_len]
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
if index == I::max_value() {
dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
} else {
let index = index.as_usize();
if index >= src_dim {
Err(Error::InvalidIndex {
index,
size: src_dim,
op: "index-select",
}
.bt())?
}
let start_src_idx = start_src_idx + index * right_len;
dst[start_dst_idx..start_dst_idx + right_len]
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
}
}
}
Ok(dst)
}
}
struct ScatterAdd<'a, I: IntDType> {
trait ElemUpdate {
fn f<T: WithDType>(dst: &mut T, src: T);
}
struct Set;
struct Add;
impl ElemUpdate for Set {
fn f<T: WithDType>(dst: &mut T, src: T) {
*dst = src
}
}
impl ElemUpdate for Add {
fn f<T: WithDType>(dst: &mut T, src: T) {
*dst += src
}
}
struct Scatter<'a, I: IntDType, M: ElemUpdate> {
ids: &'a [I],
ids_l: &'a Layout,
dim: usize,
_phantom: std::marker::PhantomData<M>,
}
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
const OP: &'static str = "scatter-add";
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let dst_len = l1.shape().elem_count();
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
Self {
ids,
ids_l,
dim,
_phantom: Default::default(),
}
}
}
impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
const OP: &'static str = "scatter";
fn f<T: WithDType>(
&self,
dst: &mut [T],
dst_l: &Layout,
src: &[T],
src_l: &Layout,
) -> Result<()> {
let dst = match dst_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
Some((o1, o2)) => &mut dst[o1..o2],
};
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};
let dim = self.dim;
let ids_dims = self.ids_l.dims();
let dst_dims = l1.dims();
let dst_dims = dst_l.dims();
let dst_dim_len = dst_dims[dim];
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
@ -592,7 +641,11 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
let start_ids_idx = start_ids_idx + i * ids_right_len;
for right_i in 0..dst_right_len {
let ids_idx = start_ids_idx + right_i;
let index = ids[ids_idx].as_usize();
let index = ids[ids_idx];
if index == I::max_value() {
continue;
}
let index = index.as_usize();
if index >= dst_dim_len {
Err(Error::InvalidIndex {
index,
@ -602,12 +655,12 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
.bt())?
}
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
dst[dst_idx] += src[ids_idx]
M::f(&mut dst[dst_idx], src[ids_idx])
}
}
}
Ok(dst)
Ok(())
}
}
@ -635,6 +688,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
if dim == 0 {
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
if *dst_idx == I::max_value() {
continue;
}
let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
@ -653,6 +709,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
}
} else {
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
if *dst_idx == I::max_value() {
continue;
}
let dst_idx = dst_idx.as_usize();
if dst_idx >= max_idx {
Err(Error::InvalidIndex {
@ -1289,6 +1348,15 @@ impl Map2 for MatMul {
} else {
Parallelism::None
};
let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
// a_skip and c_skip should be updated but step is always 0 so
// it wouldn't matter.
(1, b * m, n, k)
} else if a_skip == 0 && b_skip == n * k {
(1, m, b * n, k)
} else {
(b, m, n, k)
};
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
@ -2372,19 +2440,36 @@ impl BackendStorage for CpuStorage {
}
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
) -> Result<()> {
match ids {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
}
}
fn scatter_add_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<()> {
match ids {
Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
}
}
@ -2445,6 +2530,48 @@ impl BackendStorage for CpuStorage {
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Ok(self.clone())
}
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
use crate::scalar::Scalar;
fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
match l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
src[start_offset..start_offset + len].fill(s)
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len: 1,
} => {
for src_index in block_start_index {
src[src_index] = s
}
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
for src_index in block_start_index {
src[src_index..src_index + block_len].fill(s)
}
}
}
}
match (self, s) {
(Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
(Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
(Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
(Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
(Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
(Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
(Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
(st, s) => crate::bail!(
"const_set dtype mismatch, expected {:?} but got {:?}",
st.dtype(),
s
),
}
Ok(())
}
}
impl BackendDevice for CpuDevice {
@ -2619,20 +2746,6 @@ impl BackendDevice for CpuDevice {
Ok(storage)
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
};
Ok(storage)
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
let storage = match dtype {

View File

@ -58,6 +58,30 @@ pub trait Map2 {
}
}
pub trait Map2InPlace {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
(C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
(C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
(C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
(C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
(C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
(C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
(v1, v2) => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
op: Self::OP,
}
.bt())?,
};
Ok(())
}
}
pub trait Map2U8 {
const OP: &'static str;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;

View File

@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d<
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
let c = Cudnn::new(dev.cuda_device());
let c = Cudnn::new(dev.cuda_stream());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d<
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv2d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
unsafe {
conv2d.launch::<CudaSlice<u8>, _, _, _>(
alg,
@ -122,3 +122,104 @@ pub(crate) fn launch_conv2d<
}
Ok(())
}
pub(crate) fn launch_conv1d<
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
Y: cudarc::cudnn::CudnnDataType,
>(
src: &CudaView<T>,
src_l: &crate::Layout,
filter: &CudaView<T>,
dst: &mut CudaSlice<T>,
params: &crate::conv::ParamsConv1D,
dev: &crate::cuda_backend::CudaDevice,
) -> crate::Result<()> {
use crate::conv::CudnnFwdAlgo as CandleAlgo;
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
let device_id = dev.id();
let cudnn = CUDNN.with(|cudnn| {
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
let c = Cudnn::new(dev.cuda_stream());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
c
})?;
let conv = cudnn.create_conv2d::<Y>(
/* pad */ [params.padding as i32, 0],
/* stride */ [params.stride as i32, 1],
/* dilation */ [params.dilation as i32, 1],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
// > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
// > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
// > recommended that the user create a 4D tensor, and set the size along unused dimensions
// > to 1.
let x_shape = [
params.b_size as i32,
params.c_in as i32,
params.l_in as i32,
1,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
x_shape,
)?
} else {
let s = src_l.stride();
cudnn.create_4d_tensor_ex::<T>(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
};
let w = cudnn.create_4d_filter::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[
params.c_out as i32,
params.c_in as i32,
params.k_size as i32,
1,
],
)?;
let l_out = params.l_out() as i32;
let y = cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, l_out, 1],
)?;
let conv1d = ConvForward {
conv: &conv,
x: &x,
w: &w,
y: &y,
};
let alg = match params.cudnn_fwd_algo {
None => conv1d.pick_algorithm()?,
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
Some(CandleAlgo::ImplicitPrecompGemm) => {
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
}
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv1d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
unsafe {
conv1d.launch::<CudaSlice<u8>, _, _, _>(
alg,
Some(&mut workspace),
(T::one(), T::zero()),
src,
filter,
dst,
)?;
}
Ok(())
}

View File

@ -2,8 +2,9 @@ use crate::backend::BackendDevice;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
use cudarc::driver::CudaFunction;
use half::{bf16, f16};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
@ -24,10 +25,17 @@ impl DeviceId {
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
pub struct ModuleStore {
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
device: Arc<cudarc::driver::CudaDevice>,
context: Arc<cudarc::driver::CudaContext>,
modules: Arc<std::sync::RwLock<ModuleStore>>,
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
stream: Arc<cudarc::driver::CudaStream>,
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
@ -38,17 +46,120 @@ impl std::fmt::Debug for CudaDevice {
}
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaDevice>;
impl CudaDevice {
#[allow(clippy::missing_safety_doc)]
pub unsafe fn alloc<T: cudarc::driver::DeviceRepr>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc::<T>(len).w()
}
pub fn alloc_zeros<T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc_zeros::<T>(len).w()
}
pub fn memcpy_htod<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::HostSlice<T> + ?Sized,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_htod(src, dst).w()
}
pub fn memcpy_dtov<T: cudarc::driver::DeviceRepr, Src: cudarc::driver::DevicePtr<T>>(
&self,
src: &Src,
) -> Result<Vec<T>> {
self.stream.memcpy_dtov(src).w()
}
pub fn memcpy_dtod<
T,
Src: cudarc::driver::DevicePtr<T>,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_dtod(src, dst).w()
}
pub fn memcpy_stod<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::HostSlice<T> + ?Sized,
>(
&self,
src: &Src,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.memcpy_stod(src).w()
}
}
pub struct CudaFunc {
func: CudaFunction,
stream: Arc<cudarc::driver::CudaStream>,
}
impl std::ops::Deref for CudaFunc {
type Target = CudaFunction;
fn deref(&self) -> &Self::Target {
&self.device
&self.func
}
}
impl CudaFunc {
pub fn into_cuda_function(self) -> CudaFunction {
self.func
}
}
#[macro_export]
macro_rules! builder_arg {
($b:ident, $($arg:expr),*) => {
$(
let __arg = $arg;
$b.arg(&__arg);
)*
};
}
impl CudaFunc {
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
self.stream.launch_builder(&self.func)
}
}
impl CudaDevice {
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
self.device.clone()
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
self.stream.clone()
}
/// When turned on, all cuda tensors **created after calling this function** will
/// not track uses via cuda events.
///
/// # Safety
///
/// It is up to the user to ensure proper synchronization between multiple streams:
/// - Ensure that no tensor is freed before a use on another stream is finished.
/// - Ensure that a tensor is not used on another stream before allocation on the
/// allocating stream finishes.
/// - Ensure that a tensor is not written two concurrently by multiple streams.
pub unsafe fn disable_event_tracking(&self) {
self.context.disable_event_tracking()
}
pub fn is_event_tracking(&self) -> bool {
self.context.is_event_tracking()
}
#[cfg(not(target_arch = "wasm32"))]
@ -56,7 +167,7 @@ impl CudaDevice {
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<CudaFunction> {
) -> Result<CudaFunc> {
let mut buf = vec![];
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
@ -65,117 +176,81 @@ impl CudaDevice {
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
let func = match self.device.get_func("ug", func_name) {
Some(func) => func,
None => crate::bail!("unknown function ug::{func_name}"),
};
Ok(func)
let module = self.context.load_module(ptx).w()?;
let func = module.load_function(func_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn id(&self) -> DeviceId {
self.id
}
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
let params = (&data, v as u8, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
let params = (&data, v as u32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
let params = (&data, v as i64, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
let params = (&data, bf16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
let params = (&data, f16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
pub fn get_or_load_custom_func(
&self,
fn_name: &str,
module_name: &str,
ptx: &str,
) -> Result<CudaFunc> {
let ms = self.custom_modules.read().unwrap();
if let Some(mdl) = ms.get(module_name).as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
}
drop(ms);
let mut ms = self.custom_modules.write().unwrap();
let cuda_module = self.context.load_module(ptx.into()).w()?;
ms.insert(module_name.to_string(), cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
// Leaking the string here is a bit sad but we need a &'static str and this is only
// done once per kernel name.
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
self.load_ptx(ptx.into(), module_name, &[static_module_name])
.map_err(|cuda| CudaError::Load {
cuda,
module_name: module_name.to_string(),
})
.w()?;
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
let ms = self.modules.read().unwrap();
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
}
self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
// able to only build the error value if needed.
.ok_or(CudaError::MissingKernel {
module_name: module_name.to_string(),
})
.w()
drop(ms);
let mut ms = self.modules.write().unwrap();
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
ms.mdls[mdl.index()] = Some(cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
}
impl CudaDevice {
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.new_stream().w()?;
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
Ok(Self {
id: DeviceId::new(),
device,
context,
stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
})
}
}
@ -184,14 +259,21 @@ impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.default_stream();
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
Ok(Self {
id: DeviceId::new(),
device,
context,
stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
})
}
@ -199,13 +281,13 @@ impl BackendDevice for CudaDevice {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
gpu_id: self.context.ordinal(),
}
}
@ -217,31 +299,31 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count).w()?;
let data = self.alloc_zeros::<u8>(elem_count)?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
let data = self.alloc_zeros::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count).w()?;
let data = self.alloc_zeros::<i64>(elem_count)?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
let data = self.alloc_zeros::<bf16>(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count).w()?;
let data = self.alloc_zeros::<f16>(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count).w()?;
let data = self.alloc_zeros::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count).w()?;
let data = self.alloc_zeros::<f64>(elem_count)?;
CudaStorageSlice::F64(data)
}
};
@ -265,12 +347,12 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let mut data = unsafe { self.alloc::<f32>(elem_count)? };
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let mut data = unsafe { self.alloc::<f64>(elem_count)? };
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
@ -309,7 +391,7 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
let mut data = unsafe { self.alloc::<f32>(elem_count_round)? };
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@ -317,7 +399,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
let mut data = unsafe { self.alloc::<f64>(elem_count_round)? };
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
@ -328,39 +410,35 @@ impl BackendDevice for CudaDevice {
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc::<u8>(elem_count).w()?;
let data = self.alloc::<u8>(elem_count)?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc::<u32>(elem_count).w()?;
let data = self.alloc::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc::<i64>(elem_count).w()?;
let data = self.alloc::<i64>(elem_count)?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc::<bf16>(elem_count).w()?;
let data = self.alloc::<bf16>(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc::<f16>(elem_count).w()?;
let data = self.alloc::<f16>(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc::<f32>(elem_count).w()?;
let data = self.alloc::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc::<f64>(elem_count).w()?;
let data = self.alloc::<f64>(elem_count)?;
CudaStorageSlice::F64(data)
}
};
@ -373,31 +451,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F64(data)
}
};
@ -410,31 +488,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage)?;
CudaStorageSlice::F64(data)
}
};
@ -447,31 +525,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage)?;
CudaStorageSlice::F64(data)
}
};
@ -482,7 +560,7 @@ impl BackendDevice for CudaDevice {
}
fn synchronize(&self) -> Result<()> {
self.device.synchronize().map_err(crate::Error::wrap)?;
self.stream.synchronize().map_err(crate::Error::wrap)?;
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
/// Helper functions to plug cuda kernels in candle.
use crate::{Layout, Result, Shape, WithDType};
use crate::{Layout, Result, WithDType};
pub use cudarc;
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
@ -96,7 +96,7 @@ pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
dst_l: &Layout,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
@ -105,19 +105,19 @@ pub trait Map2InPlace {
fn map(
&self,
dst: &mut S,
dst_s: &Shape,
dst_l: &Layout,
src: &S,
src_l: &Layout,
d: &CudaDevice,
) -> Result<()> {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d),
(S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d),
(S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d),
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
}
}

View File

@ -396,7 +396,10 @@ impl UgIOp1 {
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
Ok(Self {
name,
func: func.into_cuda_function(),
})
}
#[cfg(feature = "metal")]
{
@ -459,16 +462,16 @@ impl InplaceOp1 for UgIOp1 {
#[cfg(feature = "cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::LaunchAsync;
use cudarc::driver::PushKernelArg;
let elem_count = layout.shape().elem_count();
let stream = sto.device.cuda_stream();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let params = (&sto,);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
@ -479,7 +482,9 @@ impl InplaceOp1 for UgIOp1 {
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe { self.func.clone().launch(cfg, params) }.w()?;
let mut builder = stream.launch_builder(&self.func);
builder.arg(&sto);
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
}

View File

@ -103,7 +103,63 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
}
}
impl<S: NdArray> NdArray for Vec<S> {
impl<S: WithDType> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(self.as_slice())
}
}
impl<S: WithDType> NdArray for Vec<&[S]> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let n = self.len();
let m = self[0].len();
for v in self.iter() {
if v.len() != m {
crate::bail!("two elements have different len {m} {}", v.len())
}
}
Ok(Shape::from((n, m)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
S::to_cpu_storage_owned(data)
}
}
impl<S: WithDType> NdArray for Vec<Vec<S>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let n = self.len();
let m = self[0].len();
for v in self.iter() {
if v.len() != m {
crate::bail!("two elements have different len {m} {}", v.len())
}
}
Ok(Shape::from((n, m)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let len: usize = self.iter().map(|v| v.len()).sum();
let mut dst = Vec::with_capacity(len);
for v in self.iter() {
dst.extend(v.iter().copied());
}
S::to_cpu_storage_owned(dst)
}
}
impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
@ -120,9 +176,57 @@ impl<S: NdArray> NdArray for Vec<S> {
}
fn to_cpu_storage(&self) -> CpuStorage {
// This allocates intermediary memory and shouldn't be necessary.
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
CpuStorage::concat(storages.as_slice()).unwrap()
if self.is_empty() {
return S::to_cpu_storage_owned(vec![]);
}
let len: usize = self
.iter()
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
.sum();
let mut dst = Vec::with_capacity(len);
for v1 in self.iter() {
for v2 in v1.iter() {
dst.extend(v2.iter().copied());
}
}
S::to_cpu_storage_owned(dst)
}
}
impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
}
fn to_cpu_storage(&self) -> CpuStorage {
let len: usize = self
.iter()
.map(|v| {
v.iter()
.map(|v| v.iter().map(|v| v.len()).sum::<usize>())
.sum::<usize>()
})
.sum();
let mut dst = Vec::with_capacity(len);
for v1 in self.iter() {
for v2 in v1.iter() {
for v3 in v2.iter() {
dst.extend(v3.iter().copied());
}
}
}
S::to_cpu_storage_owned(dst)
}
}
@ -292,23 +396,6 @@ impl Device {
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuDevice.ones_impl(shape, dtype)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {

View File

@ -107,6 +107,7 @@ pub trait WithDType:
fn from_f64(v: f64) -> Self;
fn to_f64(self) -> f64;
fn to_scalar(self) -> crate::scalar::Scalar;
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
@ -131,6 +132,10 @@ macro_rules! with_dtype {
$to_f64(self)
}
fn to_scalar(self) -> crate::scalar::Scalar {
crate::scalar::Scalar::$dtype(self)
}
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
CpuStorageRef::$dtype(data)
}
@ -175,7 +180,7 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
pub trait IntDType: WithDType {
pub trait IntDType: WithDType + num_traits::Bounded {
fn is_true(&self) -> bool;
fn as_usize(&self) -> usize;
}

View File

@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage {
fail!()
}
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -124,15 +128,27 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn scatter_add_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -214,10 +230,6 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -41,6 +41,10 @@ impl crate::backend::BackendStorage for MetalStorage {
fail!()
}
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
@ -128,15 +132,27 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add_set(
&mut self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
@ -218,10 +234,6 @@ impl crate::backend::BackendDevice for MetalDevice {
Err(Error::NotCompiledWithMetalSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}

View File

@ -413,6 +413,100 @@ impl BackendStorage for MetalStorage {
self.binary(name, rhs, lhs_l, rhs_l)
}
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
use crate::scalar::Scalar;
fn set<S: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
self_: &mut MetalStorage,
s: S,
l: &Layout,
) -> Result<()> {
let device = self_.device();
let dtype = self_.dtype;
let shape = l.shape();
let el_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
command_buffer.set_label("const-set");
let dst = buffer_o(&self_.buffer, l, self_.dtype);
match (el_count % 2, dtype, l.is_contiguous()) {
(0, DType::BF16 | DType::F16, true) => {
use candle_metal_kernels::unary::contiguous_tiled;
let kernel_name = match dtype {
DType::F16 => contiguous_tiled::const_set::HALF,
DType::BF16 => contiguous_tiled::const_set::BFLOAT,
_ => crate::bail!("internal bug in const_set"),
};
candle_metal_kernels::call_const_set_contiguous_tiled(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
s,
dst,
)
.map_err(MetalError::from)?;
}
(_, _, true) => {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match dtype {
DType::F16 => contiguous::const_set::HALF,
DType::BF16 => contiguous::const_set::BFLOAT,
DType::F32 => contiguous::const_set::FLOAT,
DType::I64 => contiguous::const_set::I64,
DType::U32 => contiguous::const_set::U32,
DType::U8 => contiguous::const_set::U8,
DType::F64 => crate::bail!("unsupported const-set f64"),
};
candle_metal_kernels::call_const_set_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
s,
dst,
)
.map_err(MetalError::from)?;
}
(_, _, false) => {
use candle_metal_kernels::unary::strided;
let kernel_name = match dtype {
DType::F16 => strided::const_set::HALF,
DType::BF16 => strided::const_set::BFLOAT,
DType::F32 => strided::const_set::FLOAT,
DType::I64 => strided::const_set::I64,
DType::U32 => strided::const_set::U32,
DType::U8 => strided::const_set::U8,
DType::F64 => crate::bail!("unsupported const-set f64"),
};
candle_metal_kernels::call_const_set_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
l.dims(),
s,
l.stride(),
dst,
)
.map_err(MetalError::from)?;
}
}
Ok(())
}
match (self.dtype, s) {
(DType::U8, Scalar::U8(s)) => set(self, s, l),
(DType::U32, Scalar::U32(s)) => set(self, s, l),
(DType::I64, Scalar::I64(s)) => set(self, s, l),
(DType::F16, Scalar::F16(s)) => set(self, s, l),
(DType::BF16, Scalar::BF16(s)) => set(self, s, l),
(DType::F32, Scalar::F32(s)) => set(self, s, l),
(DType::F64, Scalar::F64(s)) => set(self, s, l),
_ => crate::bail!("dtype mismatch, expected {:?}, got {:?}", self.dtype, s),
}
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let device = self.device();
let shape = layout.shape();
@ -1332,18 +1426,65 @@ impl BackendStorage for MetalStorage {
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
}
fn scatter_add(
&self,
fn scatter_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
) -> Result<()> {
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "scatter" }.bt());
};
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::F32) => "s_u8_f32",
(DType::U8, DType::F16) => "s_u8_f16",
(DType::U8, DType::BF16) => "s_u8_bf16",
(DType::U32, DType::U32) => "s_u32_u32",
(DType::U32, DType::F32) => "s_u32_f32",
(DType::U32, DType::F16) => "s_u32_f16",
(DType::U32, DType::BF16) => "s_u32_bf16",
(DType::I64, DType::F32) => "s_i64_f32",
(DType::I64, DType::F16) => "s_i64_f16",
(DType::I64, DType::BF16) => "s_i64_bf16",
_ => Err(MetalError::UnexpectedDType {
msg: "scatter ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let command_buffer = self.device.command_buffer()?;
let dst = buffer_o(&self.buffer, l, self.dtype);
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_scatter(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
l.dims(),
dim,
src,
ids,
dst,
)
.map_err(MetalError::from)?;
Ok(())
}
fn scatter_add_set(
&mut self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<()> {
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
};
let name = match (ids.dtype, self.dtype) {
@ -1364,9 +1505,10 @@ impl BackendStorage for MetalStorage {
})?,
};
let command_buffer = self.device.command_buffer()?;
let dst = buffer_o(&self.buffer, l, self.dtype);
let src = buffer_o(&src.buffer, src_l, src.dtype);
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
candle_metal_kernels::call_scatter_add(
candle_metal_kernels::call_scatter(
&self.device.device,
&command_buffer,
&self.device.kernels,
@ -1376,10 +1518,10 @@ impl BackendStorage for MetalStorage {
dim,
src,
ids,
&acc.buffer,
dst,
)
.map_err(MetalError::from)?;
Ok(acc)
Ok(())
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
@ -1513,50 +1655,32 @@ impl BackendStorage for MetalStorage {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul");
if self.dtype == DType::BF16 {
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
candle_metal_kernels::GemmDType::BF16,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(MetalError::Message(format!(
"mlx matmul doesn't support {dtype:?}"
))
.into())
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(
MetalError::Message(format!("mlx matmul doesn't support {dtype:?}")).into(),
)
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(
buffer,
self.device.clone(),
@ -1965,40 +2089,6 @@ impl BackendDevice for MetalDevice {
))
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let name = match dtype {
DType::U8 => "fill_u8",
DType::U32 => "fill_u32",
DType::I64 => "fill_i64",
DType::F16 => "fill_f16",
DType::BF16 => "fill_bf16",
DType::F32 => "fill_f32",
DType::F64 => {
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
return self.storage_from_cpu_storage(&cpu_storage);
}
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_const_fill(
&self.device,
&command_buffer,
&self.kernels,
name,
shape.elem_count(),
&buffer,
1.,
)
.map_err(MetalError::from)?;
Ok(MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let (count, buffer) = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),

View File

@ -80,6 +80,7 @@ pub enum Op {
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Gather(Tensor, Tensor, usize),
Scatter(Tensor, Tensor, Tensor, usize),
ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
IndexAdd(Tensor, Tensor, Tensor, usize),

View File

@ -45,6 +45,7 @@ pub enum OpCode {
BinFloat = b'G',
Append = b'a',
Appends = b'e',
Long1 = 0x8a,
}
// Avoid using FromPrimitive so as not to drag another dependency.
@ -84,6 +85,7 @@ impl TryFrom<u8> for OpCode {
b'G' => Ok(Self::BinFloat),
b'a' => Ok(Self::Append),
b'e' => Ok(Self::Appends),
0x8a => Ok(Self::Long1),
value => Err(value),
}
}
@ -106,6 +108,7 @@ pub enum Object {
class_name: String,
},
Int(i32),
Long(i64),
Float(f64),
Unicode(String),
Bool(bool),
@ -170,6 +173,14 @@ impl Object {
}
}
pub fn int_or_long(self) -> OResult<i64> {
match self {
Self::Int(t) => Ok(t as i64),
Self::Long(t) => Ok(t),
_ => Err(self),
}
}
pub fn tuple(self) -> OResult<Vec<Self>> {
match self {
Self::Tuple(t) => Ok(t),
@ -590,6 +601,15 @@ impl Stack {
let obj = self.new_obj(class, args)?;
self.push(obj)
}
OpCode::Long1 => {
let n_bytes = r.read_u8()?;
let mut v = 0;
// Decode the next n bytes in little endian
for i in 0..n_bytes {
v |= (r.read_u8()? as i64) << (i * 8);
}
self.push(Object::Long(v))
}
}
Ok(false)
}
@ -607,10 +627,10 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
let mut args = args.tuple()?;
let stride = Vec::<usize>::try_from(args.remove(3))?;
let size = Vec::<usize>::try_from(args.remove(2))?;
let offset = args.remove(1).int()? as usize;
let offset = args.remove(1).int_or_long()? as usize;
let storage = args.remove(0).persistent_load()?;
let mut storage = storage.tuple()?;
let storage_size = storage.remove(4).int()? as usize;
let storage_size = storage.remove(4).int_or_long()? as usize;
let path = storage.remove(2).unicode()?;
let (_module_name, class_name) = storage.remove(1).class()?;
let dtype = match class_name.as_str() {
@ -624,7 +644,11 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
crate::bail!("unsupported storage type {other}")
}
};
let layout = Layout::new(crate::Shape::from(size), stride, offset);
let layout = Layout::new(
crate::Shape::from(size),
stride,
offset * dtype.size_in_bytes(),
);
Ok((layout, dtype, path, storage_size))
}
@ -792,7 +816,7 @@ impl PthTensors {
/// # Arguments
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
/// contains multiple objects and the state_dict is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,

View File

@ -1,10 +1,10 @@
use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use crate::{builder_arg as barg, CudaDevice, CudaStorage, Result};
use half::f16;
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
use cudarc::driver::{CudaSlice, CudaView, PushKernelArg};
#[derive(Clone, Debug)]
struct PaddedCudaSlice {
@ -50,19 +50,20 @@ fn quantize_q8_1(
ky: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
let kx = elem_count;
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, ky as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
let params = (src, dst, kx as i32, kx_padded as i32);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(src);
builder.arg(dst);
barg!(builder, kx as i32, kx_padded as i32);
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
@ -72,9 +73,7 @@ fn dequantize_f32(
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let nb = elem_count.div_ceil(256);
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
@ -99,8 +98,8 @@ fn dequantize_f32(
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(elem_count)? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@ -110,15 +109,20 @@ fn dequantize_f32(
};
if is_k {
let params = (&data.inner, &dst);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
unsafe { builder.launch(cfg) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (&data.inner, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
barg!(builder, nb32 as i32);
unsafe { builder.launch(cfg) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -129,9 +133,7 @@ fn dequantize_f16(
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let nb = elem_count.div_ceil(256);
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
@ -156,8 +158,8 @@ fn dequantize_f16(
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@ -167,15 +169,20 @@ fn dequantize_f16(
};
if is_k {
let params = (&data.inner, &dst);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
unsafe { builder.launch(cfg) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (&data.inner, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
barg!(builder, nb32 as i32);
unsafe { builder.launch(cfg) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -188,8 +195,6 @@ fn dequantize_mul_mat_vec(
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
@ -210,8 +215,8 @@ fn dequantize_mul_mat_vec(
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows)? };
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
@ -219,8 +224,12 @@ fn dequantize_mul_mat_vec(
shared_mem_bytes: 0,
};
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(y);
builder.arg(&dst);
barg!(builder, ncols as i32, nrows as i32);
unsafe { builder.launch(cfg) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -233,8 +242,6 @@ fn mul_mat_vec_via_q8_1(
b_size: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
@ -249,7 +256,7 @@ fn mul_mat_vec_via_q8_1(
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes =
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
let kernel_name = match dtype {
@ -266,13 +273,13 @@ fn mul_mat_vec_via_q8_1(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size)? };
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
1 => (nrows as u32, 4),
2..=4 => ((nrows as u32 + 1) / 2, 4),
5..=8 => ((nrows as u32 + 1) / 2, 2),
2..=4 => ((nrows as u32).div_ceil(2), 4),
5..=8 => ((nrows as u32).div_ceil(2), 2),
_ => crate::bail!("unexpected bsize {b_size}"),
};
let cfg = cudarc::driver::LaunchConfig {
@ -281,16 +288,18 @@ fn mul_mat_vec_via_q8_1(
shared_mem_bytes: 0,
};
let params = (
&data.inner,
&y_q8_1,
&dst,
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&y_q8_1);
builder.arg(&dst);
barg!(
builder,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
/* nrows_y */ ncols_padded as i32,
/* nrows_dst */ nrows as i32,
/* nrows_dst */ nrows as i32
);
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -305,8 +314,6 @@ fn mul_mat_via_q8_1(
y_cols: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
@ -322,7 +329,7 @@ fn mul_mat_via_q8_1(
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
let (kernel_name, mmq_x, mmq_y) = match dtype {
@ -338,8 +345,8 @@ fn mul_mat_via_q8_1(
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols)? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
ceil_div(x_rows, mmq_y) as u32,
@ -350,17 +357,19 @@ fn mul_mat_via_q8_1(
shared_mem_bytes: 0,
};
let params = (
/* vx */ &data.inner,
/* vy */ &y_q8_1,
/* dst */ &dst,
let mut builder = func.builder();
builder.arg(/* vx */ &data.inner);
builder.arg(/* vy */ &y_q8_1);
builder.arg(/* dst */ &dst);
barg!(
builder,
/* ncols_x */ x_cols as i32,
/* nrows_x */ x_rows as i32,
/* ncols_y */ y_cols as i32,
/* nrows_y */ k_padded as i32,
/* nrows_dst */ x_rows as i32,
/* nrows_dst */ x_rows as i32
);
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -369,7 +378,7 @@ impl QCudaStorage {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let padded_size_in_bytes =
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes).w()?;
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes)?;
Ok(QCudaStorage {
data: PaddedCudaSlice {
inner,
@ -416,8 +425,7 @@ impl QCudaStorage {
let buffer = self
.device
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
.w()?;
.memcpy_dtov(&self.data.inner.slice(..self.data.len))?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
@ -448,9 +456,7 @@ impl QCudaStorage {
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
crate::cuda_backend::CudaStorageSlice::F32(data) => {
self.device.dtoh_sync_copy(data).w()?
}
crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?,
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
@ -460,10 +466,9 @@ impl QCudaStorage {
let data = qcpu_storage.data()?;
let padded_len =
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
self.device
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
.w()?;
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
self.data = PaddedCudaSlice {
inner,
len: data.len(),
@ -597,10 +602,8 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
};
let dtype = T::DTYPE;
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
device
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
.w()?;
let mut inner = unsafe { device.alloc::<u8>(padded_len)? };
device.memcpy_htod(data, &mut inner.slice_mut(..data.len()))?;
Ok(QStorage::Cuda(QCudaStorage {
data: PaddedCudaSlice {
inner,
@ -622,9 +625,9 @@ mod test {
let el_padded = pad(el, MATRIX_ROW_PADDING);
let y_size_in_bytes =
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let y = dev.memcpy_stod(&vs)?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
@ -634,7 +637,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let y = dev.memcpy_stod(&vs)?;
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
@ -647,7 +650,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
let vs = dev.memcpy_dtov(&vs.slice(..))?;
assert_eq!(vs.len(), 1);
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
// Q8 means 1/256 precision.
@ -662,7 +665,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
let vs = dev.memcpy_dtov(&vs.slice(..))?;
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], 5561851.0);
Ok(())
@ -673,7 +676,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let y = dev.memcpy_stod(&vs)?;
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
@ -687,7 +690,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
let vs = dev.memcpy_dtov(&vs.slice(..))?;
/*
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
@ -714,7 +717,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let (x_rows, ncols, y_cols) = (4, 16, 2048);
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let y = dev.memcpy_stod(&vs)?;
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
@ -728,7 +731,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
let _vs = dev.memcpy_dtov(&vs.slice(..))?;
Ok(())
}
}

View File

@ -1,6 +1,74 @@
//! TensorScalar Enum and Trait
//!
use crate::{Result, Tensor, WithDType};
use crate::{DType, Result, Tensor, WithDType};
use half::{bf16, f16};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Scalar {
U8(u8),
U32(u32),
I64(i64),
BF16(bf16),
F16(f16),
F32(f32),
F64(f64),
}
impl<T: WithDType> From<T> for Scalar {
fn from(value: T) -> Self {
value.to_scalar()
}
}
impl Scalar {
pub fn zero(dtype: DType) -> Self {
match dtype {
DType::U8 => Scalar::U8(0),
DType::U32 => Scalar::U32(0),
DType::I64 => Scalar::I64(0),
DType::BF16 => Scalar::BF16(bf16::ZERO),
DType::F16 => Scalar::F16(f16::ZERO),
DType::F32 => Scalar::F32(0.0),
DType::F64 => Scalar::F64(0.0),
}
}
pub fn one(dtype: DType) -> Self {
match dtype {
DType::U8 => Scalar::U8(1),
DType::U32 => Scalar::U32(1),
DType::I64 => Scalar::I64(1),
DType::BF16 => Scalar::BF16(bf16::ONE),
DType::F16 => Scalar::F16(f16::ONE),
DType::F32 => Scalar::F32(1.0),
DType::F64 => Scalar::F64(1.0),
}
}
pub fn dtype(&self) -> DType {
match self {
Scalar::U8(_) => DType::U8,
Scalar::U32(_) => DType::U32,
Scalar::I64(_) => DType::I64,
Scalar::BF16(_) => DType::BF16,
Scalar::F16(_) => DType::F16,
Scalar::F32(_) => DType::F32,
Scalar::F64(_) => DType::F64,
}
}
pub fn to_f64(&self) -> f64 {
match self {
Scalar::U8(v) => *v as f64,
Scalar::U32(v) => *v as f64,
Scalar::I64(v) => *v as f64,
Scalar::BF16(v) => v.to_f64(),
Scalar::F16(v) => v.to_f64(),
Scalar::F32(v) => *v as f64,
Scalar::F64(v) => *v,
}
}
}
pub enum TensorScalar {
Tensor(Tensor),

View File

@ -56,7 +56,7 @@ impl ArgSort {
mod cuda {
use super::*;
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
use crate::{CudaDevice, WithDType};
@ -69,27 +69,33 @@ mod cuda {
layout: &crate::Layout,
_wrap: W,
) -> Result<S> {
use cudarc::driver::PushKernelArg;
let slice = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let dst = unsafe { dev.alloc::<u32>(elem_count)? };
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let ncols_pad = next_power_of_2(ncols);
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
let cfg = LaunchConfig {
grid_dim: (1, nrows as u32, 1),
block_dim: (ncols_pad as u32, 1, 1),
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
};
unsafe { func.launch(cfg, params) }.w()?;
let stream = dev.cuda_stream();
let mut builder = stream.launch_builder(&func);
let ncols = ncols as i32;
let ncols_pad = ncols_pad as i32;
builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
unsafe { builder.launch(cfg) }.w()?;
Ok(S::U32(dst))
}
}

View File

@ -1,5 +1,6 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, ReduceOp};
use crate::scalar::Scalar;
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
@ -73,6 +74,14 @@ impl Storage {
}
}
pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
match self {
Storage::Cpu(storage) => storage.const_set(v, l),
Storage::Cuda(storage) => storage.const_set(v, l),
Storage::Metal(storage) => storage.const_set(v, l),
}
}
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
@ -619,32 +628,56 @@ impl Storage {
}
}
pub(crate) fn scatter_add(
&self,
pub(crate) fn scatter_set(
&mut self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<Self> {
) -> Result<()> {
self.same_device(indexes, "scatter-set")?;
self.same_device(source, "scatter-set")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
}
_ => unreachable!(),
}
Ok(())
}
pub(crate) fn scatter_add(
&mut self,
l: &Layout,
indexes: &Self,
indexes_l: &Layout,
source: &Self,
source_l: &Layout,
d: usize,
) -> Result<()> {
self.same_device(indexes, "scatter-add")?;
self.same_device(source, "scatter-add")?;
match (self, indexes, source) {
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cpu(storage))
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
}
_ => unreachable!(),
}
Ok(())
}
pub(crate) fn index_add(

View File

@ -3,7 +3,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::shape::{Dim, Dims, ShapeWithOneHole};
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@ -185,7 +185,9 @@ impl Tensor {
) -> Result<Self> {
let none = BackpropOp::none();
let shape = shape.into();
let storage = device.ones(&shape, dtype)?;
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
let layout = Layout::contiguous(shape.clone());
storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
Ok(from_storage(storage, shape, none, is_variable))
}
@ -202,6 +204,18 @@ impl Tensor {
Self::ones_impl(shape, dtype, device, false)
}
pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
self.storage_mut().const_set(value, self.layout())
}
pub fn zero_set(&self) -> Result<()> {
self.const_set(crate::scalar::Scalar::zero(self.dtype()))
}
pub fn one_set(&self) -> Result<()> {
self.const_set(crate::scalar::Scalar::one(self.dtype()))
}
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
///
/// ```rust
@ -368,8 +382,7 @@ impl Tensor {
Self::new_impl(array, shape, device, false)
}
/// Returns a new tensor with all the elements having the same specified value. Note that
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
/// Returns a new tensor with all the elements having the same specified value.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
@ -384,7 +397,12 @@ impl Tensor {
shape: S,
device: &Device,
) -> Result<Self> {
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
let none = BackpropOp::none();
let shape = shape.into();
let mut storage = unsafe { device.alloc_uninit(&shape, D::DTYPE)? };
let layout = Layout::contiguous(shape.clone());
storage.const_set(value.to_scalar(), &layout)?;
Ok(from_storage(storage, shape, none, false))
}
/// Creates a new 1D tensor from an iterator.
@ -452,17 +470,13 @@ impl Tensor {
Self::from_vec_impl(data, len, device, false)
}
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let shape = shape.into();
let buffer_size = data.len();
if buffer_size != shape.elem_count() {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let shape = shape.into_shape(data.len())?;
let storage = device.storage_owned(data)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, is_variable))
@ -481,7 +495,7 @@ impl Tensor {
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
@ -502,17 +516,12 @@ impl Tensor {
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
array: &[D],
shape: S,
device: &Device,
) -> Result<Self> {
let shape = shape.into();
let n: usize = shape.elem_count();
let buffer_size: usize = array.len();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let shape = shape.into_shape(array.len())?;
let storage = device.storage_from_slice(array)?;
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, false))
@ -1349,8 +1358,7 @@ impl Tensor {
self.index_select(ids, 0)
}
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter-add")?;
fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
let source_dims = source.dims();
let self_dims = self.dims();
let mismatch = if source_dims.len() != self_dims.len() {
@ -1367,7 +1375,7 @@ impl Tensor {
};
if mismatch {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (self, src)",
op: "scatter (self, src)",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
}
@ -1375,13 +1383,44 @@ impl Tensor {
}
if indexes.dims() != source.dims() {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (indexes, src)",
op: "scatter (indexes, src)",
lhs: indexes.shape().clone(),
rhs: source.shape().clone(),
}
.bt())?
}
let storage = self.storage().scatter_add(
Ok(())
}
pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter")?;
self.scatter_checks(indexes, source, dim)?;
let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let layout = Layout::contiguous(shape);
storage.scatter_set(
&layout,
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
Op::Scatter(t1, t2, t3, dim)
});
Ok(from_storage(storage, self.shape(), op, false))
}
pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
if self.same_storage(source) {
crate::bail!("cannot use slice_set when self and src share their storage")
}
let dim = dim.to_index(self.shape(), "scatter-set")?;
self.scatter_checks(indexes, source, dim)?;
self.storage_mut().scatter_set(
self.layout(),
&indexes.storage(),
indexes.layout(),
@ -1389,12 +1428,48 @@ impl Tensor {
source.layout(),
dim,
)?;
Ok(())
}
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "scatter-add")?;
self.scatter_checks(indexes, source, dim)?;
let shape = self.shape();
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let layout = Layout::contiguous(shape);
storage.scatter_add(
&layout,
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
Op::ScatterAdd(t1, t2, t3, dim)
});
Ok(from_storage(storage, self.shape(), op, false))
}
pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
if self.same_storage(source) {
crate::bail!("cannot use slice_set when self and src share their storage")
}
let dim = dim.to_index(self.shape(), "scatter-add-set")?;
self.scatter_checks(indexes, source, dim)?;
self.storage_mut().scatter_add(
self.layout(),
&indexes.storage(),
indexes.layout(),
&source.storage(),
source.layout(),
dim,
)?;
Ok(())
}
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
let dim = dim.to_index(self.shape(), "slice-scatter")?;
@ -2197,7 +2272,7 @@ impl Tensor {
///
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
let shape = s.into_shape(self.elem_count())?;
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {
@ -2580,6 +2655,28 @@ impl Tensor {
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}
/// Returns a new tensor with the order of elements reversed along the specified dimensions.
/// This function makes a copy of the tensors data.
///
/// ```rust
/// # use candle_core::{Tensor, Device};
/// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
/// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
/// let t_flipped = t.flip(&[0])?;
/// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
let mut result = self.clone();
for &dim in dims.iter() {
let size = result.dim(dim)?;
let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
result = result.index_select(&indices_tensor, dim)?;
}
Ok(result)
}
}
macro_rules! bin_trait {

View File

@ -241,7 +241,7 @@ impl Tensor {
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
/// has to be greater than or equal to `offset` plus the `src` size.
///
/// Note that this modifies `self` in place and as such is not compatibel with
/// Note that this modifies `self` in place and as such is not compatible with
/// back-propagation.
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
let dim = dim.to_index(self.shape(), "slice-set")?;

View File

@ -24,6 +24,15 @@ macro_rules! test_device {
};
}
pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
assert_eq!(t1.shape(), t2.shape());
// Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
let all_equal = eq_tensor.sum_all()?;
assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
Ok(())
}
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
let b = 10f32.powi(digits);
let t = t.to_vec0::<f32>()?;

View File

@ -53,6 +53,20 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
@ -163,6 +177,22 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv2d(&w, 0, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;

View File

@ -1,6 +1,6 @@
#![allow(clippy::approx_constant)]
use anyhow::{Context, Result};
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
use candle_core::{test_device, test_utils, DType, Device, Shape, Tensor, Var};
fn simple_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., 4.], device)?;
@ -505,6 +505,36 @@ fn binary_grad(device: &Device) -> Result<()> {
Ok(())
}
#[test]
fn test_flip_backprop() -> Result<()> {
let device = &Device::Cpu;
// Create a tensor (leaf node) that requires gradients
let x = Var::ones((2, 2), DType::F64, device)?;
let weights = Tensor::arange(1.0, 5.0, device)?.reshape((2, 2))?;
let y = x.matmul(&weights)?;
let expected_y = Tensor::from_vec(vec![4.0, 6.0, 4.0, 6.0], (2, 2), device)?;
candle_core::test_utils::assert_tensor_eq(&y, &expected_y)?;
let z = y.flip(&[1])?;
let expected_z = Tensor::from_vec(vec![6.0, 4.0, 6.0, 4.0], (2, 2), device)?;
candle_core::test_utils::assert_tensor_eq(&z, &expected_z)?;
let loss = z.sum_all()?;
let grad_store = loss.backward()?;
let grad_x = grad_store.get_id(x.id()).unwrap();
let flipped_weights = weights.flip(&[1])?;
let dloss_dy = Tensor::ones((2, 2), DType::F64, device)?;
// dloss/dx = dloss/dy @ dy/dx = ones @ weight.flip.T
let expected_grad = dloss_dy.matmul(&flipped_weights.t()?)?;
candle_core::test_utils::assert_tensor_eq(grad_x, &expected_grad)?;
Ok(())
}
test_device!(
simple_grad,
simple_grad_cpu,

View File

@ -25,10 +25,12 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
if !device.is_metal() {
assert_eq!(
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
}
assert_eq!(
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
[
@ -63,6 +65,26 @@ fn ones(device: &Device) -> Result<()> {
}
fn full(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((3, 4), DType::U32, device)?;
tensor.const_set(42u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 42, 42], [42, 42, 42, 42], [42, 42, 42, 42]]
);
tensor.i((.., 2))?.const_set(1337u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 1337, 42], [42, 42, 1337, 42], [42, 42, 1337, 42]]
);
tensor.i((2, ..))?.const_set(1u32.into())?;
assert_eq!(
tensor.to_vec2::<u32>()?,
[[42, 42, 1337, 42], [42, 42, 1337, 42], [1, 1, 1, 1]]
);
Ok(())
}
fn const_set(device: &Device) -> Result<()> {
assert_eq!(
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
[[42, 42, 42], [42, 42, 42]],
@ -823,9 +845,37 @@ fn embeddings(device: &Device) -> Result<()> {
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let ids = Tensor::new(&[u32::MAX, 2u32, u32::MAX], device)?;
let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]]);
Ok(())
}
#[test]
fn index_select_fail() -> Result<()> {
// Check that an error is properly reported on out of bounds.
let ids = Tensor::new(&[4u32, 2u32, 1u32], &Device::Cpu)?;
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &Device::Cpu)?;
let hs = t.index_select(&ids, 0);
assert!(hs.is_err());
Ok(())
}
// The test below triggers an unwinding panic as there is a panic within the
// #[cfg(feature = "cuda")]
// #[test]
// #[should_panic]
// fn index_select_fail_gpu() {
// // Check that a panic happens for out of bounds in cuda
// if let Ok(device) = Device::new_cuda(0) {
// if let Ok(ids) = Tensor::new(&[4u32, 2u32, 1u32], &device) {
// if let Ok(t) = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &device) {
// let _ = t.index_select(&ids, 0);
// }
// }
// }
// }
fn cmp(device: &Device) -> Result<()> {
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
@ -980,7 +1030,7 @@ fn slice_scatter(device: &Device) -> Result<()> {
Ok(())
}
fn scatter_add(device: &Device) -> Result<()> {
fn scatter(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
@ -1004,6 +1054,17 @@ fn scatter_add(device: &Device) -> Result<()> {
]
);
let hs = init.scatter(&ids, &t, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0, 1.0, 1.0],
[5.0, 1.0, 1.0, 3.0, 4.0],
[1.0, 8.0, 1.0, 7.0, 1.0],
[10.0, 1.0, 9.0, 1.0, 11.0]
]
);
let init = Tensor::ones((6, 3), DType::F32, device)?;
let hs = init.scatter_add(&ids, &t, 0)?;
assert_eq!(
@ -1017,6 +1078,56 @@ fn scatter_add(device: &Device) -> Result<()> {
[1.0, 1.0, 1.0]
]
);
let hs = init.scatter(&ids, &t, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 10.0, 5.0],
[1.0, 1.0, 8.0],
[9.0, 1.0, 2.0],
[6.0, 7.0, 1.0],
[1.0, 4.0, 11.0],
[1.0, 1.0, 1.0]
]
);
let hs = {
let ids = Tensor::new(
&[
[0u32, u32::MAX, 2],
[3, 4, u32::MAX],
[3, 3, 1],
[u32::MAX, u32::MAX, 4],
],
device,
)?;
init.scatter(&ids, &t, 0)?
};
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 1.0],
[1.0, 1.0, 8.0],
[1.0, 1.0, 2.0],
[6.0, 7.0, 1.0],
[1.0, 4.0, 11.0],
[1.0, 1.0, 1.0]
]
);
init.scatter_set(&ids, &t, 0)?;
assert_eq!(
init.to_vec2::<f32>()?,
&[
[0.0, 10.0, 5.0],
[1.0, 1.0, 8.0],
[9.0, 1.0, 2.0],
[6.0, 7.0, 1.0],
[1.0, 4.0, 11.0],
[1.0, 1.0, 1.0]
]
);
Ok(())
}
@ -1050,6 +1161,23 @@ fn gather(device: &Device) -> Result<()> {
let hs = t.gather(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
let hs = {
let ids = Tensor::new(
&[
[0u32, 0u32],
[2u32, u32::MAX],
[u32::MAX, 1u32],
[0u32, 2u32],
],
device,
)?;
t.gather(&ids, 1)?
};
assert_eq!(
hs.to_vec2::<f32>()?,
&[[0.0, 0.0], [5.0, 0.0], [0.0, 7.0], [9.0, 11.0]]
);
// Random data
// Dim: 0
@ -1484,6 +1612,7 @@ fn zero_dim(device: &Device) -> Result<()> {
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
test_device!(const_set, cs_cpu, cs_gpu, cs_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
@ -1515,12 +1644,7 @@ test_device!(
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(scatter, scatter_cpu, scatter_gpu, scatter_metal);
test_device!(
slice_scatter,
slice_scatter_cpu,
@ -1682,3 +1806,77 @@ fn pow() -> Result<()> {
);
Ok(())
}
#[test]
fn test_flip_1d() -> Result<()> {
// 1D: [0, 1, 2, 3, 4]
let t = Tensor::arange(0.0, 5.0, &Device::Cpu)?.reshape((5,))?;
let flipped = t.flip(&[0])?;
// Expected: [4, 3, 2, 1, 0]
let expected = Tensor::from_vec(vec![4.0, 3.0, 2.0, 1.0, 0.0], (5,), &Device::Cpu)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn test_flip_2d() -> Result<()> {
// 2D:
// [[0, 1, 2],
// [3, 4, 5]]
let t = Tensor::arange(0.0, 6.0, &Device::Cpu)?.reshape((2, 3))?;
let flipped = t.flip(&[0, 1])?;
// Expected:
// [[5, 4, 3],
// [2, 1, 0]]
let expected = Tensor::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0], (2, 3), &Device::Cpu)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn test_flip_3d_channels() -> Result<()> {
// 3D:
// [[[0,1,2],
// [3,4,5]],
//
// [[6,7,8],
// [9,10,11]]]
let t = Tensor::arange(0.0, 12.0, &Device::Cpu)?.reshape((2, 2, 3))?;
let flipped = t.flip(&[2])?;
// Expected:
// [[[2,1,0],
// [5,4,3]],
//
// [[8,7,6],
// [11,10,9]]]
let expected = Tensor::from_vec(
vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0, 8.0, 7.0, 6.0, 11.0, 10.0, 9.0],
(2, 2, 3),
&Device::Cpu,
)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn tensor_new() -> Result<()> {
let t1 = Tensor::new(vec![1f32, 2.0, 3.0], &Device::Cpu)?;
assert_eq!(t1.to_vec1::<f32>()?, [1.0, 2.0, 3.0]);
let t2 = Tensor::new(vec![vec![1f32, 2., 3.], vec![4., 5., 6.]], &Device::Cpu)?;
assert_eq!(t2.to_vec2::<f32>()?, [[1., 2., 3.], [4., 5., 6.]]);
let t3 = Tensor::new(
vec![
vec![vec![1f32, 2., 3.], vec![4., 5., 6.]],
vec![vec![3f32, 1., 4.], vec![1., 5., 9.]],
],
&Device::Cpu,
)?;
assert_eq!(
t3.to_vec3::<f32>()?,
[
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[3.0, 1.0, 4.0], [1.0, 5.0, 9.0]]
]
);
Ok(())
}

View File

@ -72,6 +72,8 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
// image-rs crate convention is to load in (width, height, channels) order
// See: https://docs.rs/image/latest/image/trait.ImageDecoder.html#tymethod.dimensions
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_rgb8().as_raw());
}
@ -81,8 +83,10 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
}
}
}
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
.to_dtype(DType::U8)?
// Reorder image-rs convention (width, height, channels) to candle/pytorch convolution convention (channels, height, width)
let images = (Tensor::from_vec(buffer_images, (samples, 32, 32, 3), &Device::Cpu)?
.to_dtype(DType::F32)?
.permute((0, 3, 2, 1))?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))

View File

@ -16,10 +16,9 @@ fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> {
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
let magic_number = read_u32(reader)?;
if magic_number != expected {
Err(io::Error::new(
io::ErrorKind::Other,
format!("incorrect magic number {magic_number} != {expected}"),
))?;
Err(io::Error::other(format!(
"incorrect magic number {magic_number} != {expected}"
)))?;
}
Ok(())
}

View File

@ -60,7 +60,7 @@ bindgen_cuda = { version = "0.1.1", optional = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
cudnn = ["candle/cudnn"]
cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
@ -69,6 +69,7 @@ metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal", "rubato"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
snac = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
[[example]]
@ -107,6 +108,10 @@ required-features = ["candle-datasets"]
name = "mimi"
required-features = ["mimi"]
[[example]]
name = "snac"
required-features = ["snac"]
[[example]]
name = "encodec"
required-features = ["encodec"]

View File

@ -4,7 +4,7 @@ Experimental, not instruction-tuned small LLM from the Hazy Research group, comb
[Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based)
[Simple linear attention language models balance the recall-throughput tradeoff](https://arxiv.org/abs/2402.18668)
[Simple linear attention language models balance the recall-throughput tradeoff](https://huggingface.co/papers/2402.18668)
## Running an example

View File

@ -1,6 +1,6 @@
# candle-beit
[Beit](https://arxiv.org/abs/2106.08254) is a computer vision model.
[Beit](https://huggingface.co/papers/2106.08254) is a computer vision model.
In this example, it is used as an ImageNet classifier: the model returns the
probability for the image to belong to each of the 1000 ImageNet categories.

View File

@ -0,0 +1,13 @@
# candle-chatglm
Uses `THUDM/chatglm3-6b` to generate chinese text. Will not generate text for english (usually).
## Text Generation
```bash
cargo run --example chatglm --release -- --prompt "部署门槛较低等众多优秀特 "
> 部署门槛较低等众多优秀特 点使得其成为了一款备受欢迎的AI助手。
>
> 作为一款人工智能助手ChatGLM3-6B
```

View File

@ -0,0 +1,42 @@
# candle-chinese-clip
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
pairs of images with related texts. This one is trained using in chinese instead of english.
## Running on cpu
```bash
$ cargo run --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
>
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
>
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
>
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
```
## Running on metal
```bash
$ cargo run --features metal --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
>
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
>
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
>
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
```

View File

@ -0,0 +1,17 @@
# candle-convmixer
A lightweight CNN architecture that processes image patches similar to a vision transformer, with separate spatial and channel convolutions.
ConvMixer from [Patches Are All You Need?](https://huggingface.co/papers/2201.09792) and [ConvMixer](https://github.com/locuslab/convmixer).
## Running an example
```bash
$ cargo run --example convmixer --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
> mountain bike, all-terrain bike, off-roader: 61.75%
> unicycle, monocycle : 5.73%
> moped : 3.66%
> bicycle-built-for-two, tandem bicycle, tandem: 3.51%
> crash helmet : 0.85%
```

View File

@ -1,7 +1,7 @@
# candle-convnext
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
[A ConvNet for the 2020s](https://huggingface.co/papers/2201.03545) and
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://huggingface.co/papers/2301.00808).
This candle implementation uses a pre-trained ConvNeXt network for inference. The
classification head has been trained on the ImageNet dataset and returns the

View File

@ -0,0 +1,14 @@
# Conversational Speech Model (CSM)
CSM is a speech generation model from Sesame,
[SesameAILabs/csm](https://github.com/SesameAILabs/csm).
It can generate a conversational speech between two different speakers.
The speakers turn are delimited by the `|` character in the prompt.
```bash
cargo run --example csm --features cuda -r -- \
--voices candle-examples/examples/csm/voices.safetensors \
--prompt "Hey how are you doing?|Pretty good, pretty good. How about you?"
```

View File

@ -0,0 +1,243 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::csm::{Config, Model};
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "1b")]
Csm1b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
/// The prompt to be used for the generation, use a | to separate the speakers.
#[arg(long, default_value = "Hey how are you doing today?")]
prompt: String,
/// The voices to be used, in safetensors format.
#[arg(long)]
voices: String,
/// The output file using the wav format.
#[arg(long, default_value = "out.wav")]
out_file: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.7)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "1b")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
config: Option<String>,
#[arg(long)]
weights: Option<String>,
/// The mimi model weight file, in safetensor format.
#[arg(long)]
mimi_weights: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
let name = match args.which {
Which::Csm1b => "sesame/csm-1b",
};
name.to_string()
}
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let filenames = match args.weights {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => vec![repo.get("model.safetensors")?],
};
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("meta-llama/Llama-3.2-1B".to_string())
.get("tokenizer.json")?,
};
let mimi_filename = match args.mimi_weights {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("kyutai/mimi".to_string())
.get("model.safetensors")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = match args.config {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
};
let device = candle_examples::device(args.cpu)?;
let (mut model, device) = {
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
(model, device)
};
let mut mimi_model = {
use candle_transformers::models::mimi;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? };
let config = mimi::Config::v0_1(Some(32));
mimi::Model::new(config, vb)?
};
let cb = config.audio_num_codebooks;
println!("loaded the model in {:?}", start.elapsed());
let voices = candle::safetensors::load(args.voices, &device)?;
let mut lp = candle_transformers::generation::LogitsProcessor::new(
args.seed,
Some(args.temperature),
None,
);
let tokens = voices
.get("tokens")
.expect("no tokens in prompt")
.to_dtype(DType::U32)?;
let mask = voices.get("mask").expect("no mask in prompt").clone();
let mut pos = 0;
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
let mut all_pcms = vec![];
for (turn_idx, prompt) in args.prompt.split('|').enumerate() {
println!("{prompt:?}");
let speaker_idx = turn_idx % 2;
let prompt = format!("[{speaker_idx}]{}<|end_of_text|>", prompt);
let prompt = tokenizer.encode(prompt, true).map_err(E::msg)?;
let (mut tokens, mut mask) = model.text_tokens_and_mask(prompt.get_ids())?;
let mut generated_tokens = vec![];
loop {
let frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
let is_done = frame.iter().all(|&x| x == 0);
(tokens, mask) = model.audio_tokens_and_mask(frame)?;
print!("\rframe {pos}");
if is_done {
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
break;
}
generated_tokens.push(tokens.clone());
}
println!();
let generated_tokens = Tensor::cat(&generated_tokens, 1)?.narrow(2, 0, cb)?.t()?;
let pcm = mimi_model.decode(&generated_tokens)?;
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
all_pcms.push(pcm);
}
let pcm = Tensor::cat(&all_pcms, 0)?;
let pcm = pcm.to_vec1::<f32>()?;
println!("writing output file {}", args.out_file);
let mut output = std::fs::File::create(args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
Ok(())
}

Binary file not shown.

View File

@ -0,0 +1,17 @@
# candle-custom-ops
This example illustrates how to implement forward and backward passes for custom operations on the CPU and GPU.
The custom op in this example implements RMS normalization for the CPU and CUDA.
## Running an example
```bash
$ cargo run --example custom-ops
> [[ 0., 1., 2., 3., 4., 5., 6.],
> [ 7., 8., 9., 10., 11., 12., 13.]]
> Tensor[[2, 7], f32]
> [[0.0000, 0.2773, 0.5547, 0.8320, 1.1094, 1.3867, 1.6641],
> [0.6864, 0.7845, 0.8825, 0.9806, 1.0786, 1.1767, 1.2748]]
> Tensor[[2, 7], f32]
```

View File

@ -56,7 +56,7 @@ impl CustomOp1 for LayerNorm {
layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
use candle::cuda_backend::cudarc::driver::{LaunchConfig, PushKernelArg};
use candle::cuda_backend::WrapErr;
let (d1, d2) = layout.shape().dims2()?;
let d1 = d1 as u32;
@ -68,15 +68,19 @@ impl CustomOp1 for LayerNorm {
Some((o1, o2)) => slice.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
let params = (&dst, &slice, self.eps, d1, d2);
let dst = unsafe { dev.alloc::<f32>(elem_count) }?;
let func =
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
let cfg = LaunchConfig {
grid_dim: (d1, 1, 1),
block_dim: (d2, 1, 1),
shared_mem_bytes: 0,
};
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&dst);
builder.arg(&slice);
candle::builder_arg!(builder, self.eps, d1, d2);
unsafe { builder.launch(cfg) }.w()?;
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, layout.shape().clone()))

View File

@ -20,8 +20,8 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{Encoding, PaddingParams, Tokenizer};
enum TaskType {
Ner(DebertaV2NERModel),
TextClassification(DebertaV2SeqClassificationModel),
Ner(Box<DebertaV2NERModel>),
TextClassification(Box<DebertaV2SeqClassificationModel>),
}
#[derive(Parser, Debug, Clone, ValueEnum)]
@ -169,21 +169,16 @@ impl Args {
match self.task {
ArgsTask::Ner => Ok((
TaskType::Ner(DebertaV2NERModel::load(
vb,
&config,
Some(id2label.clone()),
)?),
TaskType::Ner(DebertaV2NERModel::load(vb, &config, Some(id2label.clone()))?.into()),
config,
tokenizer,
id2label,
)),
ArgsTask::TextClassification => Ok((
TaskType::TextClassification(DebertaV2SeqClassificationModel::load(
vb,
&config,
Some(id2label.clone()),
)?),
TaskType::TextClassification(
DebertaV2SeqClassificationModel::load(vb, &config, Some(id2label.clone()))?
.into(),
),
config,
tokenizer,
id2label,

View File

@ -1,6 +1,6 @@
# candle-dinov2-reg4
[DINOv2-reg4](https://arxiv.org/abs/2309.16588) is the lastest version of DINOv2 with registers.
[DINOv2-reg4](https://huggingface.co/papers/2309.16588) is the lastest version of DINOv2 with registers.
In this example, it is used as an plant species classifier: the model returns the
probability for the image to belong to each of the 7806 PlantCLEF2024 categories.

View File

@ -1,5 +1,5 @@
//! DINOv2 reg4 finetuned on PlantCLEF 2024
//! https://arxiv.org/abs/2309.16588
//! https://huggingface.co/papers/2309.16588
//! https://huggingface.co/spaces/BVRA/PlantCLEF2024
//! https://zenodo.org/records/10848263

View File

@ -8,7 +8,7 @@ DistilBert is used to compute the sentence embeddings for a prompt. The model we
are downloaded from the hub on the first run.
```bash
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
$ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
@ -20,3 +20,25 @@ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
> Tensor[[1, 7, 768], f32]
```
## Masked Token
DistilBert is used to compute the top K choices for a masked token.
```bash
$ cargo run --example distilbert -- --prompt "The capital of France is [MASK]." --top-k 10
> Input: The capital of France is [MASK].
> Predictions for [MASK] at position 6:
> 1: marseille (probability: 12.14%)
> 2: paris (probability: 10.84%)
> 3: toulouse (probability: 8.57%)
> 4: lyon (probability: 7.61%)
> 5: montpellier (probability: 5.18%)
> 6: bordeaux (probability: 4.88%)
> 7: nantes (probability: 4.82%)
> 8: lille (probability: 4.07%)
> 9: strasbourg (probability: 3.12%)
> 10: cannes (probability: 3.04%)
```

View File

@ -3,15 +3,48 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
use candle_transformers::models::distilbert::{
Config, DistilBertForMaskedLM, DistilBertModel, DTYPE,
};
use anyhow::{Error as E, Result};
use anyhow::{Context, Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::path::PathBuf;
use tokenizers::Tokenizer;
enum ModelType {
Masked(Box<DistilBertForMaskedLM>),
UnMasked(Box<DistilBertModel>),
}
impl ModelType {
fn device(&self) -> &Device {
match self {
ModelType::Masked(model) => &model.bert.device,
ModelType::UnMasked(model) => &model.device,
}
}
fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
match self {
ModelType::Masked(model) => Ok(model.forward(input_ids, attention_mask)?),
ModelType::UnMasked(model) => Ok(model.forward(input_ids, attention_mask)?),
}
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "distilbert")]
DistilBert,
#[value(name = "distilbertformaskedlm")]
DistilbertForMaskedLM,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -23,10 +56,14 @@ struct Args {
#[arg(long)]
tracing: bool,
#[arg(long, default_value = "distilbert")]
model: Which,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
/// Revision or branch
#[arg(long)]
revision: Option<String>,
@ -42,94 +79,248 @@ struct Args {
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
/// Number of top predictions to show for each mask
#[arg(long, default_value = "5")]
top_k: usize,
}
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
fn build_model_and_tokenizer(&self) -> Result<(ModelType, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
let (model_id, revision) = self.resolve_model_and_revision();
let (config_path, tokenizer_path, weights_path) =
self.download_model_files(&model_id, &revision)?;
let config = std::fs::read_to_string(config_path)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
let vb = self.load_variables(&weights_path, &device)?;
let model = self.create_model(&config, vb)?;
Ok((model, tokenizer))
}
fn resolve_model_and_revision(&self) -> (String, String) {
let default_model = "distilbert-base-uncased".to_string();
let default_revision = "main".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
match (self.model_id.clone(), self.revision.clone()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(Some(model_id), None) => (model_id, default_revision),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
}
}
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
fn download_model_files(
&self,
model_id: &str,
revision: &str,
) -> Result<(PathBuf, PathBuf, PathBuf)> {
let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
let api = Api::new()?;
let api = api.repo(repo);
let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
api.get("model.safetensors")?
};
let model = DistilBertModel::load(vb, &config)?;
Ok((model, tokenizer))
Ok((config, tokenizer, weights))
}
fn load_variables(&self, weights_path: &PathBuf, device: &Device) -> Result<VarBuilder> {
if self.use_pth {
Ok(VarBuilder::from_pth(weights_path, DTYPE, device)?)
} else {
Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, device)? })
}
}
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
match self.model {
Which::DistilbertForMaskedLM => Ok(ModelType::Masked(
DistilBertForMaskedLM::load(vb, config)?.into(),
)),
Which::DistilBert => Ok(ModelType::UnMasked(
DistilBertModel::load(vb, config)?.into(),
)),
}
}
}
fn get_mask(size: usize, device: &Device) -> Tensor {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device).unwrap()
fn main() -> Result<()> {
let args = Args::parse();
let _guard = setup_tracing(&args);
let (model, tokenizer) = args.build_model_and_tokenizer()?;
let device = model.device();
let (token_ids, mask) = prepare_inputs(&args, &tokenizer, device)?;
let output = model.forward(&token_ids, &mask)?;
process_output(&model, &output, &token_ids, &tokenizer, &args)?;
Ok(())
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
fn setup_tracing(args: &Args) -> Option<impl Drop> {
if args.tracing {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
}
}
let tokenizer = tokenizer
fn prepare_inputs(args: &Args, tokenizer: &Tokenizer, device: &Device) -> Result<(Tensor, Tensor)> {
let mut binding = tokenizer.clone();
let tokenizer_configured = binding
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(args.prompt, true)
let tokens = tokenizer_configured
.encode(args.prompt.clone(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mask = get_mask(tokens.len(), device);
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
println!("mask: {:?}", mask.to_vec2::<u8>());
let mask = match args.model {
Which::DistilbertForMaskedLM => attention_mask_maskedlm(tokenizer, &args.prompt, device)?,
Which::DistilBert => attention_mask(tokens.len(), device)?,
};
let ys = model.forward(&token_ids, &mask)?;
println!("{ys}");
println!("token_ids: {:?}", token_ids.to_vec2::<u32>()?);
Ok((token_ids, mask))
}
fn process_output(
model: &ModelType,
output: &Tensor,
token_ids: &Tensor,
tokenizer: &Tokenizer,
args: &Args,
) -> Result<()> {
match model {
ModelType::UnMasked(_) => {
println!("embeddings");
println!("{output}");
}
ModelType::Masked(_) => {
process_masked_output(output, token_ids, tokenizer, args)?;
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
fn process_masked_output(
output: &Tensor,
token_ids: &Tensor,
tokenizer: &Tokenizer,
args: &Args,
) -> Result<()> {
let input_ids_vec = token_ids.to_vec2::<u32>()?;
let mask_token_id = tokenizer
.token_to_id("[MASK]")
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
println!("\nInput: {}", args.prompt);
for (token_idx, &token_id) in input_ids_vec[0].iter().enumerate() {
if token_id == mask_token_id {
println!("Predictions for [MASK] at position {}:", token_idx);
let pos_logits = output.get(0)?.get(token_idx)?;
let probs = candle_nn::ops::softmax(&pos_logits, 0)?;
let (top_values, top_indices) = get_top_k(&probs, args.top_k)?;
let values = top_values.to_vec1::<f32>()?;
let indices = top_indices.to_vec1::<u32>()?;
for (i, (&token_id, &prob)) in indices.iter().zip(values.iter()).enumerate() {
let token = tokenizer.decode(&[token_id], false).map_err(E::msg)?;
println!(
" {}: {:15} (probability: {:.2}%)",
i + 1,
token,
prob * 100.0
);
}
}
}
Ok(())
}
fn get_top_k(tensor: &Tensor, k: usize) -> Result<(Tensor, Tensor)> {
let n = tensor.dims().iter().product::<usize>();
let k = std::cmp::min(k, n);
let values = tensor.to_vec1::<f32>()?;
let mut value_indices: Vec<(f32, usize)> = values
.into_iter()
.enumerate()
.map(|(idx, val)| (val, idx))
.collect();
value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let top_k_values: Vec<f32> = value_indices.iter().take(k).map(|(val, _)| *val).collect();
let top_k_indices: Vec<u32> = value_indices
.iter()
.take(k)
.map(|(_, idx)| *idx as u32)
.collect();
let device = tensor.device();
let top_values = Tensor::from_vec(top_k_values, (k,), device)?;
let top_indices = Tensor::from_vec(top_k_indices, (k,), device)?;
Ok((top_values, top_indices))
}
fn attention_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Ok(Tensor::from_slice(&mask, (size, size), device)?)
}
fn attention_mask_maskedlm(tokenizer: &Tokenizer, input: &str, device: &Device) -> Result<Tensor> {
let tokens = tokenizer.encode(input, true).map_err(E::msg)?;
let seq_len = tokens.get_attention_mask().to_vec().len();
let mask_token_id = tokenizer
.token_to_id("[MASK]")
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
let mut attention_mask_vec = Vec::with_capacity(seq_len * seq_len);
let ids = tokens.get_ids();
for _ in 0..seq_len {
for id in ids.iter() {
let mask_value = if id == &mask_token_id { 1u8 } else { 0u8 };
attention_mask_vec.push(mask_value);
}
}
let shape = (1, 1, seq_len, seq_len);
let mask = Tensor::from_vec(attention_mask_vec, shape, device)?;
Ok(mask)
}

View File

@ -0,0 +1,15 @@
# candle-efficientnet
Demonstrates a Candle implementation of EfficientNet for image classification based on ImageNet classes.
## Running an example
```bash
$ cargo run --example efficientnet --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which b1
> bicycle-built-for-two, tandem bicycle, tandem: 45.85%
> mountain bike, all-terrain bike, off-roader: 30.45%
> crash helmet : 2.58%
> unicycle, monocycle : 2.21%
> tricycle, trike, velocipede: 1.53%
```

View File

@ -1,6 +1,6 @@
//! EfficientNet implementation.
//!
//! https://arxiv.org/abs/1905.11946
//! https://huggingface.co/papers/1905.11946
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

View File

@ -1,6 +1,6 @@
# candle-efficientvit
[EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention](https://arxiv.org/abs/2305.07027).
[EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention](https://huggingface.co/papers/2305.07027).
This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.

View File

@ -1,6 +1,6 @@
# candle-eva2
[EVA-02](https://arxiv.org/abs/2303.11331) is a computer vision model.
[EVA-02](https://huggingface.co/papers/2303.11331) is a computer vision model.
In this example, it is used as an ImageNet classifier: the model returns the
probability for the image to belong to each of the 1000 ImageNet categories.

View File

@ -1,3 +1,10 @@
# candle-falcon
Falcon is a general large language model.
## Running an example
Make sure to include the `--use-f32` flag if using CPU, because there isn't a BFloat16 implementation yet.
```
cargo run --example falcon --release -- --prompt "Flying monkeys are" --use-f32
```

View File

@ -1,6 +1,6 @@
# candle-fastvit
[FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189).
[FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://huggingface.co/papers/2303.14189).
This candle implementation uses a pre-trained FastViT network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.

View File

@ -50,6 +50,8 @@ enum Which {
InstructV2_9B,
#[value(name = "3-1b")]
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
}
enum Model {
@ -122,6 +124,17 @@ impl TextGeneration {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
println!(
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
);
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
@ -144,7 +157,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
@ -272,6 +285,7 @@ fn main() -> Result<()> {
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
@ -292,13 +306,10 @@ fn main() -> Result<()> {
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.which == Which::BaseV3_1B {
vec![repo.get("model.safetensors")?]
} else {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
}
None => match args.which {
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
@ -331,7 +342,7 @@ fn main() -> Result<()> {
let model = Model2::new(args.use_flash_attn, &config, vb)?;
Model::V2(model)
}
Which::BaseV3_1B => {
Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(args.use_flash_attn, &config, vb)?;
Model::V3(model)
@ -350,6 +361,31 @@ fn main() -> Result<()> {
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
let prompt = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B
| Which::BaseV2_2B
| Which::InstructV2_2B
| Which::BaseV2_9B
| Which::InstructV2_9B
| Which::BaseV3_1B => args.prompt,
Which::InstructV3_1B => {
format!(
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
args.prompt
)
}
};
pipeline.run(&prompt, args.sample_len)?;
Ok(())
}

View File

@ -12,7 +12,7 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
** Running with ~cpu~
#+begin_src shell
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
cargo run --example glm4 --release -- --cpu --prompt "Hello world"
#+end_src
** Output Example

View File

@ -3,7 +3,7 @@
gte-Qwen1.5-7B-instruct is a variant of the GTE embedding model family.
- [Model card](https://huggingface.co/Alibaba-NLP/gte-Qwen1.5-7B-instruct) on the HuggingFace Hub.
- [Technical report](https://arxiv.org/abs/2308.03281) *Towards General Text Embeddings with Multi-stage Contrastive Learning*
- [Technical report](https://huggingface.co/papers/2308.03281) *Towards General Text Embeddings with Multi-stage Contrastive Learning*
## Running the example

View File

@ -7,7 +7,10 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::helium::{Config, Model};
use candle_transformers::models::helium::{Config as ConfigPreview, Model as ModelPreview};
use candle_transformers::models::llama::{
Cache as CacheV1, Llama as ModelV1, LlamaConfig as ConfigV1, LlamaEosToks,
};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -16,6 +19,44 @@ use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Debug, Clone)]
enum Model {
V1 { model: ModelV1, cache: CacheV1 },
Preview(ModelPreview),
}
impl Model {
fn forward(&mut self, input: &Tensor, start_pos: usize) -> Result<Tensor> {
let model = match self {
Model::V1 { model, cache } => model.forward(input, start_pos, cache)?,
Model::Preview(m) => m.forward(input, start_pos)?,
};
Ok(model)
}
}
#[derive(Debug, Clone)]
enum Config {
V1(ConfigV1),
Preview(ConfigPreview),
}
impl Config {
fn bos_token_id(&self) -> Option<u32> {
match self {
Config::V1(c) => c.bos_token_id,
Config::Preview(c) => Some(c.bos_token_id),
}
}
fn eos_token_id(&self) -> Option<LlamaEosToks> {
match self {
Config::V1(c) => c.eos_token_id.clone(),
Config::Preview(c) => Some(LlamaEosToks::Single(c.eos_token_id)),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
@ -46,7 +87,7 @@ impl TextGeneration {
Sampling::ArgMax
} else {
match (top_k, top_p) {
(None, None) => Sampling::All { temperature },
(None, None) => Sampling::GumbelSoftmax { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
@ -106,7 +147,15 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id {
let is_eos = self
.config
.eos_token_id()
.as_ref()
.is_some_and(|v| match v {
LlamaEosToks::Single(eos) => *eos == next_token,
LlamaEosToks::Multiple(eos) => eos.contains(&next_token),
});
if Some(next_token) == self.config.bos_token_id() || is_eos {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
@ -131,6 +180,8 @@ impl TextGeneration {
enum Which {
#[value(name = "v1-preview")]
V1Preview,
#[value(name = "v1")]
V1,
}
#[derive(Parser, Debug)]
@ -144,9 +195,6 @@ struct Args {
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
@ -171,7 +219,7 @@ struct Args {
sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "v1-preview")]
#[arg(long, default_value = "v1")]
which: Which,
#[arg(long)]
@ -230,6 +278,7 @@ fn main() -> Result<()> {
None => {
let name = match args.which {
Which::V1Preview => "kyutai/helium-1-preview-2b",
Which::V1 => "kyutai/helium-1-2b",
};
name.to_string()
}
@ -254,18 +303,27 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = match args.config {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
let config_file = match args.config {
Some(config_file) => std::path::PathBuf::from(config_file),
None => repo.get("config.json")?,
};
let config = match args.which {
Which::V1Preview => Config::Preview(serde_json::from_slice(&std::fs::read(config_file)?)?),
Which::V1 => Config::V1(serde_json::from_slice(&std::fs::read(config_file)?)?),
};
let device = candle_examples::device(args.cpu)?;
let (model, device) = {
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
let model = match &config {
Config::V1(c) => {
let c = c.clone().into_config(false);
let model = ModelV1::load(vb, &c)?;
let cache = CacheV1::new(true, dtype, &c, &device)?;
Model::V1 { model, cache }
}
Config::Preview(c) => Model::Preview(ModelPreview::new(c, vb)?),
};
(model, device)
};

View File

@ -1,6 +1,6 @@
# hiera
[Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://arxiv.org/abs/2306.00989)
[Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://huggingface.co/papers/2306.00989)
This candle implementation uses pre-trained Hiera models from timm for inference.
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.

View File

@ -0,0 +1,11 @@
# candle-llama
Candle implementations of various Llama based architectures.
## Running an example
```bash
$ cargo run --example llama -- --prompt "Machine learning is " --which v32-3b-instruct
> Machine learning is the part of computer science which deals with the development of algorithms and
```

View File

@ -256,6 +256,12 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let tokenizer = common_args.tokenizer()?;
let device = candle_examples::device(common_args.cpu)?;
#[cfg(feature = "cuda")]
if let candle::Device::Cuda(d) = &device {
unsafe {
d.disable_event_tracking();
}
};
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
let is_safetensors = config_path

View File

@ -21,7 +21,7 @@ impl Config {
}
fn dt_rank(&self) -> usize {
(self.d_model + 15) / 16
self.d_model.div_ceil(16)
}
fn d_conv(&self) -> usize {

View File

@ -5,13 +5,13 @@ the transformer architecture. It leverages State Space Models (SSMs) with the
goal of being computationally efficient on long sequences. The implementation is
based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://huggingface.co/papers/2312.00752).
Compared to the mamba-minimal example, this version is far more efficient but
would only work for inference.
## Running the example
```bash
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
$ cargo run --example mamba --release -- --prompt "Mamba is the"
```

View File

@ -18,21 +18,19 @@ I know you are waiting for me. I will go through the forest, I will go through t
mountain. I cannot stay far from you any longer.</s>
```
### Changing model and language pairs
```bash
$ cargo run --example marian-mt --release -- --text "hello, how are you." --which base --language-pair en-zh
你好,你好吗?
```
## Generating the tokenizer.json files
You can use the following script to generate the `tokenizer.json` config files
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
packages to be install and use the `convert_slow_tokenizer.py` script from this
directory.
```python
from convert_slow_tokenizer import MarianConverter
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
```
The tokenizer for each `marian-mt` model was trained independently,
meaning each new model needs unique tokenizer encoders and decoders.
You can use the `./python/convert_slow_tokenizer.py` script in this directory to generate
the `tokenizer.json` config files from the hf-hub repos.
The script requires all the packages in `./python/requirements.txt` or `./python/uv.lock`
to be installed, and has only been tested for `python 3.12.7`.

File diff suppressed because it is too large Load Diff

View File

@ -20,6 +20,22 @@ enum Which {
Big,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum LanguagePair {
#[value(name = "fr-en")]
FrEn,
#[value(name = "en-zh")]
EnZh,
#[value(name = "en-hi")]
EnHi,
#[value(name = "en-es")]
EnEs,
#[value(name = "en-fr")]
EnFr,
#[value(name = "en-ru")]
EnRu,
}
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
@ -36,6 +52,10 @@ struct Args {
#[arg(long, default_value = "big")]
which: Which,
// Choose which language pair to use
#[arg(long, default_value = "fr-en")]
language_pair: LanguagePair,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
@ -53,21 +73,43 @@ pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let config = match args.which {
Which::Base => marian::Config::opus_mt_fr_en(),
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
let config = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => marian::Config::opus_mt_fr_en(),
(Which::Big, LanguagePair::FrEn) => marian::Config::opus_mt_tc_big_fr_en(),
(Which::Base, LanguagePair::EnZh) => marian::Config::opus_mt_en_zh(),
(Which::Base, LanguagePair::EnHi) => marian::Config::opus_mt_en_hi(),
(Which::Base, LanguagePair::EnEs) => marian::Config::opus_mt_en_es(),
(Which::Base, LanguagePair::EnFr) => marian::Config::opus_mt_fr_en(),
(Which::Base, LanguagePair::EnRu) => marian::Config::opus_mt_en_ru(),
(Which::Big, lp) => anyhow::bail!("big is not supported for language pair {lp:?}"),
};
let tokenizer_default_repo = match args.language_pair {
LanguagePair::FrEn => "lmz/candle-marian",
LanguagePair::EnZh
| LanguagePair::EnHi
| LanguagePair::EnEs
| LanguagePair::EnFr
| LanguagePair::EnRu => "KeighBee/candle-marian",
};
let tokenizer = {
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-fr.json",
Which::Big => "tokenizer-marian-fr.json",
let filename = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-fr.json",
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-fr.json",
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-en.json",
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-en.json",
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-en.json",
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-en.json",
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-en.json",
(Which::Big, lp) => {
anyhow::bail!("big is not supported for language pair {lp:?}")
}
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
.model(tokenizer_default_repo.to_string())
.get(filename)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
@ -77,13 +119,21 @@ pub fn main() -> anyhow::Result<()> {
let tokenizer = match args.tokenizer_dec {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-en.json",
Which::Big => "tokenizer-marian-en.json",
let filename = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-en.json",
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-en.json",
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-zh.json",
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-hi.json",
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-es.json",
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-fr.json",
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-ru.json",
(Which::Big, lp) => {
anyhow::bail!("big is not supported for language pair {lp:?}")
}
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
.model(tokenizer_default_repo.to_string())
.get(filename)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
@ -94,18 +144,48 @@ pub fn main() -> anyhow::Result<()> {
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
None => {
let api = Api::new()?;
let api = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-fr-en".to_string(),
hf_hub::RepoType::Model,
"refs/pr/4".to_string(),
))
.get("model.safetensors")?,
Which::Big => Api::new()?
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
.get("model.safetensors")?,
},
)),
(Which::Big, LanguagePair::FrEn) => {
api.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
}
(Which::Base, LanguagePair::EnZh) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-zh".to_string(),
hf_hub::RepoType::Model,
"refs/pr/13".to_string(),
)),
(Which::Base, LanguagePair::EnHi) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-hi".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
)),
(Which::Base, LanguagePair::EnEs) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-es".to_string(),
hf_hub::RepoType::Model,
"refs/pr/4".to_string(),
)),
(Which::Base, LanguagePair::EnFr) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-fr".to_string(),
hf_hub::RepoType::Model,
"refs/pr/9".to_string(),
)),
(Which::Base, LanguagePair::EnRu) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-ru".to_string(),
hf_hub::RepoType::Model,
"refs/pr/7".to_string(),
)),
(Which::Big, lp) => {
anyhow::bail!("big is not supported for language pair {lp:?}")
}
};
api.get("model.safetensors")?
}
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};

View File

@ -0,0 +1,53 @@
from pathlib import Path
import warnings
from transformers import AutoTokenizer
from transformers.convert_slow_tokenizer import SpmConverter, requires_backends, import_protobuf
class MarianConverter(SpmConverter):
def __init__(self, *args, index: int = 0):
requires_backends(self, "protobuf")
super(SpmConverter, self).__init__(*args)
# from .utils import sentencepiece_model_pb2 as model_pb2
model_pb2 = import_protobuf()
m = model_pb2.ModelProto()
print(self.original_tokenizer.spm_files)
with open(self.original_tokenizer.spm_files[index], "rb") as f:
m.ParseFromString(f.read())
self.proto = m
print(self.original_tokenizer)
#with open(self.original_tokenizer.vocab_path, "r") as f:
dir_path = Path(self.original_tokenizer.spm_files[0]).parents[0]
with open(dir_path / "vocab.json", "r") as f:
import json
self._vocab = json.load(f)
if self.proto.trainer_spec.byte_fallback:
if not getattr(self, "handle_byte_fallback", None):
warnings.warn(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
"unknown tokens into a sequence of byte tokens matching the original piece of text."
)
def vocab(self, proto):
vocab_size = max(self._vocab.values()) + 1
vocab = [("<NIL>", -100) for _ in range(vocab_size)]
for piece in proto.pieces:
try:
index = self._vocab[piece.piece]
except Exception:
print(f"Ignored missing piece {piece.piece}")
vocab[index] = (piece.piece, piece.score)
return vocab
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
fast_tokenizer.save("tokenizer-marian-base-fr.json")
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
fast_tokenizer.save("tokenizer-marian-base-en.json")

View File

@ -0,0 +1,22 @@
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
filelock==3.18.0
fsspec==2025.3.2
huggingface-hub==0.30.1
idna==3.10
joblib==1.4.2
numpy==2.2.4
packaging==24.2
protobuf==6.30.2
pyyaml==6.0.2
regex==2024.11.6
requests==2.32.3
sacremoses==0.1.1
safetensors==0.5.3
sentencepiece==0.2.0
tokenizers==0.21.1
tqdm==4.67.1
transformers==4.50.3
typing-extensions==4.13.0
urllib3==2.3.0

View File

@ -13,6 +13,6 @@ Note that the current candle implementation suffers from some limitations as of
## Run an example
```bash
cargo run --example metavoice --release -- \\
cargo run --example metavoice --release -- \
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
```

View File

@ -0,0 +1,16 @@
# candle-mnist-training
Training a 2 layer MLP on mnist in Candle.
## Running an example
```bash
$ cargo run --example mnist-training --features candle-datasets
> train-images: [60000, 784]
> train-labels: [60000]
> test-images: [10000, 784]
> test-labels: [10000]
> 1 train loss: 2.30265 test acc: 68.08%
> 2 train loss: 1.50815 test acc: 60.77%
```

View File

@ -7,6 +7,7 @@ extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use rand::prelude::*;
use rand::rng;
use candle::{DType, Result, Tensor, D};
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
@ -138,7 +139,7 @@ fn training_loop_cnn(
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
for epoch in 1..args.epochs {
let mut sum_loss = 0f32;
batch_idxs.shuffle(&mut thread_rng());
batch_idxs.shuffle(&mut rng());
for batch_idx in batch_idxs.iter() {
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;

View File

@ -2,7 +2,7 @@
MobileCLIP is family of efficient CLIP-like models using FastViT-based image encoders.
See [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/abs/2311.17049)
See [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://huggingface.co/papers/2311.17049)
## Running on an example on cpu

View File

@ -1,6 +1,6 @@
# candle-mobilenetv4
[MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518)
[MobileNetV4 - Universal Models for the Mobile Ecosystem](https://huggingface.co/papers/2404.10518)
This candle implementation uses pre-trained MobileNetV4 models from timm for inference.
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.

View File

@ -1,6 +1,6 @@
# candle-mobileone
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
[MobileOne: An Improved One millisecond Mobile Backbone](https://huggingface.co/papers/2206.04040).
This candle implementation uses a pre-trained MobileOne network for inference. The
classification head has been trained on the ImageNet dataset and returns the

View File

@ -12,7 +12,7 @@ $ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jp
Now you can run Moondream from the `candle-examples` crate:
```bash
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
$ cargo run --example moondream --release -- --prompt "Describe the people behind the bikers?" --image "candle-examples/examples/yolo-v8/assets/bike.jpg"
avavx: false, neon: true, simd128: false, f16c: false
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64

View File

@ -0,0 +1,20 @@
# candle-musicgen
Candle implementation of musicgen from [Simple and Controllable Music Generation](https://huggingface.co/papers/2306.05284).
## Running an example
```bash
$ cargo run --example musicgen -- --prompt "90s rock song with loud guitars and heavy drums"
> tokens: [2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437, 5253, 7, 1]
> Tensor[dims 1, 13; u32]
> [[[ 0.0902, 0.1256, -0.0585, ..., 0.1057, -0.5141, -0.4675],
> [ 0.1972, -0.0268, -0.3368, ..., -0.0495, -0.3597, -0.3940],
> [-0.0855, -0.0007, 0.2225, ..., -0.2804, -0.5360, -0.2436],
> ...
> [ 0.0515, 0.0235, -0.3855, ..., -0.4728, -0.6858, -0.2923],
> [-0.3728, -0.1442, -0.1179, ..., -0.4388, -0.0287, -0.3242],
> [ 0.0163, 0.0012, -0.0020, ..., 0.0142, 0.0173, -0.0103]]]
> Tensor[[1, 13, 768], f32]
```

View File

@ -3,7 +3,7 @@
OLMo is a series of Open Language Models designed to enable the science of language models.
- **Project Page:** https://allenai.org/olmo
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
- **Papers:** [OLMo](https://huggingface.co/papers/2402.00838) [OLMo 2](https://huggingface.co/papers/2501.00656)
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
<!-- - **Press release:** TODO -->

View File

@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::olmo::{Config, Model as OLMo};
use candle_transformers::models::olmo2::{Config as Config2, Model as OLMo2};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
enum Model {
OLMo(OLMo),
OLMo2(OLMo2),
}
struct TextGeneration {
@ -82,6 +84,7 @@ impl TextGeneration {
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model {
Model::OLMo(m) => m.forward(&input, start_pos)?,
Model::OLMo2(m) => m.forward(&input, start_pos)?,
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
@ -129,6 +132,8 @@ enum Which {
W7bTwin2T,
#[value(name = "1.7-7b")]
V1_7W7b,
#[value(name = "2-1b")]
V2W1b,
}
#[derive(Parser, Debug)]
@ -220,6 +225,7 @@ fn main() -> Result<()> {
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
Which::V2W1b => "allenai/OLMo-2-0425-1B-Instruct".to_string(),
},
};
@ -238,33 +244,36 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
Which::W1b => {
Which::W1b | Which::V2W1b => {
vec![repo.get("model.safetensors")?]
}
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
},
};
let config_filename = repo.get("config.json")?;
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = {
let config_filename = repo.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
config
};
let device = candle_examples::device(args.cpu)?;
let model = {
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = OLMo::new(&config, vb)?;
Model::OLMo(model)
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = match args.model {
Which::W1b | Which::W7b | Which::W7bTwin2T | Which::V1_7W7b => {
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let model = OLMo::new(&config, vb)?;
Model::OLMo(model)
}
Which::V2W1b => {
let config: Config2 = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let model = OLMo2::new(&config, vb)?;
Model::OLMo2(model)
}
};
println!("loaded the model in {:?}", start.elapsed());

View File

@ -2,7 +2,7 @@
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based models in Candle.
It contains small variants of two models, [SqueezeNet](https://arxiv.org/pdf/1602.07360.pdf) (default) and [EfficientNet](https://arxiv.org/pdf/1905.11946.pdf).
It contains small variants of two models, [SqueezeNet](https://huggingface.co/papers/1602.07360) (default) and [EfficientNet](https://huggingface.co/papers/1905.11946).
You can run the examples with following commands:

View File

@ -5,12 +5,14 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use candle::{IndexOp, D};
use candle_examples::save_image;
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
SqueezeNet,
EfficientNet,
EsrGan,
}
#[derive(Parser)]
@ -28,10 +30,21 @@ struct Args {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let image = candle_examples::imagenet::load_image224(args.image)?;
let image = match args.which {
Which::SqueezeNet | Which::EfficientNet => {
candle_examples::imagenet::load_image224(&args.image)?
}
Which::EsrGan => candle_examples::imagenet::load_image_with_std_mean(
&args.image,
128,
&[0.0f32, 0.0, 0.0],
&[1.0f32, 1.0, 1.0],
)?,
};
let image = match args.which {
Which::SqueezeNet => image,
Which::EfficientNet => image.permute((1, 2, 0))?,
Which::EsrGan => image,
};
println!("loaded image {image:?}");
@ -45,6 +58,9 @@ pub fn main() -> anyhow::Result<()> {
Which::EfficientNet => hf_hub::api::sync::Api::new()?
.model("onnx/EfficientNet-Lite4".into())
.get("efficientnet-lite4-11.onnx")?,
Which::EsrGan => hf_hub::api::sync::Api::new()?
.model("qualcomm/Real-ESRGAN-x4plus".into())
.get("Real-ESRGAN-x4plus.onnx")?,
},
};
@ -57,21 +73,40 @@ pub fn main() -> anyhow::Result<()> {
let prs = match args.which {
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
Which::EfficientNet => output,
Which::EsrGan => output,
};
let prs = prs.i(0)?.to_vec1::<f32>()?;
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top = top.into_iter().take(5).collect::<Vec<_>>();
match args.which {
Which::EfficientNet | Which::SqueezeNet => {
let prs = prs.i(0)?.to_vec1::<f32>()?;
// Print the top predictions
for &(i, p) in &top {
println!(
"{:50}: {:.2}%",
candle_examples::imagenet::CLASSES[i],
p * 100.0
);
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top = top.into_iter().take(5).collect::<Vec<_>>();
// Print the top predictions
for &(i, p) in &top {
println!(
"{:50}: {:.2}%",
candle_examples::imagenet::CLASSES[i],
p * 100.0
);
}
}
Which::EsrGan => {
let max_pixel_val = candle::Tensor::try_from(255.0f32)?
.to_device(prs.device())?
.broadcast_as(prs.shape())?;
let out = (prs * max_pixel_val)?.i(0)?.to_dtype(candle::DType::U8)?;
let pb = std::path::PathBuf::from(args.image);
let input_file_name = pb.file_name().unwrap();
let mut output_file_name = std::ffi::OsString::from("super_");
output_file_name.push(input_file_name);
save_image(&out, output_file_name)?;
}
}
Ok(())

View File

@ -0,0 +1,14 @@
# Orpheus
Orpheus is a 3B text-to-speech model based on Llama.
- Weights on HuggingFace
[canopylabs/orpheus-3b-0.1-ft](https://huggingface.co/canopylabs/orpheus-3b-0.1-ft).
- Code on GitHub [canopyai/Orpheus-TTS](https://github.com/canopyai/Orpheus-TTS).
```bash
cargo run --example orpheus --features cuda -r
```

View File

@ -0,0 +1,329 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::llama::{Cache, Llama, LlamaConfig};
use candle_transformers::models::snac::{Config as SnacConfig, Model as SnacModel};
use tokenizers::Tokenizer;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/realtime_streaming_example/main.py#L43
const STOP_TOKEN_ID: u32 = 128258;
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long, default_value = "Hey, how are you doing today?")]
prompt: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.6)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// The output wav file.
#[arg(long, default_value = "out.wav")]
out_file: String,
#[arg(long, default_value = "3b-0.1-ft")]
which: Which,
#[arg(long, default_value = "tara")]
voice: Voice,
#[arg(long)]
use_flash_attn: bool,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Voice {
#[value(name = "tara")]
Tara,
#[value(name = "leah")]
Leah,
#[value(name = "jess")]
Jess,
#[value(name = "leo")]
Leo,
#[value(name = "dan")]
Dan,
#[value(name = "mia")]
Mia,
#[value(name = "zac")]
Zac,
#[value(name = "zoe")]
Zoe,
}
impl Voice {
fn as_str(&self) -> &'static str {
match self {
Voice::Tara => "tara",
Voice::Leah => "leah",
Voice::Jess => "jess",
Voice::Leo => "leo",
Voice::Dan => "dan",
Voice::Mia => "mia",
Voice::Zac => "zac",
Voice::Zoe => "zoe",
}
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "3b-0.1-ft")]
ThreeB0_1Ft,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let prompt = args.prompt.clone();
let mut model = Model::load(args)?;
model.run(&prompt)?;
Ok(())
}
struct Model {
model: Llama,
tokenizer: Tokenizer,
logits_processor: candle_transformers::generation::LogitsProcessor,
cache: Cache,
device: Device,
verbose_prompt: bool,
snac: SnacModel,
out_file: String,
voice: Voice,
}
fn load_snac(device: &Device) -> Result<SnacModel> {
let api = hf_hub::api::sync::Api::new()?;
let m = api.model("hubertsiuzdak/snac_24khz".to_string());
let config = m.get("config.json")?;
let config: SnacConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
let m = api.model("lmz/candle-snac".to_string());
let model = m.get("snac_24khz.safetensors")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, device)? };
let model = SnacModel::new(&config, vb)?;
Ok(model)
}
impl Model {
fn load(args: Args) -> Result<Self> {
let start = std::time::Instant::now();
let api = hf_hub::api::sync::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::ThreeB0_1Ft => "canopylabs/orpheus-3b-0.1-ft".to_string(),
},
};
let revision = match args.revision {
Some(r) => r,
None => "main".to_string(),
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
revision,
));
let model_files = match args.model_file {
Some(m) => vec![m.into()],
None => match args.which {
Which::ThreeB0_1Ft => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
let config = match args.config_file {
Some(m) => m.into(),
None => repo.get("config.json")?,
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device)? };
let config: LlamaConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
let config = config.into_config(args.use_flash_attn);
let model = Llama::load(vb, &config)?;
let logits_processor = {
use candle_transformers::generation::{LogitsProcessor, Sampling};
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k.as_ref(), args.top_p.as_ref()) {
(None, None) => Sampling::All { temperature },
(Some(&k), None) => Sampling::TopK { k, temperature },
(None, Some(&p)) => Sampling::TopP { p, temperature },
(Some(&k), Some(&p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
println!("loaded the model in {:?}", start.elapsed());
let cache = Cache::new(true, dtype, &config, &device)?;
let snac = load_snac(&device)?;
Ok(Self {
model,
tokenizer,
logits_processor,
cache,
device,
verbose_prompt: args.verbose_prompt,
snac,
voice: args.voice,
out_file: args.out_file,
})
}
fn run(&mut self, prompt: &str) -> Result<()> {
println!("running the model on '{}'", prompt);
let device = &self.device;
let prompt = format!("{voice}: {prompt}", voice = self.voice.as_str());
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/engine_class.py#L82
let mut tokens = [
&[128259],
tokens.get_ids(),
&[128009, 128260, 128261, 128257],
]
.concat();
if self.verbose_prompt {
println!("{:?}", tokens);
}
let mut cache = self.cache.clone();
println!("starting the inference loop");
let mut index_pos = 0;
let mut audio_tokens = vec![];
for index in 0..2000 {
let (context_size, context_index) = if index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();
let next_token = self.logits_processor.sample(&logits)?;
if let Some(tok) = self.tokenizer.id_to_token(next_token) {
match tok.strip_prefix("<custom_token_") {
Some(tok) => match tok.strip_suffix('>') {
Some(tok) => {
let tok = tok.parse::<u32>()?;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/decoder.py#L86C35-L86C63
let tok = tok - 10 - ((audio_tokens.len() as u32 % 7) * 4096);
audio_tokens.push(tok);
}
None => {
println!("{index}: unexpected custom token {next_token} {tok}");
}
},
None => {
println!("{index}: unexpected token {next_token} {tok}");
}
}
}
if next_token == STOP_TOKEN_ID {
println!("reached stop token");
break;
}
tokens.push(next_token);
}
println!("generated {} audio tokens", audio_tokens.len());
let mut codes0 = vec![];
let mut codes1 = vec![];
let mut codes2 = vec![];
for audio_tokens in audio_tokens.chunks_exact(7) {
codes0.push(audio_tokens[0]);
for i in [1, 4] {
codes1.push(audio_tokens[i]);
}
for i in [2, 3, 5, 6] {
codes2.push(audio_tokens[i]);
}
}
let codes0 = Tensor::new(codes0, device)?.unsqueeze(0)?;
let codes1 = Tensor::new(codes1, device)?.unsqueeze(0)?;
let codes2 = Tensor::new(codes2, device)?.unsqueeze(0)?;
let pcm = self.snac.decode(&[&codes0, &codes1, &codes2])?;
println!("decoded to pcm {pcm:?}");
let mut output = std::fs::File::create(&self.out_file)?;
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24000)?;
Ok(())
}
}

View File

@ -0,0 +1,18 @@
# candle-quantized-gemma
Candle implementation of quantized Gemma.
## Running an example
```bash
$ cargo run --example quantized-gemma -- --prompt "Write a function to calculate fibonacci numbers. "
> ```python
> def fibonacci(n):
> """Calculates the nth Fibonacci number using recursion."""
> if n <= 1:
> return n
> else:
> return fibonacci(n-1) + fibonacci(n-2
> ```
```

View File

@ -0,0 +1,344 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::gguf_file;
use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_gemma3::ModelWeights;
const DEFAULT_PROMPT: &str = "Write a function to calculate fibonacci num";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "gemma3-4b-it")]
Gemma3_4bIt,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGUF file to load, typically a .gguf file generated by quantization
#[arg(long)]
model: Option<String>,
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
/// is preserved.
#[arg(long)]
prompt: Option<String>,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize,
/// The tokenizer config in json format.
#[arg(long)]
tokenizer: Option<String>,
/// The temperature used to generate samples, use 0 for greedy sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "gemma3-4b-it")]
which: Which,
}
impl Args {
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = "google/gemma-3-4b-it";
println!("DEBUG: Downloading tokenizer from {}", repo);
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
};
println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path);
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
Ok(tokenizer)
}
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename) = match self.which {
Which::Gemma3_4bIt => (
"google/gemma-3-4b-it-qat-q4_0-gguf",
"gemma-3-4b-it-q4_0.gguf",
),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
"main".to_string(),
))
.get(filename)?
}
};
Ok(model_path)
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
#[derive(Debug)]
enum Prompt {
Interactive,
Chat,
One(String),
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let mut model = {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file, &device)?
};
println!("model built");
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
println!(
"DEBUG: Tokenizer vocabulary size: {}",
tos.tokenizer().get_vocab(true).len()
);
let prompt = match args.prompt.as_deref() {
Some("chat") => Prompt::Chat,
Some("interactive") => Prompt::Interactive,
Some(s) => Prompt::One(s.to_string()),
None => Prompt::One(DEFAULT_PROMPT.to_string()),
};
let mut pre_prompt_tokens = vec![];
for _ in 0.. {
let prompt_str = match &prompt {
Prompt::One(prompt) => prompt.clone(),
Prompt::Interactive | Prompt::Chat => {
print!("> ");
std::io::stdout().flush()?;
let mut prompt = String::new();
std::io::stdin().read_line(&mut prompt)?;
if prompt.ends_with('\n') {
prompt.pop();
if prompt.ends_with('\r') {
prompt.pop();
}
}
// Format for Gemma 3 chat/instruction format
format!("<start_of_turn> user\n{prompt}<end_of_turn>\n<start_of_turn> model\n")
}
};
print!("{}", &prompt_str);
let tokens = tos
.tokenizer()
.encode(prompt_str, true)
.map_err(anyhow::Error::msg)?;
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
let to_sample = args.sample_len.saturating_sub(1);
let max_seq_len = 8192; // Gemma 3 context length
let prompt_tokens = if prompt_tokens.len() + to_sample > max_seq_len - 10 {
let to_remove = prompt_tokens.len() + to_sample + 10 - max_seq_len;
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
} else {
prompt_tokens
};
let mut all_tokens = vec![];
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in prompt_tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
// For Gemma 3, use the correct end of sequence token
let eos_token = *tos
.tokenizer()
.get_vocab(true)
.get("<end_of_turn>")
.unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token {
break;
};
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
prompt_tokens.len(),
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
match prompt {
Prompt::One(_) => break,
Prompt::Interactive => {}
Prompt::Chat => {
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
}
}
}
Ok(())
}

Some files were not shown because too many files have changed in this diff Show More