Compare commits

..

57 Commits

Author SHA1 Message Date
9453cc3095 Bump the crate version to 0.8.0. (#2612) 2024-11-12 14:11:46 +01:00
3769206583 Update docs (#2553)
* add module docs for candle-core

* doc each of the candle-nn modules and add the links to the doc page
2024-11-11 22:13:52 +01:00
e2b6b367fa Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32)

* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format

* Conditional compilation for bf16

* Use it in quantized llama

* Some review comments

* Use set_params!

* Remove unused

* Remove feature

* Fix metal sdpa for v stride

* Remove comma

* Add the dim method to layout and shape.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-05 09:28:00 +01:00
6454597943 Improved launch config for layer-norm/rms-norm. (#2591)
* Improved launch config for layer-norm/rms-norm.

* Add more testing for the fused layer/rms norm kernels.
2024-11-04 10:42:18 +01:00
3fba2b5fc4 Add the SmolLM2 models. (#2595)
* Add the SmolLM2 models.

* More SmolLM2 support.
2024-11-03 17:11:12 +01:00
530ab96036 Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)
* support skip layer guidance (slg) for stable diffusion 3.5 medium

* Tweak the comments formatting.

* Proper error message.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-11-01 18:10:40 +01:00
7ac0de15a9 Lazy upcasting for t5. (#2589) 2024-10-30 18:08:51 +01:00
d232e132f6 Support sd3.5 medium and MMDiT-X (#2587)
* extract attn out of joint_attn

* further adjust attn and joint_attn

* add mmdit-x support

* support sd3.5-medium in the example

* update README.md
2024-10-30 06:19:07 +01:00
139ff56aeb Reduce memory usage for sd 3.5. (#2582) 2024-10-28 22:45:02 +01:00
498bc2cdc9 Release the mmdit model earlier to reduce memory usage. (#2581)
* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.

* Release the mmdit model earlier to reduce memory usage.
2024-10-28 16:06:53 +01:00
0e2c8c17fb UG metal integration. (#2580) 2024-10-27 15:20:37 +01:00
594d984f9c Support for UG kernels. (#2579)
* Support for UG kernels.

* Add a dedicated test.
2024-10-27 13:37:19 +01:00
37e0ab8c64 Stable diffusion 3.5 support. (#2578)
* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.
2024-10-27 10:01:04 +01:00
07849aa595 Update README.md (#2577) 2024-10-26 18:23:52 +02:00
3699c1a053 Fix the repo name for llama 3.1. (#2576)
* Fix the repo name for llama 3.1.

* Fix the book.
2024-10-26 11:25:04 +02:00
a2e9d41b20 use softmax_last_dim (metal and cuda kernel) in llama attention layer (#2572) 2024-10-23 20:07:09 +02:00
7c09215ef4 ONNX: GatherElements, Xor (#2568) 2024-10-17 20:22:35 +02:00
dcd83336b6 Testcases (#2567) 2024-10-17 13:00:45 +02:00
a01aa89799 onnx: ReduceMin/Max Ops (#2563)
* Stella_en_1.5B_v5

* Separated  creation. This is a critical step for numerical accuracy and would be documented in the readme

* EmbedDim would require clone and copy

* WIP: example

* Examples added

* a litte more in README

* WIP: ONNX Reduce-max ops

* WIP: tests for ReduceMin

* Reduce min/ max v18+

* Reformatting tests for better review readability

* Error on empty set, backward compatibility (13 and below) with 'axes'
2024-10-15 10:34:07 +02:00
3d1dc06cdb Enable stable-diffusion 3 on metal. (#2560) 2024-10-14 08:59:12 +02:00
f553ab5eb4 Adds support for Stella_en_v5 embedding model - 1.5B variant (#2551)
* Stella_en_1.5B_v5

* Separated  creation. This is a critical step for numerical accuracy and would be documented in the readme

* EmbedDim would require clone and copy

* WIP: example

* Examples added

* a litte more in README
2024-10-13 23:09:12 +02:00
41ade774e8 fix: Allow marian configs to deserialize from json. (#2556) 2024-10-13 23:05:50 +02:00
6eab6b57f5 Fix the guide to gain access to Stable Diffusion 3 Medium (#2559) 2024-10-13 22:55:26 +02:00
ca7cf5cb3b Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example

Add get_qkv_linear to handle different dimensionality in linears

Add stable diffusion 3 example

Add use_quant_conv and use_post_quant_conv for vae in stable diffusion

adapt existing AutoEncoderKLConfig to the change

add forward_until_encoder_layer to ClipTextTransformer

rename sd3 config to sd3_medium in mmdit; minor clean-up

Enable flash-attn for mmdit impl when the feature is enabled.

Add sd3 example codebase

add document

crediting references

pass the cargo fmt test

pass the clippy test

* fix typos

* expose cfg_scale and time_shift as options

* Replace the sample image with JPG version. Change image output format accordingly.

* make meaningful error messages

* remove the tail-end assignment in sd3_vae_vb_rename

* remove the CUDA requirement

* use default_value in clap args

* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime

* resolve clippy errors and warnings

* use default_value_t

* Pin the web-sys dependency.

* Clippy fix.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-10-13 22:08:40 +02:00
0d96ec31e8 feat: intergrate chinese clip and add example (#2555)
* start to impl chinese clip

* impl vision model

* copy code from bert

* refactor use

* refactor use again

* fix text model

* refactor

* try to fix text model

* tuning

* tuning chinese clip

* delete useless code

* revert code

* Clippy fixes.

* Also apply cargo fmt.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-10-10 15:18:55 +02:00
937e8eda74 Add BertForMaskedLM to support SPLADE Models (#2550)
* add bert for masked lm

* working example

* add example readme

* Clippy fix.

* And apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-10-07 23:28:21 +02:00
edf7668291 improve (#2548) 2024-10-07 17:30:56 +02:00
e4a96f9e7c Switch to using the MLX matmul by default. (#2547) 2024-10-06 23:24:55 +02:00
f856b5c3a7 pyo3 update. (#2545)
* pyo3 update.

* Stub fix.
2024-10-06 10:09:38 +02:00
d2e432914e Tensor tools print all (#2543)
* Support whisper large-v3 turbo in the whisper-microphone example.

* Print all tensors when no argument is provided.
2024-10-05 10:05:14 +02:00
410c89f72a Add required feature for whisper example in Readme (#2539) 2024-10-04 14:29:55 +02:00
56aacb05da Make the RNN configs accessible from the models. (#2541) 2024-10-04 14:22:23 +02:00
6faecaa616 Fix for cudnn bf16 conv2d. (#2535) 2024-10-02 23:18:55 +02:00
90d04ff622 Support whisper large-v3 turbo in the whisper-microphone example. (#2533) 2024-10-02 22:09:14 +02:00
7b60bda4ed Add support for cuda streams. (#2532) 2024-10-02 21:30:58 +02:00
936300678d Add whisper large-v3 turbo to the example. (#2531) 2024-10-02 21:07:08 +02:00
f479840ce6 Add a seed to the flux example. (#2529) 2024-10-02 10:52:02 +02:00
fd08d3d0a4 Tweak some metal tests. (#2528) 2024-10-02 10:22:31 +02:00
a2bcc227df Efficient implementation of Tensor::ones() for metal (#2512)
* WIP: hopefully better const impl

* with GPU

* More tests on

* Reverting primitive for

* Incorporating review changes - added check elem count check in kerner, using  for call strategy

* rustfmt ran
2024-10-01 19:11:59 +02:00
def4c6cdee Cuda quantized mmv bugfix. (#2526) 2024-10-01 12:57:55 +02:00
888d886dd8 Add ColPali (#2524)
* add colpali

* cleanup

* fix clippy
2024-10-01 11:48:39 +02:00
6110ad8d4f Refactor the whisper microphone example. (#2523)
* Refactor the whisper microphone example.

* Tweak the whisper microphone example more.
2024-10-01 00:24:17 +02:00
aa35bf2ff5 Add/lstm direction (#2455)
* add: direction for lstm layer

* lint: remove unused Error import

* refactor: remove unnecessary int assignment to Direction enum:

* refactor: use &'static str type instead of String for direction_str:

* Run cargofmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-09-30 22:44:07 +02:00
724650446c Yet another cuda qmm padding fix. (#2509) 2024-09-30 21:53:30 +02:00
dfe9a00683 Pixtral polishing. (#2522)
* Pixtral polishing.

* Clippy fix.
2024-09-30 21:23:54 +02:00
683ab698de Add Pixtral. (#2521)
* Add Pixtral.

* More pixtral vision encoder.

* Sketch a pixtral example.

* Sketch a pixtral example.

* Better image loading.

* Support loading images embedded in safetensor files.

* Clippy fixes.

* Add the llava multimodal adapter.

* Add more of the llava bits.

* Add the pixtral config.

* More pixtral inference.

* Add the text generation bits.

* Get the example to work.

* Bugfix.

* Run some bits of the model in f32.

* Blessed version :)

* Better rope frequency computations.

* README update.
2024-09-30 19:31:14 +02:00
2f49e1b534 Add PaliGemma. (#2519)
* Add PaliGemma.

* PaliGemma inference loop.

* Running PaliGemma example.

* Tweak the prompt.
2024-09-29 19:56:56 +02:00
0ebb38813b Paligemma siglip vision config (#2518)
* Add the paligemma siglip vision config.

* More paligemma configs.
2024-09-29 17:53:52 +02:00
3a3c48b14b Bump the crate version to 0.7.2. (#2517) 2024-09-29 10:56:50 +02:00
261ed65f36 Add the SigLIP model. (#2515)
* Add the SigLIP model.

* Add more to the forward pass of the vision model.

* Complete the forward pass.

* Add the siglip example.

* Fix.

* Another fix.

* Get everything in place.

* Add a readme.
2024-09-28 23:48:00 +02:00
62525e8352 Remove some extra whitelines. (#2513) 2024-09-28 14:41:28 +02:00
2c25754281 Clippy fixes for onnx + fix a broken test. (#2510) 2024-09-26 23:37:59 +02:00
ed48f54b54 Expand split ops (#2505)
* candle-onnx: Add Split and Expand operators, Fix Where Op

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining Split examples as tests
TODO: Add.test case that motivates Where fix

* candle-onnx: Add ReduceSum operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Add ReduceL2 operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Fix Clip operator empty string as default arg issue

Optional input args may be signified by an empty string. The length of the input array is not enough because non optional args may follow optional ones.

I encountered this when trying to use the ONNX model found at https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 for example.

The LSTM op has a utility which I factored to be more generally accessible, and I have used it in the ops I have recently created or debugged.

I believe it is likely that this issue may also manifest in other ops, but I didn't want to change anything that I'm not testing.

* fix formatting

* fix small mistake made during refactor
2024-09-26 22:57:55 +02:00
ad8a4c5e5a Add some llama-3.2 examples. (#2508)
* Add some llama-3.2 examples.

* Support tie-word-embeddings for llama.
2024-09-26 21:00:18 +02:00
c3c392f45c Merge pull request #2507 from huggingface/ci-move
move CI/Cuda runner
2024-09-26 18:48:52 +02:00
a0184a4fe4 move CI/Cuda runner 2024-09-26 17:09:26 +02:00
10d47183c0 Quantized version of flux. (#2500)
* Quantized version of flux.

* More generic sampling.

* Hook the quantized model.

* Use the newly minted gguf file.

* Fix for the quantized model.

* Default to avoid the faster cuda kernels.
2024-09-26 10:23:43 +02:00
114 changed files with 12127 additions and 755 deletions

View File

@ -9,7 +9,8 @@ jobs:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on: [single-gpu, nvidia-gpu, t4, ci]
runs-on:
group: aws-g4dn-2xlarge
container:
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
options: --gpus 0

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.7.1"
version = "0.8.0"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,20 +33,20 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.7.1" }
candle-datasets = { path = "./candle-datasets", version = "0.7.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.1" }
candle-kernels = { path = "./candle-kernels", version = "0.7.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.1" }
candle-nn = { path = "./candle-nn", version = "0.7.1" }
candle-onnx = { path = "./candle-onnx", version = "0.7.1" }
candle-transformers = { path = "./candle-transformers", version = "0.7.1" }
candle = { path = "./candle-core", package = "candle-core", version = "0.8.0" }
candle-datasets = { path = "./candle-datasets", version = "0.8.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.0" }
candle-kernels = { path = "./candle-kernels", version = "0.8.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.0" }
candle-nn = { path = "./candle-nn", version = "0.8.0" }
candle-onnx = { path = "./candle-onnx", version = "0.8.0" }
candle-transformers = { path = "./candle-transformers", version = "0.8.0" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
@ -70,6 +70,9 @@ tokenizers = { version = "0.19.1", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
ug = "0.0.2"
ug-cuda = "0.0.2"
ug-metal = "0.0.2"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}

View File

@ -2,7 +2,8 @@
[![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.gg/hugging-face-879548962464493619)
[![Latest version](https://img.shields.io/crates/v/candle-core.svg)](https://crates.io/crates/candle-core)
[![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core)
![License](https://img.shields.io/crates/l/candle-core.svg)
[![License](https://img.shields.io/github/license/base-org/node?color=blue)](https://github.com/huggingface/candle/blob/main/LICENSE-MIT)
[![License](https://img.shields.io/badge/license-Apache%202.0-blue?style=flat-square)](https://github.com/huggingface/candle/blob/main/LICENSE-APACHE)
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
and ease of use. Try our online demos:
@ -187,6 +188,7 @@ And then head over to
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.
If you have an addition to this list, please submit a pull request.

View File

@ -11,8 +11,8 @@ Then let's start by downloading the [model file](https://huggingface.co/bert-bas
```rust
# extern crate candle_core;
# extern crate hf_hub;
use hf_hub::api::sync::Api;
# extern crate candle_hf_hub;
use candle_hf_hub::api::sync::Api;
use candle_core::Device;
let api = Api::new().unwrap();
@ -50,8 +50,8 @@ Now that we have our weights, we can use them in our bert architecture:
```rust
# extern crate candle_core;
# extern crate candle_nn;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
# extern crate candle_hf_hub;
# use candle_hf_hub::api::sync::Api;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());

View File

@ -28,6 +28,9 @@ rand_distr = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
ug = { workspace = true }
ug-cuda = { workspace = true, optional = true }
ug-metal = { workspace = true, optional = true }
yoke = { workspace = true }
zip = { workspace = true }
@ -39,11 +42,11 @@ criterion = { workspace = true }
[features]
default = []
cuda = ["cudarc", "dep:candle-kernels"]
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]
[[bench]]
name = "bench_main"

View File

@ -26,6 +26,7 @@ impl From<cudarc::driver::DriverError> for crate::Error {
pub(crate) fn launch_conv2d<
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
Y: cudarc::cudnn::CudnnDataType,
>(
src: &CudaView<T>,
src_l: &crate::Layout,
@ -48,7 +49,7 @@ pub(crate) fn launch_conv2d<
}
c
})?;
let conv = cudnn.create_conv2d::<T>(
let conv = cudnn.create_conv2d::<Y>(
/* pad */ [params.padding as i32, params.padding as i32],
/* stride */ [params.stride as i32, params.stride as i32],
/* dilation */ [params.dilation as i32, params.dilation as i32],
@ -62,18 +63,18 @@ pub(crate) fn launch_conv2d<
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
cudnn.create_4d_tensor(
cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
x_shape,
)?
} else {
let s = src_l.stride();
cudnn.create_4d_tensor_ex(
cudnn.create_4d_tensor_ex::<T>(
x_shape,
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
)?
};
let w = cudnn.create_4d_filter(
let w = cudnn.create_4d_filter::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[
params.c_out as i32,
@ -83,7 +84,7 @@ pub(crate) fn launch_conv2d<
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor(
let y = cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;

View File

@ -51,6 +51,27 @@ impl CudaDevice {
self.device.clone()
}
pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<CudaFunction> {
let mut buf = vec![];
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
let opts = cudarc::nvrtc::CompileOptions {
use_fast_math: Some(true),
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
let func = match self.device.get_func("ug", func_name) {
Some(func) => func,
None => crate::bail!("unknown function ug::{func_name}"),
};
Ok(func)
}
pub fn id(&self) -> DeviceId {
self.id
}
@ -144,6 +165,20 @@ impl CudaDevice {
}
}
impl CudaDevice {
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
})
}
}
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;

View File

@ -1522,7 +1522,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<u8>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::U8(out)
}
@ -1530,7 +1530,10 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<bf16>(inp, inp_l, k, &mut out, params, &device)
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
// version.
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
crate::cudnn::launch_conv2d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::BF16(out)
}
@ -1538,7 +1541,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f16>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F16(out)
}
@ -1546,7 +1549,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f32>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F32(out)
}
@ -1554,7 +1557,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f64>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F64(out)
}

View File

@ -375,3 +375,110 @@ impl Tensor {
)
}
}
pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
func: cudarc::driver::CudaFunction,
#[cfg(feature = "metal")]
func: metal::ComputePipelineState,
}
impl UgIOp1 {
#[allow(unused)]
pub fn new(
name: &'static str,
kernel: ug::lang::ssa::Kernel,
device: &crate::Device,
) -> Result<Self> {
#[cfg(feature = "cuda")]
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(feature = "metal")]
{
let device = device.as_metal_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(not(any(feature = "cuda", feature = "metal")))]
{
Ok(Self { name })
}
}
}
impl InplaceOp1 for UgIOp1 {
fn name(&self) -> &'static str {
self.name
}
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
crate::bail!("ug ops are only supported on metal/cuda at the moment")
}
#[cfg(feature = "metal")]
fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
use crate::backend::BackendStorage;
use candle_metal_kernels::utils::EncoderProvider;
let elem_count = layout.shape().elem_count();
if sto.dtype() != crate::DType::F32 {
// TODO: support more dtypes.
crate::bail!("input is not a f32 tensor")
}
let device = sto.device();
println!("here");
let command_buffer = device.command_buffer()?;
let command_buffer = &command_buffer;
let encoder = command_buffer.encoder();
let encoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&self.func);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let grid_dims = metal::MTLSize {
width: g as u64,
height: 1,
depth: 1,
};
let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
encoder.dispatch_threads(grid_dims, group_dims);
Ok(())
}
#[cfg(feature = "cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::LaunchAsync;
let elem_count = layout.shape().elem_count();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let params = (&sto,);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (g as u32, 1, 1),
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe { self.func.clone().launch(cfg, params) }.w()?;
Ok(())
}
}

View File

@ -130,6 +130,26 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
match self {
Self::Cuda(d) => Ok(d),
Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
}
}
pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
match self {
Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
Self::Cpu => crate::bail!("expected a metal device, got cpu"),
Self::Metal(d) => Ok(d),
}
}
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
}
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}

View File

@ -14,6 +14,12 @@ macro_rules! fail {
};
}
impl CudaDevice {
pub fn new_with_stream(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
}
impl crate::backend::BackendStorage for CudaStorage {
type Device = CudaDevice;

View File

@ -165,6 +165,9 @@ pub enum Error {
#[error("Metal error {0}")]
Metal(#[from] MetalError),
#[error(transparent)]
Ug(#[from] ug::Error),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),
@ -179,6 +182,10 @@ pub enum Error {
#[error(transparent)]
ParseInt(#[from] std::num::ParseIntError),
/// Utf8 parse error.
#[error(transparent)]
FromUtf8(#[from] std::string::FromUtf8Error),
/// I/O error.
#[error(transparent)]
Io(#[from] std::io::Error),

View File

@ -35,6 +35,12 @@ impl Layout {
self.shape.dims()
}
/// The dimension size for a specified dimension index.
pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
let dim = dim.to_index(&self.shape, "dim")?;
Ok(self.dims()[dim])
}
pub fn shape(&self) -> &Shape {
&self.shape
}

View File

@ -32,6 +32,20 @@
//! Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
//!
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
//!
//! ## Other Crates
//!
//! Candle consists of a number of crates. This crate holds core the common data structures but you may wish
//! to look at the docs for the other crates which can be found here:
//!
//! - [candle-core](https://docs.rs/candle-core/). Core Datastructures and DataTypes.
//! - [candle-nn](https://docs.rs/candle-nn/). Building blocks for Neural Nets.
//! - [candle-datasets](https://docs.rs/candle-datasets/). Rust access to commonly used Datasets like MNIST.
//! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use.
//! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models.
//! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python.
//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models.
//!
#[cfg(feature = "accelerate")]
mod accelerate;
@ -77,7 +91,7 @@ mod variable;
pub use cuda_backend::cudnn;
pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};

View File

@ -144,6 +144,28 @@ impl MetalDevice {
self.use_mlx_mm = use_mlx_mm
}
pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<metal::ComputePipelineState> {
let mut buf = vec![];
ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?;
let metal_code = String::from_utf8(buf)?;
let lib = self
.device
.new_library_with_source(&metal_code, &metal::CompileOptions::new())
.map_err(MetalError::from)?;
let func = lib
.get_function(func_name, None)
.map_err(MetalError::from)?;
let pl = self
.device
.new_compute_pipeline_state_with_function(&func)
.map_err(MetalError::from)?;
Ok(pl)
}
pub fn id(&self) -> DeviceId {
self.id
}

View File

@ -1865,9 +1865,9 @@ impl BackendDevice for MetalDevice {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
let kernels = Arc::new(Kernels::new());
let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() {
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false,
Ok(_) => true,
let use_mlx_mm = match std::env::var("CANDLE_USE_MFA_MM").as_deref() {
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => true,
Ok(_) => false,
};
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
@ -1917,10 +1917,38 @@ impl BackendDevice for MetalDevice {
))
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let name = match dtype {
DType::U8 => "fill_u8",
DType::U32 => "fill_u32",
DType::I64 => "fill_i64",
DType::F16 => "fill_f16",
DType::BF16 => "fill_bf16",
DType::F32 => "fill_f32",
DType::F64 => {
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
return self.storage_from_cpu_storage(&cpu_storage);
}
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_const_fill(
&self.device,
&command_buffer,
&self.kernels,
name,
shape.elem_count(),
&buffer,
1.,
)
.map_err(MetalError::from)?;
Ok(MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {

View File

@ -6,9 +6,15 @@ use half::f16;
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
#[derive(Clone, Debug)]
struct PaddedCudaSlice {
inner: CudaSlice<u8>,
len: usize,
}
#[derive(Clone, Debug)]
pub struct QCudaStorage {
data: CudaSlice<u8>,
data: PaddedCudaSlice,
dtype: GgmlDType,
device: CudaDevice,
}
@ -34,10 +40,7 @@ fn ceil_div(p: usize, q: usize) -> usize {
}
fn pad(p: usize, q: usize) -> usize {
// Overallocate by q rather than just padding by q as this should pad the last row
// and we don't have enough information here to know how many elements to add :(
// ceil_div(p, q) * q
p + q
ceil_div(p, q) * q
}
fn quantize_q8_1(
@ -64,7 +67,7 @@ fn quantize_q8_1(
}
fn dequantize_f32(
data: &CudaSlice<u8>,
data: &PaddedCudaSlice,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
@ -107,21 +110,21 @@ fn dequantize_f32(
};
if is_k {
let params = (data, &dst);
let params = (&data.inner, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
let params = (&data.inner, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_f16(
data: &CudaSlice<u8>,
data: &PaddedCudaSlice,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
@ -164,21 +167,21 @@ fn dequantize_f16(
};
if is_k {
let params = (data, &dst);
let params = (&data.inner, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
let params = (&data.inner, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>,
data: &PaddedCudaSlice,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
@ -187,7 +190,7 @@ fn dequantize_mul_mat_vec(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
@ -216,13 +219,13 @@ fn dequantize_mul_mat_vec(
shared_mem_bytes: 0,
};
let params = (data, y, &dst, ncols as i32, nrows as i32);
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn mul_mat_vec_via_q8_1(
data: &CudaSlice<u8>,
data: &PaddedCudaSlice,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
@ -232,7 +235,7 @@ fn mul_mat_vec_via_q8_1(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
@ -279,7 +282,7 @@ fn mul_mat_vec_via_q8_1(
};
let params = (
data,
&data.inner,
&y_q8_1,
&dst,
/* ncols_x */ ncols as i32,
@ -293,7 +296,7 @@ fn mul_mat_vec_via_q8_1(
#[allow(clippy::too_many_arguments)]
fn mul_mat_via_q8_1(
data: &CudaSlice<u8>,
data: &PaddedCudaSlice,
y: &CudaView<f32>,
dtype: GgmlDType,
x_rows: usize,
@ -304,7 +307,7 @@ fn mul_mat_via_q8_1(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
}
@ -318,7 +321,7 @@ fn mul_mat_via_q8_1(
// Start by quantizing y
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
@ -348,7 +351,7 @@ fn mul_mat_via_q8_1(
};
let params = (
/* vx */ data,
/* vx */ &data.inner,
/* vy */ &y_q8_1,
/* dst */ &dst,
/* ncols_x */ x_cols as i32,
@ -364,9 +367,14 @@ fn mul_mat_via_q8_1(
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
let padded_size_in_bytes =
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes).w()?;
Ok(QCudaStorage {
data,
data: PaddedCudaSlice {
inner,
len: size_in_bytes,
},
device: device.clone(),
dtype,
})
@ -406,7 +414,10 @@ impl QCudaStorage {
}
// Run the dequantization on cpu.
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
let buffer = self
.device
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
.w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
@ -442,18 +453,26 @@ impl QCudaStorage {
}
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = pad(src.len(), MATRIX_ROW_PADDING);
let src_len = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?;
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
let padded_len =
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
self.device
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
.w()?;
self.data = PaddedCudaSlice {
inner,
len: data.len(),
};
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.data.len()
self.data.len
}
pub fn fwd(
@ -576,11 +595,19 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
let data = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
};
let data = device.htod_sync_copy(data).w()?;
let dtype = T::DTYPE;
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
device
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
.w()?;
Ok(QStorage::Cuda(QCudaStorage {
data,
data: PaddedCudaSlice {
inner,
len: data.len(),
},
device: device.clone(),
dtype: T::DTYPE,
dtype,
}))
}
@ -680,4 +707,28 @@ mod test {
assert_eq!(vs[15], 13138824.0);
Ok(())
}
// The following test used to fail under compute-sanitizer until #2526.
#[test]
fn cuda_mm_q8_1_pad() -> Result<()> {
let dev = CudaDevice::new(0)?;
let (x_rows, ncols, y_cols) = (4, 16, 2048);
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ x_rows,
/* x_cols */ ncols,
/* y_rows */ ncols,
/* y_cols */ y_cols,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
Ok(())
}
}

View File

@ -18,7 +18,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
let actual_blocks = ys.len();
// Validate that the input is the right size
if actual_blocks < expected_blocks {
if expected_blocks != actual_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
}

View File

@ -142,6 +142,12 @@ impl Shape {
&self.0
}
/// The dimension size for a specified dimension index.
pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
let dim = dim.to_index(self, "dim")?;
Ok(self.dims()[dim])
}
/// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize {
self.0.iter().product()

View File

@ -1520,14 +1520,15 @@ impl Tensor {
/// # Arguments
///
/// * `self` - The input tensor.
/// * `indexes` - The indices of elements to gather, this should have the same shape as `self`
/// but can have a different number of elements on the target dimension.
/// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
/// and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
/// * `dim` - the target dimension.
///
/// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
/// dimension `dim` by the values in `indexes`.
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "gather")?;
let self_dims = self.dims();
let indexes_dims = indexes.dims();
let mismatch = if indexes_dims.len() != self_dims.len() {
@ -1535,7 +1536,7 @@ impl Tensor {
} else {
let mut mismatch = false;
for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
if i != dim && d1 != d2 {
if i != dim && d1 < d2 {
mismatch = true;
break;
}

View File

@ -143,3 +143,39 @@ fn inplace_op1() -> Result<()> {
);
Ok(())
}
#[cfg(any(feature = "cuda", feature = "metal"))]
#[allow(clippy::approx_constant)]
#[test]
fn ug_op() -> Result<()> {
let kernel = {
use ug::lang::op;
let layout = ug::Layout::from_shape(&[12]);
let ptr = op::Arg::ptr(ug::DType::F32);
let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?;
let src = op::unary(op::UnaryOp::Exp, src)?;
let st = op::store(ptr.id(), layout, src)?;
let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]);
let opts: ug::lower_op::Opts = Default::default();
kernel.lower(&opts.with_global(0, 12))?
};
let device = if candle_core::utils::cuda_is_available() {
Device::new_cuda(0)?
} else if candle_core::utils::metal_is_available() {
Device::new_metal(0)?
} else {
candle_core::bail!("metal/cuda is mandatory for this test")
};
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
t.inplace_op1(&op)?;
assert_eq!(
to_vec1_round(&t, 2)?,
&[
1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47,
59874.13
]
);
Ok(())
}

View File

@ -29,6 +29,36 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
[
[
half::f16::from_f32(1.0),
half::f16::from_f32(1.0),
half::f16::from_f32(1.0)
],
[
half::f16::from_f32(1.0),
half::f16::from_f32(1.0),
half::f16::from_f32(1.0)
]
],
);
assert_eq!(
Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
[
[
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0)
],
[
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0)
]
],
);
Ok(())
}
@ -1017,6 +1047,280 @@ fn gather(device: &Device) -> Result<()> {
let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?;
let hs = t.gather(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
// Random data
// Dim: 0
let t = Tensor::new(
&[
[
[108_f32, -47., 16., -56., -83., -130., 210.],
[253., 95., 151., 228., -210., -123., -127.],
[-9., -217., 2., -78., 163., 245., -204.],
[-246., 79., -238., 88., -226., -184., 171.],
[8., -48., -153., 234., -34., 166., -153.],
[124., 0., -10., -61., -242., -15., -238.],
],
[
[12., -64., -199., 244., -240., 156., -128.],
[173., -57., 4., -198., 233., -110., 238.],
[95., 82., 0., 240., 53., -211., 209.],
[-122., 167., -212., 227., -144., 61., 118.],
[-63., -146., 200., 244., 168., -167., 116.],
[-125., -147., 110., -253., -178., -250., -18.],
],
[
[57., 86., -50., 56., 92., 205., -78.],
[-137., -156., -18., 248., -61., -239., 14.],
[-248., -30., -50., -70., -251., 250., -83.],
[-221., 67., 72., 59., -24., -154., 232.],
[-144., -23., -74., 5., 93., 171., 205.],
[46., -77., -38., -226., 246., 161., -17.],
],
[
[-153., -231., -236., 161., 126., 2., -22.],
[-229., -41., 209., 164., 234., 160., 57.],
[223., 254., -186., -162., -46., -160., -102.],
[65., 30., 213., -253., 59., 224., -154.],
[-82., -203., -177., 17., 31., -256., -246.],
[176., -135., -65., 54., -56., 210., 76.],
],
[
[-10., -245., 168., 124., -14., -33., -178.],
[25., -43., -39., 132., -89., 169., 179.],
[187., -215., 32., -133., 87., -7., -168.],
[-224., -215., -5., -230., -58., -162., 128.],
[158., -137., -122., -100., -202., -83., 136.],
[30., -185., -144., 250., 209., -40., 127.],
],
[
[-196., 108., -245., 122., 146., -228., 62.],
[-1., -66., 160., 137., 13., -172., -21.],
[244., 199., -164., 28., 119., -175., 198.],
[-62., 253., -162., 195., -95., -230., -211.],
[123., -72., -26., -107., -139., 64., 245.],
[11., -126., -182., 108., -12., 184., -127.],
],
[
[-159., 126., 176., 161., 73., -111., -138.],
[-187., 214., -217., -33., -223., -201., -212.],
[-61., -120., -166., -172., -95., 53., 196.],
[-33., 86., 134., -152., 154., -53., 74.],
[186., -28., -154., -174., 141., -109., 217.],
[82., 35., 252., 145., 181., 74., -87.],
],
],
device,
)?;
let ids = Tensor::new(
&[
[
[6_u32, 6, 4, 3, 4, 4, 6],
[3, 3, 2, 4, 4, 4, 6],
[3, 3, 0, 2, 4, 6, 4],
[2, 5, 1, 2, 6, 6, 1],
[2, 1, 6, 5, 3, 2, 3],
[6, 1, 0, 1, 0, 2, 6],
],
[
[4, 6, 4, 3, 3, 3, 2],
[4, 3, 2, 4, 4, 4, 6],
[2, 3, 0, 2, 4, 6, 4],
[6, 5, 1, 2, 6, 6, 1],
[4, 1, 6, 5, 3, 2, 3],
[1, 1, 0, 1, 0, 2, 6],
],
[
[3, 6, 4, 3, 3, 3, 2],
[2, 3, 2, 4, 4, 4, 6],
[4, 3, 0, 2, 4, 6, 4],
[0, 5, 1, 2, 6, 6, 1],
[6, 1, 6, 5, 3, 2, 3],
[4, 1, 0, 1, 0, 2, 6],
],
[
[0, 6, 4, 3, 3, 3, 2],
[5, 3, 2, 4, 4, 4, 6],
[0, 3, 0, 2, 4, 6, 4],
[3, 5, 1, 2, 6, 6, 1],
[0, 1, 6, 5, 3, 2, 3],
[3, 1, 0, 1, 0, 2, 6],
],
],
device,
)?;
let hs = t.gather(&ids, 0)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[
[-159_f32, 126., 168., 161., -14., -33., -138.],
[-229., -41., -18., 132., -89., 169., -212.],
[223., 254., 2., -70., 87., 53., -168.],
[-221., 253., -212., 59., 154., -53., 118.],
[-144., -146., -154., -107., 31., 171., -246.],
[82., -147., -10., -253., -242., 161., -87.]
],
[
[-10., 126., 168., 161., 126., 2., -78.],
[25., -41., -18., 132., -89., 169., -212.],
[-248., 254., 2., -70., 87., 53., -168.],
[-33., 253., -212., 59., 154., -53., 118.],
[158., -146., -154., -107., 31., 171., -246.],
[-125., -147., -10., -253., -242., 161., -87.]
],
[
[-153., 126., 168., 161., 126., 2., -78.],
[-137., -41., -18., 132., -89., 169., -212.],
[187., 254., 2., -70., 87., 53., -168.],
[-246., 253., -212., 59., 154., -53., 118.],
[186., -146., -154., -107., 31., 171., -246.],
[30., -147., -10., -253., -242., 161., -87.]
],
[
[108., 126., 168., 161., 126., 2., -78.],
[-1., -41., -18., 132., -89., 169., -212.],
[-9., 254., 2., -70., 87., 53., -168.],
[65., 253., -212., 59., 154., -53., 118.],
[8., -146., -154., -107., 31., 171., -246.],
[176., -147., -10., -253., -242., 161., -87.]
]
]
);
// Dim: 1
let t = Tensor::new(
&[
[
[-117_f32, -175., 69., -163.],
[200., 242., -21., -67.],
[179., 150., -126., -75.],
[-118., 38., -138., -13.],
[-221., 136., -185., 180.],
[58., 182., -204., -149.],
],
[
[3., -148., -58., -154.],
[-43., 45., -108., 4.],
[-69., -249., -71., -21.],
[80., 110., -152., -235.],
[-88., 7., 92., -250.],
[-186., 207., -242., 98.],
],
[
[238., 19., 64., -242.],
[-150., -97., 218., 58.],
[111., -233., 204., -212.],
[-242., -232., 83., 42.],
[153., 62., -251., 219.],
[-117., 36., -119., 10.],
],
[
[215., 159., -169., -27.],
[-83., 101., -88., 169.],
[-205., 93., 225., -64.],
[-162., 240., 214., 23.],
[-112., 6., 21., 245.],
[-38., 113., 93., 215.],
],
[
[91., -188., -148., 101.],
[74., 203., -35., 55.],
[-116., -130., -153., -96.],
[58., 22., -45., -194.],
[-221., -134., 73., 159.],
[-203., -254., 31., 235.],
],
[
[105., -53., 61., 186.],
[-195., 234., 75., -1.],
[51., 139., 160., -108.],
[-173., -167., 161., 19.],
[83., -246., 156., -222.],
[109., 39., -149., 137.],
],
],
device,
)?;
let ids = Tensor::new(
&[
[[4_u32, 4, 4, 2]],
[[0, 4, 4, 3]],
[[1, 5, 3, 4]],
[[0, 3, 3, 2]],
[[1, 1, 5, 2]],
[[1, 4, 5, 4]],
],
device,
)?;
let hs = t.gather(&ids, 1)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[[-221., 136., -185., -75.]],
[[3., 7., 92., -235.]],
[[-150., 36., 83., 219.]],
[[215., 240., 214., -64.]],
[[74., 203., 31., -96.]],
[[-195., -246., -149., -222.]]
]
);
// Dim: 2
let t = Tensor::new(
&[
[[-162_f32, 202.], [-126., -39.], [35., -65.], [1., 80.]],
[[37., 248.], [-191., 89.], [117., -40.], [-217., 220.]],
],
device,
)?;
let ids = Tensor::new(&[[[1_u32], [0], [1], [1]], [[0], [1], [0], [1]]], device)?;
let hs = t.gather(&ids, 2)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[[202.], [-126.], [-65.], [80.]],
[[37.], [89.], [117.], [220.]]
]
);
let t = Tensor::new(
&[
[[-21_f32, -197.], [194., 122.]],
[[255., -106.], [-191., 250.]],
[[33., -117.], [43., 10.]],
[[-130., 238.], [-217., -92.]],
],
device,
)?;
let ids = Tensor::new(
&[
[[0_u32, 1], [1, 0]],
[[1, 0], [0, 1]],
[[0, 1], [0, 1]],
[[1, 0], [1, 0]],
],
device,
)?;
let hs = t.gather(&ids, 2)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[[-21., -197.], [122., 194.]],
[[-106., 255.], [-191., 250.]],
[[33., -117.], [43., 10.]],
[[238., -130.], [-92., -217.]]
]
);
Ok(())
}

View File

@ -27,7 +27,7 @@ intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
pyo3 = { version = "0.22.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
safetensors = { workspace = true }
@ -36,6 +36,7 @@ serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2" , optional = true}
[dev-dependencies]
anyhow = { workspace = true }
@ -65,7 +66,7 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
microphone = ["cpal", "rubato"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
@ -117,3 +118,7 @@ required-features = ["depth_anything_v2"]
[[example]]
name = "silero-vad"
required-features = ["onnx"]
[[example]]
name = "colpali"
required-features = ["pdf2image"]

View File

@ -0,0 +1,224 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, Device, Tensor};
use candle_nn as nn;
use candle_transformers::models::chinese_clip::{ChineseClipConfig, ChineseClipModel};
use clap::Parser;
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
tracing_subscriber::fmt::init();
let device = candle_examples::device(args.cpu)?;
let var = load_weights(args.model, &device)?;
let clip_model = ChineseClipModel::new(var, &ChineseClipConfig::clip_vit_base_patch16())?;
tracing::info!("Transformer loaded. ");
let (pixel_values, vec_imgs) = load_images(args.images, &device)?;
tracing::info!("Images loaded. ");
let tokenizer = load_tokenizer()?;
let (input_ids, type_ids, attention_mask, text_sequences) =
tokenize_sequences(args.sequences, &tokenizer, &device)?;
tracing::info!("Computing ... ");
let (_logits_per_text, logits_per_image) = clip_model.forward(
&pixel_values,
&input_ids,
Some(&type_ids),
Some(&attention_mask),
)?;
let softmax_image = nn::ops::softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
tracing::info!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
tracing::info!("Probability: {:.4}% Text: {} ", p, text_sequences[i]);
}
}
Ok(())
}
pub fn load_weights(model: Option<String>, device: &Device) -> anyhow::Result<nn::VarBuilder> {
let model_file = match model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = hf_hub::Repo::with_revision(
"OFA-Sys/chinese-clip-vit-base-patch16".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
);
let api = api.repo(repo);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
Ok(unsafe { nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, device)? })
}
pub fn load_tokenizer() -> anyhow::Result<Tokenizer> {
let tokenizer_file = {
let api = hf_hub::api::sync::Api::new()?;
let repo = hf_hub::Repo::with_revision(
"OFA-Sys/chinese-clip-vit-base-patch16".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
);
let api = api.repo(repo);
api.get("tokenizer.json")?
};
Tokenizer::from_file(tokenizer_file).map_err(anyhow::Error::msg)
}
pub fn tokenize_sequences(
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Tensor, Tensor, Vec<String>)> {
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"自行车比赛".to_string(),
"两只猫咪".to_string(),
"拿着蜡烛的机器人".to_string(),
],
};
let mut input_ids = vec![];
let mut type_ids = vec![];
let mut attention_mask = vec![];
let mut max_len = 0;
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(anyhow::Error::msg)?;
input_ids.push(encoding.get_ids().to_vec());
type_ids.push(encoding.get_type_ids().to_vec());
attention_mask.push(encoding.get_attention_mask().to_vec());
if encoding.get_ids().len() > max_len {
max_len = encoding.get_ids().len();
}
}
let pad_id = *tokenizer
.get_vocab(true)
.get("[PAD]")
.ok_or(anyhow::Error::msg("No pad token"))?;
let input_ids: Vec<Vec<u32>> = input_ids
.iter_mut()
.map(|item| {
item.extend(vec![pad_id; max_len - item.len()]);
item.to_vec()
})
.collect();
let type_ids: Vec<Vec<u32>> = type_ids
.iter_mut()
.map(|item| {
item.extend(vec![0; max_len - item.len()]);
item.to_vec()
})
.collect();
let attention_mask: Vec<Vec<u32>> = attention_mask
.iter_mut()
.map(|item| {
item.extend(vec![0; max_len - item.len()]);
item.to_vec()
})
.collect();
let input_ids = Tensor::new(input_ids, device)?;
let type_ids = Tensor::new(type_ids, device)?;
let attention_mask = Tensor::new(attention_mask, device)?;
Ok((input_ids, type_ids, attention_mask, vec_seq))
}
pub fn load_images(
images: Option<Vec<String>>,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let vec_imgs = match images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let mut images = vec![];
for path in vec_imgs.iter() {
let tensor = load_image(path, 224, device)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?.to_device(device)?;
Ok((images, vec_imgs))
}
fn load_image<T: AsRef<std::path::Path>>(
path: T,
image_size: usize,
device: &Device,
) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8().into_raw();
let img = Tensor::from_vec(img, (height, width, 3), device)?.permute((2, 0, 1))?;
let mean = Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], device)?.reshape((3, 1, 1))?;
let std =
Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], device)?.reshape((3, 1, 1))?;
let img = (img.to_dtype(DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)?;
Ok(img)
}

View File

@ -12,7 +12,6 @@ use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::clip;
use tokenizers::Tokenizer;
use tracing::info;
#[derive(Parser)]
struct Args {
@ -40,15 +39,12 @@ fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
// .unsqueeze(0)?;
Ok(img)
}
@ -57,24 +53,16 @@ fn load_images<T: AsRef<std::path::Path>>(
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
// std::env::set_var("RUST_BACKTRACE", "full");
let args = Args::parse();
tracing_subscriber::fmt::init();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
@ -89,13 +77,9 @@ pub fn main() -> anyhow::Result<()> {
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = clip::ClipConfig::vit_base_patch32();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
@ -103,43 +87,29 @@ pub fn main() -> anyhow::Result<()> {
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = clip::ClipModel::new(vb, &config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
info!("softmax_image_vec: {:?}", softmax_image_vec);
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
info!("\n\nResults for image: {}\n", img);
println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
@ -156,7 +126,6 @@ pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
@ -169,7 +138,6 @@ pub fn tokenize_sequences(
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("No pad token"))?;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
@ -178,16 +146,12 @@ pub fn tokenize_sequences(
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
@ -195,8 +159,6 @@ pub fn tokenize_sequences(
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -0,0 +1,18 @@
# Colpali
[HuggingFace Model Card](https://huggingface.co/vidore/colpali-v1.2-merged)
```
wget https://arxiv.org/pdf/1706.03762.pdf
cargo run --features cuda,pdf2image --release --example colpali -- --prompt "What is Positional Encoding" --pdf "1706.03762.pdf"
```
```
Prompt: what is position encoding?
top 3 page numbers that contain similarity to the prompt
-----------------------------------
Page: 6
Page: 11
Page: 15
-----------------------------------
```

View File

@ -0,0 +1,268 @@
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::colpali::Model;
use candle_transformers::models::{colpali, paligemma};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use image::DynamicImage;
use pdf2image::{RenderOptionsBuilder, PDF};
use tokenizers::Tokenizer;
struct PageRetriever {
model: Model,
config: paligemma::Config,
pdf: PDF,
device: Device,
tokenizer: Tokenizer,
range: pdf2image::Pages,
batch_size: usize,
top_k: usize,
}
impl PageRetriever {
fn new(
model: Model,
config: paligemma::Config,
pdf: PDF,
tokenizer: Tokenizer,
device: &Device,
range: Option<pdf2image::Pages>,
batch_size: usize,
top_k: usize,
) -> Self {
let page_count = pdf.page_count();
Self {
model,
config,
pdf,
device: device.clone(),
tokenizer,
range: range.unwrap_or_else(|| pdf2image::Pages::Range(1..=page_count)),
batch_size,
top_k,
}
}
fn get_images_from_pdf(&self) -> Result<Vec<DynamicImage>> {
let pages = self
.pdf
.render(self.range.clone(), RenderOptionsBuilder::default().build()?)?;
Ok(pages)
}
fn tokenize_batch(&self, prompts: Vec<&str>) -> Result<Tensor> {
let tokens = self.tokenizer.encode_batch(prompts, true).map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), &self.device)
})
.collect::<candle::Result<Vec<_>>>()?;
let input = Tensor::stack(&token_ids, 0)?;
Ok(input)
}
fn images_to_tensor(
&self,
pages: &[DynamicImage],
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for page in pages.iter() {
let img = page.resize_to_fill(
image_size as u32,
image_size as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (image_size, image_size, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
images.push(img);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
fn retrieve(&mut self, prompt: &str) -> Result<Vec<usize>> {
let dtype = if self.device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let dummy_prompt: &str = "Describe the image";
let input = self.tokenize_batch(vec![prompt])?;
let dummy_input = self.tokenize_batch(vec![dummy_prompt])?;
let pages = self.get_images_from_pdf()?;
let mut all_scores = Vec::new();
for batch in pages.chunks(self.batch_size) {
let page_images = self
.images_to_tensor(batch, self.config.vision_config.image_size)?
.to_device(&self.device)?
.to_dtype(dtype)?;
let dummy_input = dummy_input.repeat((page_images.dims()[0], 0))?;
let image_embeddings = self.model.forward_images(&page_images, &dummy_input)?;
let text_embeddings = self.model.forward_text(&input)?;
let scores = text_embeddings
.unsqueeze(1)?
.broadcast_matmul(&image_embeddings.unsqueeze(0)?.transpose(3, 2)?)?
.max(3)?
.sum(2)?;
let batch_scores: Vec<f32> = scores
.to_dtype(DType::F32)?
.to_vec2()?
.into_iter()
.flatten()
.collect();
all_scores.extend(batch_scores);
}
let mut indices: Vec<usize> = (0..all_scores.len()).collect();
indices.sort_by(|a, b| all_scores[*b].partial_cmp(&all_scores[*a]).unwrap());
let top_k_indices = indices[0..self.top_k].to_vec();
Ok(top_k_indices)
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// number of top pages to show.
#[arg(long, default_value_t = 3)]
top_k: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
pdf: String,
#[arg(long)]
start: Option<u32>,
#[arg(long)]
end: Option<u32>,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "vidore/colpali-v1.2-merged".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.repo(Repo::with_revision(
"vidore/colpali".to_string(),
RepoType::Model,
"main".to_string(),
))
.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
let start = std::time::Instant::now();
let config: paligemma::Config = paligemma::Config::paligemma_3b_448();
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(false)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = colpali::Model::new(&config, vb)?;
let pdf = PDF::from_file(args.pdf)?;
// check if start and end given in arg
let range = if let (Some(start), Some(end)) = (args.start, args.end) {
pdf2image::Pages::Range(start..=end)
} else {
pdf2image::Pages::Range(1..=pdf.page_count()) // can use pdf2image::Pages::All but there is a bug in the library which causes the first page to rendered twice.
};
let mut retriever =
PageRetriever::new(model, config, pdf, tokenizer, &device, Some(range), 4, 3);
let top_k_indices = retriever.retrieve(&args.prompt)?;
println!("Prompt: {}", args.prompt);
println!(
"top {} page numbers that contain similarity to the prompt",
retriever.top_k
);
println!("-----------------------------------");
for index in top_k_indices {
println!("Page: {:?}", index + 1);
}
println!("-----------------------------------");
Ok(())
}

View File

@ -1,4 +1,3 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};

View File

@ -44,6 +44,14 @@ struct Args {
#[arg(long, value_enum, default_value = "schnell")]
model: Model,
/// Use the slower kernels.
#[arg(long)]
use_dmmv: bool,
/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
@ -65,6 +73,7 @@ fn run(args: Args) -> Result<()> {
decode_only,
model,
quantized,
..
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
@ -86,6 +95,9 @@ fn run(args: Args) -> Result<()> {
api.repo(hf_hub::Repo::model(name.to_string()))
};
let device = candle_examples::device(cpu)?;
if let Some(seed) = args.seed {
device.set_seed(seed)?;
}
let dtype = device.bf16_default_to_f32();
let img = match decode_only {
None => {
@ -244,5 +256,7 @@ fn run(args: Args) -> Result<()> {
fn main() -> Result<()> {
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(args.use_dmmv);
run(args)
}

View File

@ -35,10 +35,26 @@ enum Which {
V31,
V3Instruct,
V31Instruct,
V32_1b,
V32_1bInstruct,
V32_3b,
V32_3bInstruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
TinyLlama1_1BChat,
#[value(name = "SmoLM2-1.7B")]
SmolLM2_1B,
#[value(name = "SmoLM2-1.7B-Instruct")]
SmolLM2_1BInstruct,
#[value(name = "SmoLM2-360M")]
SmolLM2_360M,
#[value(name = "SmoLM2-360M-Instruct")]
SmolLM2_360MInstruct,
#[value(name = "SmoLM2-135M")]
SmolLM2_135M,
#[value(name = "SmoLM2-135M-Instruct")]
SmolLM2_135MInstruct,
}
#[derive(Parser, Debug)]
@ -130,15 +146,28 @@ fn main() -> Result<()> {
};
let (llama, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
let model_id = args.model_id.unwrap_or_else(|| {
let str = match args.which {
Which::V1 => "Narsil/amall-7b",
Which::V2 => "meta-llama/Llama-2-7b-hf",
Which::V3 => "meta-llama/Meta-Llama-3-8B",
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct",
Which::V31 => "meta-llama/Llama-3.1-8B",
Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct",
Which::V32_1b => "meta-llama/Llama-3.2-1B",
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct",
Which::V32_3b => "meta-llama/Llama-3.2-3B",
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct",
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0",
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
Which::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
Which::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
Which::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
Which::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
Which::SmolLM2_1B => "HuggingFaceTB/SmolLM2-1.7B",
Which::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
};
str.to_string()
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
@ -156,10 +185,22 @@ fn main() -> Result<()> {
| Which::V3Instruct
| Which::V31
| Which::V31Instruct
| Which::V32_3b
| Which::V32_3bInstruct
| Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
Which::SmolLM2_360M
| Which::SmolLM2_360MInstruct
| Which::SmolLM2_135M
| Which::SmolLM2_135MInstruct
| Which::SmolLM2_1B
| Which::SmolLM2_1BInstruct
| Which::V32_1b
| Which::V32_1bInstruct
| Which::TinyLlama1_1BChat => {
vec![api.get("model.safetensors")?]
}
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;

View File

@ -1,4 +1,3 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};

View File

@ -60,7 +60,6 @@ fn load_images<T: AsRef<std::path::Path>>(
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = candle_examples::imagenet::load_image_with_std_mean(
path,
@ -70,9 +69,7 @@ fn load_images<T: AsRef<std::path::Path>>(
)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
@ -80,24 +77,17 @@ pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_name = args.which.model_name();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
let model_file = if args.use_pth {
api.get("open_clip_pytorch_model.bin")?
} else {
api.get("open_clip_model.safetensors")?
};
let tokenizer = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let config = &args.which.config();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
@ -105,9 +95,7 @@ pub fn main() -> anyhow::Result<()> {
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb = if args.use_pth {
VarBuilder::from_pth(&model_file, DType::F32, &device)?
} else {
@ -115,22 +103,15 @@ pub fn main() -> anyhow::Result<()> {
};
let model = mobileclip::MobileClipModel::new(vb, config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
@ -171,7 +152,6 @@ pub fn tokenize_sequences(
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
@ -185,8 +165,6 @@ pub fn tokenize_sequences(
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -0,0 +1,28 @@
# PaliGemma
[HuggingFace Model Card](https://huggingface.co/google/paligemma-3b-pt-224) -
[Model Page](https://ai.google.dev/gemma/docs/paligemma)
```bash
cargo run --features cuda --release --example paligemma -- \
--prompt "caption fr" --image candle-examples/examples/yolo-v8/assets/bike.jpg
```
```
loaded image with shape Tensor[dims 1, 3, 224, 224; bf16, cuda:0]
loaded the model in 1.267744448s
caption fr. Un groupe de cyclistes qui sont dans la rue.
13 tokens generated (56.52 token/s)
```
```bash
cargo run --features cuda --release --example paligemma -- \
--prompt "caption fr" --image candle-examples/examples/flux/assets/flux-robot.jpg
```
```
loaded image with shape Tensor[dims 1, 3, 224, 224; bf16, cuda:0]
loaded the model in 1.271492621s
caption fr une image d' un robot sur la plage avec le mot rouillé
15 tokens generated (62.78 token/s)
```

View File

@ -0,0 +1,276 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::paligemma::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
image: Tensor,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
image: Tensor,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
image,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = if index > 0 {
self.model.forward(&input)?
} else {
self.model.setup(&self.image, &input)?
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
image: String,
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
Ok(img)
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "google/paligemma-3b-mix-224".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let config = Config::paligemma_3b_224();
let image = load_image(&args.image, config.vision_config.image_size)?
.to_device(&device)?
.to_dtype(dtype)?
.unsqueeze(0)?;
println!("loaded image with shape {:?}", image);
let start = std::time::Instant::now();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
image,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
let prompt = format!("{}\n", args.prompt);
pipeline.run(&prompt, args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,28 @@
# pixtral
Pixtral-12B is a 12B text+vision model.
[Blog Post](https://mistral.ai/news/pixtral-12b/) -
[HF Model Card](https://huggingface.co/mistralai/Pixtral-12B-2409) -
[HF Community Model Card](https://huggingface.co/mistral-community/pixtral-12b).
```bash
cargo run --profile=release-with-debug --features cuda --example pixtral -- \
--image candle-examples/examples/flux/assets/flux-robot.jpg
```
```
Describe the image.
The image depicts a charming, rustic robot standing on a sandy beach at sunset.
The robot has a vintage, steampunk aesthetic with visible gears and mechanical
parts. It is holding a small lantern in one hand, which emits a warm glow, and
its other arm is extended forward as if reaching out or guiding the way. The
robot's body is adorned with the word "RUST" in bright orange letters, adding to
its rustic theme.
The background features a dramatic sky filled with clouds, illuminated by the
setting sun, casting a golden hue over the scene. Gentle waves lap against the
shore, creating a serene and picturesque atmosphere. The overall mood of the
image is whimsical and nostalgic, evoking a sense of adventure and tranquility.
```

View File

@ -0,0 +1,327 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::pixtral::{vision_model, Config, Model};
use candle::{DType, Device, Module, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
image: Tensor,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
image: Tensor,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
image,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut generated_tokens = 0usize;
let get_token = |v| match self.tokenizer.get_token(v) {
Some(token) => Ok(token),
None => anyhow::bail!("cannot find the {v} token"),
};
let bos_token = get_token("<s>")?;
let eos_token = get_token("</s>")?;
let inst_token = get_token("[INST]")?;
let end_inst_token = get_token("[/INST]")?;
let img_break = get_token("[IMG_BREAK]")?;
let img_end = get_token("[IMG_END]")?;
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let logits = if index > 0 {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
self.model.lm_forward(&input)?
} else {
let (_b, _c, h, w) = self.image.dims4()?;
let h = h / self.model.patch_size;
let w = w / self.model.patch_size;
let image_embeds = self.model.encode_image(&self.image)?;
println!("generated image embeddings {image_embeds:?}");
let image_embeds = image_embeds.to_dtype(self.model.dtype)?;
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let break_embeds = {
let input = Tensor::new(&[img_break], &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let start_embeds = {
let mut in_tokens = vec![bos_token, inst_token];
in_tokens.extend_from_slice(tokens.as_slice());
let input = Tensor::new(in_tokens.as_slice(), &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let end_embeds = {
let input =
Tensor::new(&[img_end, end_inst_token], &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let mut input_embeds = vec![start_embeds];
for h_idx in 0..h {
if h_idx > 0 {
input_embeds.push(break_embeds.clone())
}
let row = image_embeds.narrow(1, h_idx * w, w)?;
input_embeds.push(row);
}
input_embeds.push(end_embeds);
let input_embeds = Tensor::cat(&input_embeds, 1)?;
self.model.lm_forward_embeds(&input_embeds)?
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long, default_value = "Describe the image.\n")]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
image: String,
#[arg(long)]
vision_only: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "mistral-community/pixtral-12b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let device = candle_examples::device(args.cpu)?;
let dtype = if device.supports_bf16() && !args.vision_only {
DType::BF16
} else {
DType::F32
};
let config: Config = match args.config_file {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
};
let image = if args.image.ends_with(".safetensors") {
match candle::safetensors::load(&args.image, &device)?.remove("img") {
None => anyhow::bail!("no img tensor in {}", args.image),
Some(v) => v,
}
} else {
candle_examples::imagenet::load_image_with_std_mean(
&args.image,
1024,
&[0.48145466, 0.4578275, 0.40821073],
&[0.26862954, 0.261_302_6, 0.275_777_1],
)?
};
let image = image.to_device(&device)?.unsqueeze(0)?;
println!("loaded image with shape {:?}", image);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
if args.vision_only {
let start = std::time::Instant::now();
let model = vision_model::Model::new(&config.vision_config, vb.pp("vision_tower"))?;
println!("loaded the model in {:?}", start.elapsed());
let embs = model.forward(&image)?;
println!("EMBS\n{embs}");
} else {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
image,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
}
Ok(())
}

View File

@ -71,6 +71,10 @@ enum Which {
L8b,
#[value(name = "phi3")]
Phi3,
#[value(name = "SmoLM2-360M-Instruct")]
SmolLM2_360MInstruct,
#[value(name = "SmoLM2-1.7B-Instruct")]
SmolLM2_1BInstruct,
}
impl Which {
@ -88,7 +92,9 @@ impl Which {
| Self::Leo7b
| Self::Leo13b
| Self::L8b
| Self::Phi3 => false,
| Self::Phi3
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
@ -124,6 +130,8 @@ impl Which {
| Self::OpenChat35
| Self::Starling7bAlpha
| Self::L8b
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct
| Self::Phi3 => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
@ -150,6 +158,8 @@ impl Which {
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::L8b
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct
| Self::Phi3 => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
}
@ -179,6 +189,8 @@ impl Which {
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L8b => "meta-llama/Meta-Llama-3-8B",
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
Self::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
Self::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
}
}
}
@ -343,6 +355,14 @@ impl Args {
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
),
Which::SmolLM2_360MInstruct => (
"HuggingFaceTB/SmolLM2-360M-Instruct-GGUF",
"smollm2-360m-instruct-q8_0.gguf",
),
Which::SmolLM2_1BInstruct => (
"HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
"smollm2-1.7b-instruct-q4_k_m.gguf",
),
};
let revision = if self.which == Which::Phi3 {
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
@ -455,6 +475,8 @@ fn main() -> anyhow::Result<()> {
| Which::Leo7b
| Which::Leo13b
| Which::L8b
| Which::SmolLM2_1BInstruct
| Which::SmolLM2_360MInstruct
| Which::Phi3 => 1,
Which::Mixtral
| Which::MixtralInstruct
@ -573,6 +595,7 @@ fn main() -> anyhow::Result<()> {
}
let eos_token = match args.which {
Which::SmolLM2_360MInstruct | Which::SmolLM2_1BInstruct => "<|endoftext|>",
Which::L8b => "<|end_of_text|>",
_ => match args.which.is_open_chat() {
true => "<|end_of_turn|>",

View File

@ -0,0 +1,24 @@
## SigLIP
SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmoid based loss,
[HuggingFace](https://huggingface.co/google/siglip-base-patch16-224).
### Running an example
```
$ cargo run --features cuda -r --example siglip -
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
Probability: 0.0000% Text: a cycling race
Probability: 0.0000% Text: a photo of two cats
Probability: 100.0000% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
Probability: 100.0000% Text: a cycling race
Probability: 0.0000% Text: a photo of two cats
Probability: 0.0000% Text: a robot holding a candle
```

View File

@ -0,0 +1,153 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::siglip;
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
Ok(img)
}
fn load_images<T: AsRef<std::path::Path>>(
paths: &Vec<T>,
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = siglip::Config::base_patch16_224();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = siglip::Model::new(&config, vb)?;
let (input_ids, vec_seq) = tokenize_sequences(&config, args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
pub fn tokenize_sequences(
config: &siglip::Config,
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let pad_id = config.text_config.pad_token_id;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"a cycling race".to_string(),
"a photo of two cats".to_string(),
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = config.text_config.max_position_embeddings;
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -0,0 +1,28 @@
# candle-splade
SPLADE is a neural retrieval model which learns query/document sparse expansion via the BERT MLM head and sparse regularization. Sparse representations benefit from several advantages compared to dense approaches: efficient use of inverted index, explicit lexical match, interpretability... They also seem to be better at generalizing on out-of-domain data. In this example we can do the following two tasks:
- Compute sparse embedding for a given query.
- Compute similarities between a set of sentences using sparse embeddings.
## Sparse Sentence embeddings
SPLADE is used to compute the sparse embedding for a given query. The model weights
are downloaded from the hub on the first run. This makes use of the BertForMaskedLM model.
```bash
cargo run --example splade --release -- --prompt "Here is a test sentence"
> "the out there still house inside position outside stay standing hotel sitting dog animal sit bird cat statue cats"
> [0.10270107, 0.269471, 0.047469813, 0.0016636598, 0.05394874, 0.23105666, 0.037475716, 0.45949644, 0.009062732, 0.06790692, 0.0327835, 0.33122346, 0.16863061, 0.12688516, 0.340983, 0.044972017, 0.47724655, 0.01765311, 0.37331146]
```
```bash
cargo run --example splade --release --features
> score: 0.47 'The new movie is awesome' 'The new movie is so great'
> score: 0.43 'The cat sits outside' 'The cat plays in the garden'
> score: 0.14 'I love pasta' 'Do you like pizza?'
> score: 0.11 'A man is playing guitar' 'The cat plays in the garden'
> score: 0.05 'A man is playing guitar' 'A woman watches TV'
```

View File

@ -0,0 +1,210 @@
use std::path::PathBuf;
use anyhow::{Error as E, Result};
use candle::Tensor;
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{self, BertForMaskedLM, Config};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
// Path to the tokenizer file.
#[arg(long)]
tokenizer_file: Option<String>,
// Path to the weight files.
#[arg(long)]
weight_files: Option<String>,
// Path to the config file.
#[arg(long)]
config_file: Option<String>,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "prithivida/Splade_PP_en_v1".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let weights_filename = match args.weight_files {
Some(files) => PathBuf::from(files),
None => match repo.get("model.safetensors") {
Ok(safetensors) => safetensors,
Err(_) => match repo.get("pytorch_model.bin") {
Ok(pytorch_model) => pytorch_model,
Err(e) => {
return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
}
},
},
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(args.cpu)?;
let dtype = bert::DTYPE;
let vb = if weights_filename.ends_with("model.safetensors") {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], dtype, &device).unwrap() }
} else {
println!("Loading weights from pytorch_model.bin");
VarBuilder::from_pth(&weights_filename, dtype, &device).unwrap()
};
let model = BertForMaskedLM::load(vb, &config)?;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
let ys = model.forward(&token_ids, &token_type_ids, None)?;
let vec = Tensor::log(
&Tensor::try_from(1.0)?
.to_dtype(dtype)?
.to_device(&device)?
.broadcast_add(&ys.relu()?)?,
)?
.max(1)?;
let vec = normalize_l2(&vec)?;
let vec = vec.squeeze(0)?.to_vec1::<f32>()?;
let indices = (0..vec.len())
.filter(|&i| vec[i] != 0.0)
.map(|x| x as u32)
.collect::<Vec<_>>();
let tokens = tokenizer.decode(&indices, true).unwrap();
println!("{tokens:?}");
let values = indices.iter().map(|&i| vec[i as usize]).collect::<Vec<_>>();
println!("{values:?}");
} else {
let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
];
let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let tokens = tokenizer
.encode_batch(sentences.to_vec(), true)
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Ok(Tensor::new(tokens.as_slice(), &device)?)
})
.collect::<Result<Vec<_>>>()?;
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Ok(Tensor::new(tokens.as_slice(), &device)?)
})
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;
let token_type_ids = token_ids.zeros_like()?;
let ys = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
let vector = Tensor::log(
&Tensor::try_from(1.0)?
.to_dtype(dtype)?
.to_device(&device)?
.broadcast_add(&ys.relu()?)?,
)?;
let vector = vector
.broadcast_mul(&attention_mask.unsqueeze(2)?.to_dtype(dtype)?)?
.max(1)?;
let vec = normalize_l2(&vector)?;
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = vec.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = vec.get(j)?;
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

View File

@ -0,0 +1,71 @@
# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3/3.5
![](assets/stable-diffusion-3.jpg)
*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k*, generated by Stable Diffusion 3 Medium
Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture.
- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
- [research paper](https://arxiv.org/pdf/2403.03206)
- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium)
Stable Diffusion 3.5 is a family of text-to-image models with latest improvements:
- [announcement blog post](https://stability.ai/news/introducing-stable-diffusion-3-5)
It has three variants:
- [Stable Diffusion 3.5 Large](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) @ 8.1b params, with scaled and slightly modified MMDiT architecture.
- [Stable Diffusion 3.5 Large Turbo](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo) distilled version that enables 4-step inference.
- [Stable Diffusion 3.5 Medium](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) @ 2.5b params, with improved MMDiT-X architecture.
## Getting access to the weights
The weights of Stable Diffusion 3/3.5 is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the repos on HuggingFace Hub to gain access to the weights for your HuggingFace account.
To allow your computer to gain access to the public-gated repos on HuggingFace, you might need to create a [HuggingFace User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) and log in on your computer if you haven't done that before. A convenient way to do the login is to use [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/cli):
```shell
huggingface-cli login
```
and you will be prompted to enter your token.
On the first run, the weights will be automatically downloaded from the Huggingface Hub. After the download, the weights will be [cached](https://huggingface.co/docs/datasets/en/cache) and remain accessible locally.
## Running the model
```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- \
--which 3-medium --height 1024 --width 1024 \
--prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k'
```
To use different models, changed the value of `--which` option. (Possible values: `3-medium`, `3.5-large`, `3.5-large-turbo` and `3.5-medium`).
To display other options available,
```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- --help
```
If GPU supports, Flash-Attention is a strongly recommended feature as it can greatly improve the speed of inference, as MMDiT is a transformer model heavily depends on attentions. To utilize [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) in the demo, you will need both `--features flash-attn` and `--use-flash-attn`.
```shell
cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- --use-flash-attn ...
```
## Performance Benchmark
Below benchmark is done with Stable Diffusion 3 Medium by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds).
[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc).
System specs (Desktop PCIE 5 x8/x8 dual-GPU setup):
- Operating System: Ubuntu 23.10
- CPU: i9 12900K w/o overclocking.
- RAM: 64G dual-channel DDR5 @ 4800 MT/s
| Speed (iter/s) | w/o flash-attn | w/ flash-attn |
| -------------- | -------------- | ------------- |
| RTX 3090 Ti | 0.83 | 2.15 |
| RTX 4090 | 1.72 | 4.06 |

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

View File

@ -0,0 +1,234 @@
use anyhow::{Error as E, Ok, Result};
use candle::{DType, IndexOp, Module, Tensor, D};
use candle_transformers::models::{stable_diffusion, t5};
use std::path::PathBuf;
use tokenizers::tokenizer::Tokenizer;
struct ClipWithTokenizer {
clip: stable_diffusion::clip::ClipTextTransformer,
config: stable_diffusion::clip::Config,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}
impl ClipWithTokenizer {
fn new(
vb: candle_nn::VarBuilder,
config: stable_diffusion::clip::Config,
tokenizer_path: &str,
max_position_embeddings: usize,
) -> Result<Self> {
let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?;
let path_buf = hf_hub::api::sync::Api::new()?
.model(tokenizer_path.to_string())
.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg(
"Failed to serialize huggingface PathBuf of CLIP tokenizer",
))?)
.map_err(E::msg)?;
Ok(Self {
clip,
config,
tokenizer,
max_position_embeddings,
})
}
fn encode_text_to_embedding(
&self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let pad_id = match &self.config.pad_with {
Some(padding) => *self
.tokenizer
.get_vocab(true)
.get(padding.as_str())
.ok_or(E::msg("Failed to tokenize CLIP padding."))?,
None => *self
.tokenizer
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?,
};
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let eos_position = tokens.len() - 1;
while tokens.len() < self.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let (text_embeddings, text_embeddings_penultimate) = self
.clip
.forward_until_encoder_layer(&tokens, usize::MAX, -2)?;
let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?;
Ok((text_embeddings_penultimate, text_embeddings_pooled))
}
}
struct T5WithTokenizer {
t5: t5::T5EncoderModel,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}
impl T5WithTokenizer {
fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result<Self> {
let api = hf_hub::api::sync::Api::new()?;
let repo = api.repo(hf_hub::Repo::with_revision(
"google/t5-v1_1-xxl".to_string(),
hf_hub::RepoType::Model,
"refs/pr/2".to_string(),
));
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let model = t5::T5EncoderModel::load(vb, &config)?;
let tokenizer_filename = api
.model("lmz/mt5-tokenizers".to_string())
.get("t5-v1_1-xxl.tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok(Self {
t5: model,
tokenizer,
max_position_embeddings,
})
}
fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<Tensor> {
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.resize(self.max_position_embeddings, 0);
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let embeddings = self.t5.forward_dt(&input_token_ids, Some(DType::F32))?;
Ok(embeddings)
}
}
pub struct StableDiffusion3TripleClipWithTokenizer {
clip_l: ClipWithTokenizer,
clip_g: ClipWithTokenizer,
clip_g_text_projection: candle_nn::Linear,
t5: T5WithTokenizer,
}
impl StableDiffusion3TripleClipWithTokenizer {
pub fn new_split(
clip_g_file: &PathBuf,
clip_l_file: &PathBuf,
t5xxl_file: &PathBuf,
device: &candle::Device,
) -> Result<Self> {
let vb_clip_g = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_g_file], DType::F16, device)?
};
let vb_clip_l = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)?
};
let vb_t5 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F16, device)?
};
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb_clip_l,
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;
let text_projection =
candle_nn::linear_no_bias(1280, 1280, vb_clip_g.pp("text_projection"))?;
let clip_g = ClipWithTokenizer::new(
vb_clip_g,
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;
let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
})
}
pub fn new(vb: candle_nn::VarBuilder) -> Result<Self> {
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb.pp("clip_l.transformer"),
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;
let clip_g = ClipWithTokenizer::new(
vb.pp("clip_g.transformer"),
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;
let text_projection =
candle_nn::linear_no_bias(1280, 1280, vb.pp("clip_g.transformer.text_projection"))?;
let t5 = T5WithTokenizer::new(vb.pp("t5xxl.transformer"), max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
})
}
pub fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let (clip_l_embeddings, clip_l_embeddings_pooled) =
self.clip_l.encode_text_to_embedding(prompt, device)?;
let (clip_g_embeddings, clip_g_embeddings_pooled) =
self.clip_g.encode_text_to_embedding(prompt, device)?;
let clip_g_embeddings_pooled = self
.clip_g_text_projection
.forward(&clip_g_embeddings_pooled.unsqueeze(0)?)?
.squeeze(0)?;
let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)?
.unsqueeze(0)?;
let clip_embeddings_concat = Tensor::cat(
&[&clip_l_embeddings, &clip_g_embeddings],
D::Minus1,
)?
.pad_with_zeros(D::Minus1, 0, 2048)?;
let t5_embeddings = self
.t5
.encode_text_to_embedding(prompt, device)?
.to_dtype(DType::F16)?;
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;
Ok((context, y))
}
}

View File

@ -0,0 +1,273 @@
mod clip;
mod sampling;
mod vae;
use candle::{DType, IndexOp, Tensor};
use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT};
use crate::clip::StableDiffusion3TripleClipWithTokenizer;
use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
use anyhow::{Ok, Result};
use clap::Parser;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "3-medium")]
V3Medium,
#[value(name = "3.5-large")]
V3_5Large,
#[value(name = "3.5-large-turbo")]
V3_5LargeTurbo,
#[value(name = "3.5-medium")]
V3_5Medium,
}
impl Which {
fn is_3_5(&self) -> bool {
match self {
Self::V3Medium => false,
Self::V3_5Large | Self::V3_5LargeTurbo | Self::V3_5Medium => true,
}
}
}
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The prompt to be used for image generation.
#[arg(
long,
default_value = "A cute rusty robot holding a candle torch in its hand, \
with glowing neon text \"LETS GO RUSTY\" displayed on its chest, \
bright background, high quality, 4k"
)]
prompt: String,
#[arg(long, default_value = "")]
uncond_prompt: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Use flash_attn to accelerate attention operation in the MMDiT.
#[arg(long)]
use_flash_attn: bool,
/// The height in pixels of the generated image.
#[arg(long, default_value_t = 1024)]
height: usize,
/// The width in pixels of the generated image.
#[arg(long, default_value_t = 1024)]
width: usize,
/// The model to use.
#[arg(long, default_value = "3-medium")]
which: Which,
/// The seed to use when generating random samples.
#[arg(long)]
num_inference_steps: Option<usize>,
/// CFG scale.
#[arg(long)]
cfg_scale: Option<f64>,
/// Time shift factor (alpha).
#[arg(long, default_value_t = 3.0)]
time_shift: f64,
/// Use Skip Layer Guidance (SLG) for the sampling.
/// Currently only supports Stable Diffusion 3.5 Medium.
#[arg(long)]
use_slg: bool,
/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
uncond_prompt,
cpu,
tracing,
use_flash_attn,
height,
width,
num_inference_steps,
cfg_scale,
time_shift,
seed,
which,
use_slg,
} = Args::parse();
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(cpu)?;
let default_inference_steps = match which {
Which::V3_5Large => 28,
Which::V3_5LargeTurbo => 4,
Which::V3_5Medium => 28,
Which::V3Medium => 28,
};
let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps);
let default_cfg_scale = match which {
Which::V3_5Large => 4.0,
Which::V3_5LargeTurbo => 1.0,
Which::V3_5Medium => 4.0,
Which::V3Medium => 4.0,
};
let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale);
let api = hf_hub::api::sync::Api::new()?;
let (mmdit_config, mut triple, vb) = if which.is_3_5() {
let sai_repo_for_text_encoders = {
let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
// Unfortunately, stabilityai/stable-diffusion-3.5-medium doesn't have the monolithic text encoders that's usually
// placed under the text_encoders directory, like the case in stabilityai/stable-diffusion-3.5-large and -large-turbo.
// To make things worse, it currently only has partitioned model.fp16-00001-of-00002.safetensors and model.fp16-00002-of-00002.safetensors
// under the text_encoder_3 directory, for the t5xxl_fp16.safetensors model. This means that we need to merge the two partitions
// to get the monolithic text encoders. This is not a trivial task.
// Since the situation can change, we do not want to spend efforts to handle the uniqueness of stabilityai/stable-diffusion-3.5-medium,
// which involves different paths and merging the two partitions files for t5xxl_fp16.safetensors.
// so for now, we'll use the text encoder models from the stabilityai/stable-diffusion-3.5-large repository.
// TODO: Change to "stabilityai/stable-diffusion-3.5-medium" once the maintainers of the repository add back the monolithic text encoders.
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-large",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let sai_repo_for_mmdit = {
let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-medium",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let clip_g_file = sai_repo_for_text_encoders.get("text_encoders/clip_g.safetensors")?;
let clip_l_file = sai_repo_for_text_encoders.get("text_encoders/clip_l.safetensors")?;
let t5xxl_file = sai_repo_for_text_encoders.get("text_encoders/t5xxl_fp16.safetensors")?;
let model_file = {
let model_file = match which {
Which::V3_5Large => "sd3.5_large.safetensors",
Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors",
Which::V3_5Medium => "sd3.5_medium.safetensors",
Which::V3Medium => unreachable!(),
};
sai_repo_for_mmdit.get(model_file)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new_split(
&clip_g_file,
&clip_l_file,
&t5xxl_file,
&device,
)?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)?
};
match which {
Which::V3_5Large => (MMDiTConfig::sd3_5_large(), triple, vb),
Which::V3_5LargeTurbo => (MMDiTConfig::sd3_5_large(), triple, vb),
Which::V3_5Medium => (MMDiTConfig::sd3_5_medium(), triple, vb),
Which::V3Medium => unreachable!(),
}
} else {
let sai_repo = {
let name = "stabilityai/stable-diffusion-3-medium";
api.repo(hf_hub::Repo::model(name.to_string()))
};
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new(vb.pp("text_encoders"))?;
(MMDiTConfig::sd3_medium(), triple, vb)
};
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
let (context_uncond, y_uncond) =
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
// Drop the text model early to avoid using too much memory.
drop(triple);
let context = Tensor::cat(&[context, context_uncond], 0)?;
let y = Tensor::cat(&[y, y_uncond], 0)?;
if let Some(seed) = seed {
device.set_seed(seed)?;
}
let slg_config = if use_slg {
match which {
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394
Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig {
scale: 2.5,
start: 0.01,
end: 0.2,
layers: vec![7, 8, 9],
}),
_ => anyhow::bail!("--use-slg can only be used with 3.5-medium"),
}
} else {
None
};
let start_time = std::time::Instant::now();
let x = {
let mmdit = MMDiT::new(
&mmdit_config,
use_flash_attn,
vb.pp("model.diffusion_model"),
)?;
sampling::euler_sample(
&mmdit,
&y,
&context,
num_inference_steps,
cfg_scale,
time_shift,
height,
width,
slg_config,
)?
};
let dt = start_time.elapsed().as_secs_f32();
println!(
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
dt,
num_inference_steps as f32 / dt
);
let img = {
let vb_vae = vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;
// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
autoencoder.decode(&((x / 1.5305)? + 0.0609)?)?
};
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
Ok(())
}

View File

@ -0,0 +1,83 @@
use anyhow::{Ok, Result};
use candle::{DType, IndexOp, Tensor};
use candle_transformers::models::flux;
use candle_transformers::models::mmdit::model::MMDiT;
pub struct SkipLayerGuidanceConfig {
pub scale: f64,
pub start: f64,
pub end: f64,
pub layers: Vec<usize>,
}
#[allow(clippy::too_many_arguments)]
pub fn euler_sample(
mmdit: &MMDiT,
y: &Tensor,
context: &Tensor,
num_inference_steps: usize,
cfg_scale: f64,
time_shift: f64,
height: usize,
width: usize,
slg_config: Option<SkipLayerGuidanceConfig>,
) -> Result<Tensor> {
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
let sigmas = (0..=num_inference_steps)
.map(|x| x as f64 / num_inference_steps as f64)
.rev()
.map(|x| time_snr_shift(time_shift, x))
.collect::<Vec<f64>>();
for (step, window) in sigmas.windows(2).enumerate() {
let (s_curr, s_prev) = match window {
[a, b] => (a, b),
_ => continue,
};
let timestep = (*s_curr) * 1000.0;
let noise_pred = mmdit.forward(
&Tensor::cat(&[&x, &x], 0)?,
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
None,
)?;
let mut guidance = apply_cfg(cfg_scale, &noise_pred)?;
if let Some(slg_config) = slg_config.as_ref() {
if (num_inference_steps as f64) * slg_config.start < (step as f64)
&& (step as f64) < (num_inference_steps as f64) * slg_config.end
{
let slg_noise_pred = mmdit.forward(
&x,
&Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
&y.i(..1)?,
&context.i(..1)?,
Some(&slg_config.layers),
)?;
guidance = (guidance
+ (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?;
}
}
x = (x + (guidance * (*s_prev - *s_curr))?)?;
}
Ok(x)
}
// The "Resolution-dependent shifting of timestep schedules" recommended in the SD3 tech report paper
// https://arxiv.org/pdf/2403.03206
// Following the implementation in ComfyUI:
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/
// comfy/model_sampling.py#L181
fn time_snr_shift(alpha: f64, t: f64) -> f64 {
alpha * t / (1.0 + (alpha - 1.0) * t)
}
fn apply_cfg(cfg_scale: f64, noise_pred: &Tensor) -> Result<Tensor> {
Ok(((cfg_scale * noise_pred.narrow(0, 0, 1)?)?
- ((cfg_scale - 1.0) * noise_pred.narrow(0, 1, 1)?)?)?)
}

View File

@ -0,0 +1,93 @@
use anyhow::{Ok, Result};
use candle_transformers::models::stable_diffusion::vae;
pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result<vae::AutoEncoderKL> {
let config = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 16,
norm_num_groups: 32,
use_quant_conv: false,
use_post_quant_conv: false,
};
Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?)
}
pub fn sd3_vae_vb_rename(name: &str) -> String {
let parts: Vec<&str> = name.split('.').collect();
let mut result = Vec::new();
let mut i = 0;
while i < parts.len() {
match parts[i] {
"down_blocks" => {
result.push("down");
}
"mid_block" => {
result.push("mid");
}
"up_blocks" => {
result.push("up");
match parts[i + 1] {
// Reverse the order of up_blocks.
"0" => result.push("3"),
"1" => result.push("2"),
"2" => result.push("1"),
"3" => result.push("0"),
_ => {}
}
i += 1; // Skip the number after up_blocks.
}
"resnets" => {
if i > 0 && parts[i - 1] == "mid_block" {
match parts[i + 1] {
"0" => result.push("block_1"),
"1" => result.push("block_2"),
_ => {}
}
i += 1; // Skip the number after resnets.
} else {
result.push("block");
}
}
"downsamplers" => {
result.push("downsample");
i += 1; // Skip the 0 after downsamplers.
}
"conv_shortcut" => {
result.push("nin_shortcut");
}
"attentions" => {
if parts[i + 1] == "0" {
result.push("attn_1")
}
i += 1; // Skip the number after attentions.
}
"group_norm" => {
result.push("norm");
}
"query" => {
result.push("q");
}
"key" => {
result.push("k");
}
"value" => {
result.push("v");
}
"proj_attn" => {
result.push("proj_out");
}
"conv_norm_out" => {
result.push("norm_out");
}
"upsamplers" => {
result.push("upsample");
i += 1; // Skip the 0 after upsamplers.
}
part => result.push(part),
}
i += 1;
}
result.join(".")
}

View File

@ -0,0 +1,45 @@
# candle-stella-en-v5: Implementation of [stella_en_1.5B_v5](https://huggingface.co/dunzhang/stella_en_1.5B_v5) embedding model
As of 7th Oct 2024, *Stella_en_1.5B_v5* is one of the top ranking model on `retrieval` and `reranking` tasks in [MTEB](https://huggingface.co/spaces/mteb/leaderboard) leaderboard.
[Model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) on the HuggingFace Hub.
## Running the example
Stella_en_1.5B_v5 is used to generate text embeddings embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?"
> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]]
> Tensor[[1, 1024], f32]
```
Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling multiple embedding dimensions.
The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example.
```bash
$ cargo run --example stella-en-v5 --release --features <metal | cuda>
>
> Score: 0.8178786
> Query: What are some ways to reduce stress?
> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending
> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent
> stress from building up.
>
>
> Score: 0.7853528
> Query: What are the benefits of drinking green tea?
> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types >
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
>
```
## Supported options:
- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option.

View File

@ -0,0 +1,359 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use std::path::Path;
use anyhow::{anyhow, Error as E, Result};
use clap::Parser;
use candle_transformers::models::stella_en_v5::{
Config, EmbedDim as StellaEmbedDim, EmbeddingModel,
};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo};
use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer};
struct Embedding {
model: EmbeddingModel,
device: Device,
tokenizer: Tokenizer,
}
impl Embedding {
fn new(model: EmbeddingModel, tokenizer: Tokenizer, device: &Device) -> Self {
Self {
model,
tokenizer,
device: device.clone(),
}
}
fn encode(&mut self, task: EncodeTask, text: Option<String>) -> Result<()> {
// Just shocasing embeddings, this has no real value
if let Some(text) = text {
let qry = task.query_preproc(&[text]);
let encoding = self.tokenizer.encode(qry, true).map_err(|e| anyhow!(e))?;
let shape = (1, encoding.len());
let input = Tensor::from_slice(encoding.get_ids(), shape, &self.device)?;
let mask = Tensor::from_slice(encoding.get_attention_mask(), shape, &self.device)?;
let result = self.model.forward(&input, &mask)?;
println!("embeddings: {result}");
} else {
// Examples copied from [Model Card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#transformers)
let queries = [
"What are some ways to reduce stress?".to_string(),
"What are the benefits of drinking green tea?".to_string(),
];
let docs = [
"There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.".to_string(),
"Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.".to_string(),
];
// We only encode the queries and not the data
let qry = task.query_preproc(&queries);
let mut qry_encoded = self
.tokenizer
.encode_batch(qry, true)
.map_err(|e| anyhow!(e))?;
let mut docs_encoded = self
.tokenizer
.encode_batch(docs.to_vec(), true)
.map_err(|e| anyhow!(e))?;
let qry_embed = {
// Now, we generate the tensors for the `input` and `mask`
let shape = (qry_encoded.len(), qry_encoded[1].len());
let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?;
let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?;
for (i, e) in qry_encoded.drain(..).enumerate() {
let input_id =
Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?;
let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)?
.to_dtype(DType::U8)?
.unsqueeze(0)?;
ids =
ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?;
masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?;
}
// Let's generate the embeddings for the query, we are going to be normalizing the result.
// For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data
self.model.forward_norm(&ids, &masks)?
};
let doc_embed = {
let shape = (docs_encoded.len(), docs_encoded[1].len());
let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?;
let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?;
for (i, e) in docs_encoded.drain(..).enumerate() {
let input_id =
Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?;
let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)?
.to_dtype(DType::U8)?
.unsqueeze(0)?;
ids =
ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?;
masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?;
}
// Let's generate the embeddings for the query, we are going to be normalizing the result.
// For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data
self.model.forward_norm(&ids, &masks)?
};
println!(
"Embed shapes:\nQuery: {:?}\nDocs: {:?}",
qry_embed.shape(),
doc_embed.shape()
); // [2, 1024] for head dim `1024`
// a matmul to generate the `similarity` score
let res = qry_embed.matmul(&doc_embed.t()?)?;
for (k, v) in queries.iter().enumerate() {
let tnsr = res.get(k)?;
let max = tnsr.argmax(0)?.to_scalar::<u32>()?;
println!(
"\nScore: {}\nQuery: {}\nAnswer: {}\n\n",
tnsr.get(max as usize)?.to_scalar::<f32>()?,
v,
docs[k]
);
}
}
Ok(())
}
}
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
enum EmbedDim {
#[value(name = "256")]
Dim256,
#[value(name = "768")]
Dim768,
#[value(name = "1024")]
Dim1024,
#[value(name = "2048")]
Dim2048,
#[value(name = "4096")]
Dim4096,
#[value(name = "6144")]
Dim6144,
#[value(name = "8192")]
Dim8192,
}
impl EmbedDim {
/// Returns dir path to the embed head weights int he repo
pub fn embed_dim_default_dir(&self) -> &'static str {
match self {
Self::Dim256 => "2_Dense_256",
Self::Dim768 => "2_Dense_768",
Self::Dim1024 => "2_Dense_1024",
Self::Dim2048 => "2_Dense_2048",
Self::Dim4096 => "2_Dense_4096",
Self::Dim6144 => "2_Dense_6144",
Self::Dim8192 => "2_Dense_8192",
}
}
/// Resolves the `EmbedDim` for given variant
pub fn embed_dim(&self) -> StellaEmbedDim {
match self {
Self::Dim256 => StellaEmbedDim::Dim256,
Self::Dim768 => StellaEmbedDim::Dim768,
Self::Dim1024 => StellaEmbedDim::Dim1024,
Self::Dim2048 => StellaEmbedDim::Dim2048,
Self::Dim4096 => StellaEmbedDim::Dim4096,
Self::Dim6144 => StellaEmbedDim::Dim6144,
Self::Dim8192 => StellaEmbedDim::Dim8192,
}
}
}
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
pub enum EncodeTask {
/// `s2p` is the `retrieval` task
/// Default in this example
#[value(name = "s2p")]
S2P,
/// `s2s` is the semantic similarity task
#[value(name = "s2s")]
S2S,
}
impl EncodeTask {
/// Preprocess a set of inputs basef on a template suggested by the model authors
/// See: https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction
pub fn query_preproc(&self, txt: &[String]) -> Vec<String> {
let instruct = match self {
Self::S2P => {
"Given a web search query, retrieve relevant passages that answer the query."
}
Self::S2S => "Retrieve semantically similar text.",
};
txt.iter()
.map(|s| format!("Instruct: {instruct}\nQuery: {s}"))
.collect::<Vec<_>>()
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
query: Option<String>,
#[arg(long, default_value = "1024")]
embed_dim: Option<EmbedDim>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
base_weight_files: Option<String>,
#[arg(long)]
embed_head_weight_files: Option<String>,
/// `Stella` is trained on 2 tasks: See [`Model Card`](https://huggingface.co/dunzhang/stella_en_1.5B_v5)
/// `s2s`: Semantic textual similarity
/// `s2p`: Retrieval task - `Default` in this example
#[arg(long, default_value = "s2p")]
task: Option<EncodeTask>,
}
// Tokenizer creation is super critical in our case.
// We are going to be `padding: Left` for each batch
fn create_tokenizer(tokenizer_file: &Path) -> Result<Tokenizer> {
let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
pad_id
} else {
return Err(anyhow!(
"Tokenizer doesn't contain expected `<|endoftext|>` token"
));
};
// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_id,
pad_token: "<|endoftext|>".to_string(),
..Default::default()
}));
Ok(tokenizer)
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let start = std::time::Instant::now();
let api = Api::new()?;
let embed_dim = match args.embed_dim {
Some(d) => d,
None => EmbedDim::Dim1024,
};
let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string()));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
// Note, if you are providing `weight_files`, ensure that the `--embed_dim` dimensions provided matches the weights
// E.g. if you are using `--embed_dim 1024`, the weight files should include the `.safetensors` file from `2_Dense_1024` dir of the repo
let base_weight_files = match args.base_weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
let embed_weight_files = match args.embed_head_weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
let head_w_path = format!("{}/model.safetensors", embed_dim.embed_dim_default_dir());
vec![repo.get(&head_w_path)?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
// Initializing the tokenizer which would require us to add padding to the `left` for batch encoding
let tokenizer = create_tokenizer(tokenizer_filename.as_path())?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = DType::F32;
let base_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? };
// Embedding layer is always built on F32 for accuracy
let embed_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? };
let model = EmbeddingModel::new(
&Config::new_1_5_b_v5(embed_dim.embed_dim()),
base_vb,
embed_vb,
)?;
println!("loaded the model in {:?}", start.elapsed());
let mut embedding = Embedding::new(model, tokenizer, &device);
let task = args.task.map_or(EncodeTask::S2P, |t| t);
embedding.encode(task, args.query)
}

View File

@ -10,7 +10,6 @@ use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use std::iter;
use tokenizers::Tokenizer;
mod multilingual;
@ -18,7 +17,6 @@ mod multilingual;
use candle_transformers::models::whisper::{self as m, audio, Config};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use std::sync::{Arc, Mutex};
pub enum Model {
Normal(m::model::Whisper),
@ -391,6 +389,7 @@ enum WhichModel {
Large,
LargeV2,
LargeV3,
LargeV3Turbo,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
@ -407,6 +406,7 @@ impl WhichModel {
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::LargeV3Turbo
| Self::DistilLargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
@ -427,6 +427,7 @@ impl WhichModel {
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::LargeV3Turbo => ("openai/whisper-large-v3-turbo", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
}
@ -479,6 +480,10 @@ struct Args {
/// Print the full DecodingResult structure rather than just the text.
#[arg(long)]
verbose: bool,
/// The input device to use.
#[arg(long)]
device: Option<String>,
}
pub fn main() -> Result<()> {
@ -543,13 +548,12 @@ pub fn main() -> Result<()> {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
};
let language_token = None;
let mut dc = Decoder::new(
let mut decoder = Decoder::new(
model,
tokenizer.clone(),
args.seed,
&device,
language_token,
/* language_token */ None,
args.task,
args.timestamps,
args.verbose,
@ -565,47 +569,69 @@ pub fn main() -> Result<()> {
// Set up the input device and stream with the default input config.
let host = cpal::default_host();
let _device = "default";
let _device = if _device == "default" {
host.default_input_device()
} else {
host.input_devices()?
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
let audio_device = match args.device.as_ref() {
None => host.default_input_device(),
Some(device) => host
.input_devices()?
.find(|x| x.name().map_or(false, |y| &y == device)),
}
.expect("failed to find input device");
.expect("failed to find the audio input device");
let _config = _device
let audio_config = audio_device
.default_input_config()
.expect("Failed to get default input config");
println!("audio config {audio_config:?}");
let channel_count = _config.channels() as usize;
let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
let audio_ring_buffer_2 = audio_ring_buffer.clone();
std::thread::spawn(move || loop {
let data = record_audio(&_device, &_config, 300).unwrap();
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
let max_len = data.len() * 16;
let data_len = data.len();
let len = audio_ring_buffer.lock().unwrap().len();
if len > max_len {
let mut data = audio_ring_buffer.lock().unwrap();
let new_data = data[data_len..].to_vec();
*data = new_data;
}
});
let channel_count = audio_config.channels() as usize;
let in_sample_rate = audio_config.sample_rate().0 as usize;
let resample_ratio = 16000. / in_sample_rate as f64;
let mut resampler = rubato::FastFixedIn::new(
resample_ratio,
10.,
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let (tx, rx) = std::sync::mpsc::channel();
let stream = audio_device.build_input_stream(
&audio_config.config(),
move |pcm: &[f32], _: &cpal::InputCallbackInfo| {
let pcm = pcm
.iter()
.step_by(channel_count)
.copied()
.collect::<Vec<f32>>();
if !pcm.is_empty() {
tx.send(pcm).unwrap()
}
},
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;
// loop to process the audio data forever (until the user stops the program)
println!("Transcribing audio...");
for (i, _) in iter::repeat(()).enumerate() {
std::thread::sleep(std::time::Duration::from_millis(1000));
let data = audio_ring_buffer_2.lock().unwrap().clone();
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
println!("transcribing audio...");
let mut buffered_pcm = vec![];
let mut language_token_set = false;
while let Ok(pcm) = rx.recv() {
use rubato::Resampler;
buffered_pcm.extend_from_slice(&pcm);
if buffered_pcm.len() < 10 * in_sample_rate {
continue;
}
let mut resampled_pcm = vec![];
for buffered_pcm in buffered_pcm.chunks(1024) {
let pcm = resampler.process(&[&buffered_pcm], None)?;
resampled_pcm.extend_from_slice(&pcm[0])
}
let pcm = resampled_pcm;
println!("{} {}", buffered_pcm.len(), pcm.len());
buffered_pcm.clear();
let mel = audio::pcm_to_mel(&config, &pcm, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(
mel,
@ -614,9 +640,13 @@ pub fn main() -> Result<()> {
)?;
// on the first iteration, we detect the language and set the language token.
if i == 0 {
if !language_token_set {
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
(true, None) => Some(multilingual::detect_language(
decoder.model(),
&tokenizer,
&mel,
)?),
(false, None) => None,
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
Ok(token_id) => Some(token_id),
@ -627,47 +657,12 @@ pub fn main() -> Result<()> {
}
};
println!("language_token: {:?}", language_token);
dc.set_language_token(language_token);
decoder.set_language_token(language_token);
language_token_set = true;
}
dc.run(
&mel,
Some((
i as f64,
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
)),
)?;
dc.reset_kv_cache();
decoder.run(&mel, None)?;
decoder.reset_kv_cache();
}
Ok(())
}
fn record_audio(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
milliseconds: u64,
) -> Result<Vec<i16>> {
let writer = Arc::new(Mutex::new(Vec::new()));
let writer_2 = writer.clone();
let stream = device.build_input_stream(
&config.config(),
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let processed = data
.iter()
.map(|v| (v * 32768.0) as i16)
.collect::<Vec<i16>>();
writer_2.lock().unwrap().extend_from_slice(&processed);
},
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
drop(stream);
let data = writer.lock().unwrap().clone();
let step = 3;
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
Ok(data)
}

View File

@ -12,7 +12,7 @@ file](https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/sample
from the hub.
```bash
cargo run --example whisper --release
cargo run --example whisper --release --features="symphonia"
> No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
> loaded wav data: Header { audio_format: 1, channel_count: 1, sampling_rate: 16000, bytes_per_second: 32000, bytes_per_sample: 2, bits_per_sample: 16 }

View File

@ -370,6 +370,7 @@ enum WhichModel {
Large,
LargeV2,
LargeV3,
LargeV3Turbo,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
@ -388,6 +389,7 @@ impl WhichModel {
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::LargeV3Turbo
| Self::DistilLargeV2
| Self::DistilLargeV3 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
@ -409,6 +411,7 @@ impl WhichModel {
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::LargeV3Turbo => ("openai/whisper-large-v3-turbo", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
Self::DistilLargeV3 => ("distil-whisper/distil-large-v3", "main"),

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.7.1"
version = "0.8.0"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.1" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.0" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.7.1"
version = "0.8.0"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -70,10 +70,9 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
template <typename T>
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const int block_size, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float2 mean_var = make_float2(0.f, 0.f);
@ -134,10 +133,9 @@ __device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta,
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
template <typename T>
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) {
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const int block_size, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float tmp = 0.0f; // partial sum for thread in warp
@ -530,15 +528,15 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#define RMSNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const int n_cols, const float eps) { \
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
const int n_cols, const int block_size, const float eps) { \
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, block_size, eps); \
} \
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const TYPENAME *beta, const int n_cols, const float eps) { \
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
const TYPENAME *beta, const int n_cols, const int block_size, const float eps) { \
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, block_size, eps); \
} \
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.7.1"
version = "0.8.0"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -0,0 +1,39 @@
#include <metal_stdlib>
using namespace metal;
template<typename T> METAL_FUNC void fill_with(
device T *out,
constant float &value,
constant size_t &numel,
uint tid [[thread_position_in_grid]]
) {
if (tid >= numel) {
return;
}
out[tid] = static_cast<T>(value);
}
#define FILL_OP(NAME, T) \
kernel void fill_##NAME( \
device T *out, \
constant float &value, \
constant size_t &numel, \
uint tid [[thread_position_in_grid]] \
) { \
fill_with<T>(out, value, numel, tid); \
} \
#define FILL_OPS(NAME, T) \
FILL_OP(NAME, T) \
FILL_OPS(u8, uchar)
FILL_OPS(u32, uint)
FILL_OPS(i64, long)
FILL_OPS(f16, half)
FILL_OPS(f32, float)
#if __METAL_VERSION__ >= 310
FILL_OPS(bf16, bfloat)
#endif

View File

@ -6,14 +6,15 @@ use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
mod utils;
pub mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderProvider};
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
const BINARY: &str = include_str!("binary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const FILL: &str = include_str!("fill.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
@ -24,6 +25,7 @@ const REDUCE: &str = include_str!("reduce.metal");
const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
@ -31,6 +33,7 @@ pub enum Source {
Binary,
Cast,
Conv,
Fill,
Gemm,
Indexing,
Mfa,
@ -40,6 +43,7 @@ pub enum Source {
Sort,
Ternary,
Unary,
Sdpa,
}
pub mod copy2d {
@ -157,6 +161,17 @@ pub enum MetalKernelError {
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
#[error("Sdpa {variation} head size was {got}, expectd {expected:?}")]
SdpaHeadSizeMismatch {
variation: &'static str,
got: usize,
expected: Vec<usize>,
},
#[error("Sdpa {variation} got dtype {got:?}")]
SdpaHeadDTypeMismatch {
variation: &'static str,
got: SdpaDType,
},
}
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@ -196,6 +211,7 @@ impl Kernels {
Source::Binary => BINARY,
Source::Cast => CAST,
Source::Conv => CONV,
Source::Fill => FILL,
Source::Gemm => MLX_GEMM,
Source::Indexing => INDEXING,
Source::Quantized => QUANTIZED,
@ -204,6 +220,7 @@ impl Kernels {
Source::Sort => SORT,
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Sdpa => SDPA,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -1624,6 +1641,313 @@ pub fn call_gemm(
Ok(())
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum SdpaDType {
BF16,
F16,
F32,
}
/// SDPA full is supported when:
/// - q head dim == 64, 128
/// - no mask
/// - q heads == kv heads
/// - final type != bf16 (TODO maybe just template this kernel too?)
/// - q,k,v are contiguous
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_full(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_buffer: &Buffer,
v_offset: usize,
v_buffer: &Buffer,
output: &Buffer,
alpha: f32,
softcapping: f32,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
#[derive(Debug)]
#[repr(C)]
struct MLXFastAttentionParams {
m: i32,
n: i32,
k: i32,
ldq: i32, // ldq == ldo
ldk: i32,
ldv: i32,
lds: i32,
ldo: i32,
tiles_n: i32,
tiles_m: i32,
batch_stride_q: i32,
batch_stride_k: i32,
batch_stride_v: i32,
batch_stride_o: i32,
swizzle_log: i32,
gemm_n_iterations_aligned: i32,
gemm_k_iterations_aligned: i32,
gemm_sv_m_block_iterations: i32,
batch_ndim: i32,
alpha: f32,
softcapping: f32,
}
let bk = q_shape.last().unwrap();
const BN: usize = 16;
const BM: usize = 16;
const WM: usize = 2;
const WN: usize = 2;
let name = match (bk, itype) {
(32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half",
(64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half",
(96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half",
(128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half",
(256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half",
(32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float",
(64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float",
(96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float",
(128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float",
(256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float",
(other, SdpaDType::F16 | SdpaDType::F32) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "full",
got: *other,
expected: vec![32, 64, 96, 128, 256],
})
}
(_, SdpaDType::BF16) => {
return Err(MetalKernelError::SdpaHeadDTypeMismatch {
variation: "full",
got: SdpaDType::BF16,
})
}
};
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, seq, hidden)
let qseq = q_shape[q_shape.len() - 2];
let m = q_shape[q_shape.len() - 2];
let n = m;
let k = q_shape[q_shape.len() - 1];
let bs_out = q_shape[0] * q_shape[1];
let batch_shape = [q_shape[0] * q_shape[1]];
let dk = q_shape[q_shape.len() - 1];
let ldq = dk;
let ldk = dk;
let ldv = dk;
let lds = BN;
let ldo = dk;
let tn = 1;
let tm = (m + BM - 1) / BM;
let b_stride_q = dk * qseq;
let b_stride_k = dk * qseq;
let b_stride_v = dk * qseq;
let b_stride_o = dk * qseq;
let swizzle_log = 0;
let gemm_n_iterations_aligned = (n + BN - 1) / BN;
let gemm_k_iterations_aligned = (k + bk - 1) / bk;
let gemm_sv_m_block_iterations = (m + BM - 1) / BM;
let batch_ndim = batch_shape.len();
let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
};
let params = MLXFastAttentionParams {
m: m as i32,
n: n as i32,
k: k as i32,
ldq: ldq as i32,
ldk: ldk as i32,
ldv: ldv as i32,
lds: lds as i32,
ldo: ldo as i32,
tiles_n: tn,
tiles_m: tm as i32,
batch_stride_q: b_stride_q as i32,
batch_stride_k: b_stride_k as i32,
batch_stride_v: b_stride_v as i32,
batch_stride_o: b_stride_o as i32,
swizzle_log,
gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32,
gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32,
gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32,
batch_ndim: batch_ndim as i32,
alpha,
softcapping,
};
let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o];
impl EncoderParam for MLXFastAttentionParams {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<MLXFastAttentionParams>() as u64,
&data as *const MLXFastAttentionParams as *const c_void,
);
}
}
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
params,
&batch_shape[..],
&batch_strides[..]
)
);
let grid_dims = MTLSize {
width: 1,
height: tm as u64,
depth: bs_out as u64,
};
let group_dims = MTLSize {
width: 32,
height: WM as u64,
depth: WN as u64,
};
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}
/// SDPA full is supported when:
/// - q head dim == 64, 96, 128
/// - no mask
/// - q,k,v are contiguous
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_vector(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_shape: &[usize],
k_stride: &[usize],
k_buffer: &Buffer,
v_offset: usize,
v_stride: &[usize],
v_buffer: &Buffer,
output: &Buffer,
alpha: f32,
softcapping: f32,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
let bk = q_shape.last().unwrap();
let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
let n = k_shape[2] as i32;
let b = (q_shape[0] * q_shape[1]) as i32;
let kstride = k_stride[1];
let vstride = v_stride[1];
let name = match (bk, itype) {
(32, SdpaDType::F16) => "sdpa_vector_float16_t_32",
(64, SdpaDType::F16) => "sdpa_vector_float16_t_64",
(96, SdpaDType::F16) => "sdpa_vector_float16_t_96",
(128, SdpaDType::F16) => "sdpa_vector_float16_t_128",
(256, SdpaDType::F16) => "sdpa_vector_float16_t_256",
(32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32",
(64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64",
(96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96",
(128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128",
(256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256",
(32, SdpaDType::F32) => "sdpa_vector_float_32",
(64, SdpaDType::F32) => "sdpa_vector_float_64",
(96, SdpaDType::F32) => "sdpa_vector_float_96",
(128, SdpaDType::F32) => "sdpa_vector_float_128",
(256, SdpaDType::F32) => "sdpa_vector_float_256",
(other, _) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "vector",
got: *other,
expected: vec![32, 64, 96, 128, 256],
})
}
};
let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
};
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, kv_seq, hidden)
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
gqa_factor,
n,
kstride,
vstride,
alpha,
softcapping
)
);
let grid_dims = MTLSize {
width: 1,
height: b as u64,
depth: 1 as u64,
};
let group_dims = MTLSize {
width: 1024,
height: 1,
depth: 1,
};
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,
@ -2357,5 +2681,25 @@ pub fn call_mlx_gemm(
Ok(())
}
pub fn call_const_fill(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
length: usize,
output: &Buffer,
v: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (output, v, length));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[cfg(test)]
mod tests;

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
use super::*;
use half::{bf16, f16};
use metal::MTLResourceOptions;
use rand::Rng;
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@ -2307,3 +2308,33 @@ fn conv_transpose1d_u32() {
let expected = vec![1, 4, 10, 20, 25, 24, 16];
assert_eq!(results, expected);
}
#[test]
fn const_fill() {
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
let dev = device();
let kernels = Kernels::new();
let command_queue = dev.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let buffer = dev.new_buffer(
(len * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModePrivate,
);
call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec::<T>(&buffer, len)
}
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.);
let v = constant_fill::<T>(name, len, value);
assert_eq!(v, vec![f(value); len])
}
test::<u8, _>("fill_u8", |v| v as u8);
test::<u32, _>("fill_u32", |v| v as u32);
test::<i64, _>("fill_i64", |v| v as i64);
test::<f16, _>("fill_f16", f16::from_f32);
test::<bf16, _>("fill_bf16", bf16::from_f32);
test::<f32, _>("fill_f32", |v| v);
}

View File

@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M
}
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
@ -61,18 +61,14 @@ pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
}
}
pub(crate) fn set_param<P: EncoderParam>(
encoder: &ComputeCommandEncoderRef,
position: u64,
data: P,
) {
pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
pub(crate) trait EncoderParam {
pub trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {

View File

@ -1,3 +1,5 @@
//! Activation Functions
//!
use candle::{Result, Tensor};
use serde::Deserialize;

View File

@ -1,3 +1,5 @@
//! Cache Implementations
//!
use candle::{Device, Result, Tensor};
#[derive(Debug, Clone)]

View File

@ -1,3 +1,20 @@
//! candle-nn
//!
//! ## Other Crates
//!
//! Candle consists of a number of crates. This crate holds structs and functions
//! that allow you to build and train neural nets. You may wish
//! to look at the docs for the other crates which can be found here:
//!
//! - [candle-core](https://docs.rs/candle-core/). Core Datastructures and DataTypes.
//! - [candle-nn](https://docs.rs/candle-nn/). Building blocks for Neural Nets.
//! - [candle-datasets](https://docs.rs/candle-datasets/). Rust access to commonly used Datasets like MNIST.
//! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use.
//! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models.
//! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python.
//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models.
//!
pub mod activation;
pub mod batch_norm;
pub mod conv;

View File

@ -1,3 +1,5 @@
//! Loss Calculations
//!
use candle::{Result, Tensor};
/// The negative log likelihood loss.

View File

@ -1,3 +1,6 @@
//! Tensor ops.
//!
use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
use rayon::prelude::*;
@ -543,15 +546,23 @@ impl candle::CustomOp2 for RmsNorm {
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let block_size = if n_cols < 1024 { 32 } else { 1024 };
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (1024, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, &alpha, n_cols as i32, self.eps);
let params = (
&src,
&dst,
&alpha,
n_cols as i32,
block_size as i32,
self.eps,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
@ -776,15 +787,24 @@ impl candle::CustomOp3 for LayerNorm {
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let block_size = if n_cols < 1024 { 32 } else { 1024 };
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (1024, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
let params = (
&src,
&dst,
&alpha,
&beta,
n_cols as i32,
block_size as i32,
self.eps,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
@ -947,3 +967,193 @@ impl Module for Identity {
Ok(xs.clone())
}
}
#[allow(dead_code)]
struct Sdpa {
scale: f32,
softcapping: f32,
}
impl candle::CustomOp3 for Sdpa {
fn name(&self) -> &'static str {
"metal-sdpa"
}
fn cpu_fwd(
&self,
_s1: &CpuStorage,
_l1: &Layout,
_s2: &CpuStorage,
_l2: &Layout,
_s3: &CpuStorage,
_l3: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("SDPA has no cpu impl")
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
q: &candle::MetalStorage,
q_l: &Layout,
k: &candle::MetalStorage,
k_l: &Layout,
v: &candle::MetalStorage,
v_l: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
use candle_metal_kernels::SdpaDType;
let device = q.device();
let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
let elem_count: usize = out_dims.iter().product();
let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
// q,k must have matching emb dim
if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
candle::bail!("`q` and `k` last dims must match");
}
// k,v must have matching n kv heads
if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
candle::bail!("`k` and `v` head dims must match");
}
// n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
}
let k_head = k_l.dim(D::Minus1)?;
let q_head = q_l.dim(D::Minus1)?;
let q_seq = q_l.dim(2)?;
let mut implementation_supports_use_case = q_head == k_head;
let supported_head_dim =
q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256;
const SDPA_FULL_THRESHOLD: usize = 2;
let supports_sdpa_full =
q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head;
let supports_sdpa_vector = q_seq == 1 && supported_head_dim;
implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
if !supported_head_dim {
candle::bail!(
"Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
q_l.dims(),
k_l.dims(),
v_l.dims()
);
}
if !implementation_supports_use_case {
candle::bail!(
"Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
q_l.dims(),
k_l.dims(),
v_l.dims()
);
}
for t in [k.dtype(), v.dtype()] {
if q.dtype() != t {
candle::bail!("all q, k, v dtypes must match.");
}
}
let itype = match q.dtype() {
DType::BF16 => SdpaDType::BF16,
DType::F16 => SdpaDType::F16,
DType::F32 => SdpaDType::F32,
other => candle::bail!("unsupported sdpa type {other:?}"),
};
let command_buffer = q.device().command_buffer()?;
if supports_sdpa_vector {
command_buffer.set_label("vector_attention");
candle_metal_kernels::call_sdpa_vector(
q.device().device(),
&command_buffer,
q.device().kernels(),
q_l.start_offset(),
q_l.dims(),
q.buffer(),
k_l.start_offset(),
k_l.dims(),
k_l.stride(),
k.buffer(),
v_l.start_offset(),
v_l.stride(),
v.buffer(),
&output,
self.scale,
self.softcapping,
itype,
)
.map_err(candle::Error::wrap)?;
} else if supports_sdpa_full {
if q_l.dim(2)? != k_l.dim(2)? {
candle::bail!(
"query and key sequence length must be equal if using full metal sdpa"
)
}
command_buffer.set_label("full_attention");
candle_metal_kernels::call_sdpa_full(
q.device().device(),
&command_buffer,
q.device().kernels(),
q_l.start_offset(),
q_l.dims(),
q.buffer(),
k_l.start_offset(),
k.buffer(),
v_l.start_offset(),
v.buffer(),
&output,
self.scale,
self.softcapping,
itype,
)
.map_err(candle::Error::wrap)?;
} else {
candle::bail!("must be vector or full sdpa kernel");
}
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
Ok((newstorage, Shape::from_dims(&out_dims)))
}
}
/// Scaled dot product attention with a fused kernel.
///
/// Computes softmax(qk^T*scale)v.
///
/// **Inputs shapes:**
/// - `q`: (bs, qhead, seq, hidden)
/// - `k`: (bs, kv_head, kv_seq, hidden)
/// - `k`: (bs, kv_head, kv_seq, v_hidden)
/// - `scale` is applied before softmax.
/// - If `softcapping` != 1.0:
/// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
///
/// **Output shape:** (bs, qhead, seq, v_hidden)
///
/// **Supported head dims:** 32, 64, 96, 128, 256.
///
/// ## On Metal:
/// - If `seq` == 1:
/// - Use a vectorized kernel
/// - Supports `seq` != `kv_seq` (cross attn. support)
/// - Supports GQA when `qhead` is a multiple of `kv_head`
/// - Otherwise:
/// - Use an alternate kernel
/// - Requires `seq` == `kv_seq`
/// - GQA is not supported (requires `qhead` == `kv_head`)
pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result<Tensor> {
q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping })
}

View File

@ -70,6 +70,12 @@ impl LSTMState {
}
}
#[derive(Debug, Clone, Copy)]
pub enum Direction {
Forward,
Backward,
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone, Copy)]
pub struct LSTMConfig {
@ -78,6 +84,7 @@ pub struct LSTMConfig {
pub b_ih_init: Option<super::Init>,
pub b_hh_init: Option<super::Init>,
pub layer_idx: usize,
pub direction: Direction,
}
impl Default for LSTMConfig {
@ -88,6 +95,7 @@ impl Default for LSTMConfig {
b_ih_init: Some(super::Init::Const(0.)),
b_hh_init: Some(super::Init::Const(0.)),
layer_idx: 0,
direction: Direction::Forward,
}
}
}
@ -100,6 +108,7 @@ impl LSTMConfig {
b_ih_init: None,
b_hh_init: None,
layer_idx: 0,
direction: Direction::Forward,
}
}
}
@ -107,7 +116,7 @@ impl LSTMConfig {
/// A Long Short-Term Memory (LSTM) layer.
///
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
#[allow(clippy::upper_case_acronyms, unused)]
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug)]
pub struct LSTM {
w_ih: Tensor,
@ -120,6 +129,62 @@ pub struct LSTM {
dtype: DType,
}
impl LSTM {
/// Creates a LSTM layer.
pub fn new(
in_dim: usize,
hidden_dim: usize,
config: LSTMConfig,
vb: crate::VarBuilder,
) -> Result<Self> {
let layer_idx = config.layer_idx;
let direction_str = match config.direction {
Direction::Forward => "",
Direction::Backward => "_reverse",
};
let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim),
&format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim),
&format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_ih_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_hh_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
Ok(Self {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
pub fn config(&self) -> &LSTMConfig {
&self.config
}
}
/// Creates a LSTM layer.
pub fn lstm(
in_dim: usize,
@ -127,39 +192,7 @@ pub fn lstm(
config: LSTMConfig,
vb: crate::VarBuilder,
) -> Result<LSTM> {
let layer_idx = config.layer_idx;
let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim),
&format!("weight_ih_l{layer_idx}"), // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim),
&format!("weight_hh_l{layer_idx}"), // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => {
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?)
}
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => {
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?)
}
None => None,
};
Ok(LSTM {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
LSTM::new(in_dim, hidden_dim, config, vb)
}
impl RNN for LSTM {
@ -253,7 +286,7 @@ impl GRUConfig {
/// A Gated Recurrent Unit (GRU) layer.
///
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
#[allow(clippy::upper_case_acronyms, unused)]
#[allow(clippy::upper_case_acronyms)]
#[derive(Clone, Debug)]
pub struct GRU {
w_ih: Tensor,
@ -266,41 +299,56 @@ pub struct GRU {
dtype: DType,
}
/// Creates a GRU layer.
impl GRU {
/// Creates a GRU layer.
pub fn new(
in_dim: usize,
hidden_dim: usize,
config: GRUConfig,
vb: crate::VarBuilder,
) -> Result<Self> {
let w_ih = vb.get_with_hints(
(3 * hidden_dim, in_dim),
"weight_ih_l0", // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(3 * hidden_dim, hidden_dim),
"weight_hh_l0", // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
None => None,
};
Ok(Self {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
pub fn config(&self) -> &GRUConfig {
&self.config
}
}
pub fn gru(
in_dim: usize,
hidden_dim: usize,
config: GRUConfig,
vb: crate::VarBuilder,
) -> Result<GRU> {
let w_ih = vb.get_with_hints(
(3 * hidden_dim, in_dim),
"weight_ih_l0", // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(3 * hidden_dim, hidden_dim),
"weight_hh_l0", // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
None => None,
};
Ok(GRU {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
GRU::new(in_dim, hidden_dim, config, vb)
}
impl RNN for GRU {

View File

@ -1,3 +1,5 @@
//! Rotary Embeddings
//!
use candle::{CpuStorage, Layout, Result, Shape, Tensor, D};
use rayon::prelude::*;

View File

@ -1,3 +1,5 @@
//! Sequential Layer
//!
//! A sequential layer used to chain multiple layers and closures.
use candle::{Module, Result, Tensor};

View File

@ -1,3 +1,5 @@
//! A `VarBuilder` for variable retrieval from models
//!
//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come
//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized
//! for training, e.g. using `VarBuilder::from_varmap`.
@ -14,6 +16,7 @@ use std::sync::Arc;
pub struct VarBuilderArgs<'a, B: Backend> {
data: Arc<TensorData<B>>,
path: Vec<String>,
pub dtype: DType,
_phantom: std::marker::PhantomData<&'a B>,
}
@ -22,6 +25,7 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: self.path.clone(),
dtype: self.dtype,
_phantom: self._phantom,
}
}
@ -33,7 +37,6 @@ pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
struct TensorData<B: Backend> {
backend: B,
pub dtype: DType,
pub device: Device,
}
@ -95,12 +98,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
let data = TensorData {
backend,
dtype,
device: dev.clone(),
};
Self {
data: Arc::new(data),
path: vec![],
dtype,
_phantom: std::marker::PhantomData,
}
}
@ -115,6 +118,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: vec![],
dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@ -124,6 +128,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: vec![prefix.to_string()],
dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@ -136,6 +141,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path,
dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@ -152,7 +158,17 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
/// The dtype used by default.
pub fn dtype(&self) -> DType {
self.data.dtype
self.dtype
}
/// Clone the VarBuilder tweaking its dtype
pub fn to_dtype(&self, dtype: DType) -> Self {
Self {
data: self.data.clone(),
path: self.path.clone(),
dtype,
_phantom: std::marker::PhantomData,
}
}
fn path(&self, tensor_name: &str) -> String {
@ -178,7 +194,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
name: &str,
hints: B::Hints,
) -> Result<Tensor> {
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
self.get_with_hints_dtype(s, name, hints, self.dtype)
}
/// Retrieve the tensor associated with the given name at the current path.
@ -460,14 +476,11 @@ impl<'a> VarBuilder<'a> {
dtype: DType,
device: Device,
) -> Self {
let data = TensorData {
backend,
dtype,
device,
};
let data = TensorData { backend, device };
Self {
data: Arc::new(data),
path: vec![],
dtype,
_phantom: std::marker::PhantomData,
}
}
@ -567,13 +580,10 @@ impl<'a> VarBuilder<'a> {
let path = self.path.clone();
let backend = Rename::new(self, renamer);
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
let data = TensorData {
backend,
dtype,
device,
};
let data = TensorData { backend, device };
Self {
data: Arc::new(data),
dtype,
path,
_phantom: std::marker::PhantomData,
}

View File

@ -1,3 +1,5 @@
//! A `VarMap` is a store that holds named variables.
//!
use candle::{DType, Device, Result, Shape, Tensor, Var};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

View File

@ -77,6 +77,27 @@ fn rms_norm(device: &Device) -> Result<()> {
Ok(())
}
fn rms_norml(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};
let (b_size, seq_len, head_dim) = (24, 70, 64);
let el_count = b_size * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?;
let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?;
let diff = (t - t2)?
.abs()?
.flatten_all()?
.max(0)?
.reshape(())?
.to_vec0::<f32>()?;
assert!(diff < 1e-5);
Ok(())
}
fn layer_norm(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
@ -103,6 +124,28 @@ fn layer_norm(device: &Device) -> Result<()> {
Ok(())
}
fn layer_norml(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};
let (b_size, seq_len, head_dim) = (24, 70, 64);
let el_count = b_size * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
let beta = Tensor::zeros(head_dim, candle::DType::F32, device)?;
let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?;
let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?;
let diff = (t - t2)?
.abs()?
.flatten_all()?
.max(0)?
.reshape(())?
.to_vec0::<f32>()?;
assert!(diff < 1e-5);
Ok(())
}
#[test]
fn softmax_numerical_stability() -> Result<()> {
let dev = &Device::Cpu;
@ -211,5 +254,7 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
test_device!(rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal);
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
test_device!(layer_norml, lnl_cpu, lnl_gpu, lnl_metal);
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);

206
candle-nn/tests/sdpa.rs Normal file
View File

@ -0,0 +1,206 @@
#[cfg(feature = "metal")]
mod metal_sdpa_tests {
#[test]
fn sdpa_full() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
// Force seqlen = 100
const BS: usize = 4;
const R: usize = 4;
const L: usize = 4;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0005, "{}", error);
Ok(())
}
#[test]
fn sdpa_vector() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 1;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0001, "{}", error);
Ok(())
}
#[test]
fn sdpa_full_softcapping() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
use std::ops::{Div, Mul};
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 4;
const L: usize = 4;
const DK: usize = 64;
const H: usize = 3;
const SOFTCAP: f64 = 50.;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(
&att.to_dtype(DType::F32)?
.div(SOFTCAP)?
.tanh()?
.mul(SOFTCAP)?,
)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0004, "{}", error);
Ok(())
}
#[test]
fn sdpa_vector_softcapping() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
use std::ops::{Div, Mul};
// Allow vectorized, seqlen = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 1;
const DK: usize = 64;
const H: usize = 3;
const SOFTCAP: f64 = 50.;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(
&att.to_dtype(DType::F32)?
.div(SOFTCAP)?
.tanh()?
.mul(SOFTCAP)?,
)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0001, "{}", error);
Ok(())
}
#[test]
fn sdpa_vector_cross() -> candle::Result<()> {
use candle::{DType, Device, Tensor};
// Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1
const BS: usize = 4;
const R: usize = 1;
const L: usize = 24;
const DK: usize = 64;
const H: usize = 3;
let scale: f64 = f64::from(DK as u32).sqrt().recip();
let device = Device::new_metal(0)?;
let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?;
let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?;
let ground_truth = {
let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?;
let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)?
.to_dtype(q.dtype())?;
att.matmul(&v.clone())?
};
let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?;
assert_eq!(ground_truth.shape(), sdpa_output.shape());
let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)?
.sum_all()?
.to_scalar()?;
assert!(error <= 0.0013, "{}", error);
Ok(())
}
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.7.1"
version = "0.8.0"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.7.1" }
candle-nn = { path = "../candle-nn", version = "0.7.1" }
candle = { path = "../candle-core", package = "candle-core", version = "0.8.0" }
candle-nn = { path = "../candle-nn", version = "0.8.0" }
prost = "0.12.1"
[build-dependencies]

View File

@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::{collections::HashMap, usize};
use std::collections::{HashMap, HashSet};
pub type Value = Tensor;
@ -321,8 +321,15 @@ fn simple_eval_(
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
None => bail!("cannot find {input_name} for op '{}'", node.name),
};
let get_opt = |i: usize| {
node.input
.get(i)
.filter(|s: &&String| !s.is_empty())
.map(|s| get(s))
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
@ -355,7 +362,7 @@ fn simple_eval_(
// HACK: current implementation of broadcast_pow cannot handle negative base,
// so we use powf where we can, which *does* correctly handle negative base.
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
let output = input0.powf(exp as f64)?;
let output = input0.powf(exp)?;
values.insert(node.output[0].clone(), output);
} else {
let output = input0.broadcast_pow(input1)?;
@ -608,15 +615,13 @@ fn simple_eval_(
}
"Clip" => {
let xs = get(&node.input[0])?;
let xs = if node.input.len() >= 2 {
let mins = get(&node.input[1])?;
xs.broadcast_maximum(mins)?
let xs = if let Some(mins) = get_opt(1) {
xs.broadcast_maximum(mins?)?
} else {
xs.clone()
};
let xs = if node.input.len() >= 3 {
let maxs = get(&node.input[2])?;
xs.broadcast_minimum(maxs)?
let xs = if let Some(maxs) = get_opt(2) {
xs.broadcast_minimum(maxs?)?
} else {
xs.clone()
};
@ -638,7 +643,7 @@ fn simple_eval_(
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(&indices)?
.add(indices)?
};
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
@ -665,6 +670,49 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), xs);
}
// https://onnx.ai/onnx/operators/onnx__GatherElements.html#gatherelements
// A Note to fellow lurkers:
// The numpy based `gather_elements` implementation in `onnx` tests [here](https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gatherelements.py)
// and examples is incorrect.
// Use `torch.gather` for the validating/ verifying against the proper behaviour
"GatherElements" => {
let data = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let rank = data.rank();
if rank != indices.rank() {
bail!("indices must have same rank as input data. Data rank [{}] != indices rank [{}]", data.rank(), indices.rank());
}
let axis = {
let axis_i64 = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = data.normalize_axis(axis_i64)?;
if axis >= rank {
bail!(
"axis ({}) out of accepted range [-rank, rank-1] which was [-{rank}, {}]",
axis_i64,
rank - 1
)
}
axis
};
// index_select does not support negative indices, so normalize them
// to positive indices.
let indices = &{
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
let max = Tensor::new(data.dims()[axis] as i64, indices.device())?
.to_dtype(indices.dtype())?;
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(indices)?
};
values.insert(node.output[0].clone(), data.gather(indices, axis)?);
}
"Shape" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
let xs = get(&node.input[0])?;
@ -759,7 +807,14 @@ fn simple_eval_(
let cond = get(&node.input[0])?;
let a = get(&node.input[1])?;
let b = get(&node.input[2])?;
let output = cond.where_cond(a, b)?;
// where_cond requires that all inputs are the same shape.
// In contrast, the Where op in ONNX only requires that they are broadcastable.
let shape = broadcast_shape_from_many(&[cond.dims(), a.dims(), b.dims()])?;
let cond = cond.broadcast_as(shape.clone())?;
let a = a.broadcast_as(shape.clone())?;
let b = b.broadcast_as(shape)?;
let output = cond.where_cond(&a, &b)?;
values.insert(node.output[0].clone(), output);
}
"Conv" => {
@ -962,6 +1017,7 @@ fn simple_eval_(
}
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
@ -1176,6 +1232,92 @@ fn simple_eval_(
}
values.insert(node.output[0].clone(), out);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax
"ReduceMax" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;
let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};
let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();
let mut axes = axes
.iter()
.map(|a| {
let axis = if *a < 0 {
(rank as i64 + *a) as usize
} else {
*a as usize
};
axes_set.insert(axis);
axis
})
.collect::<Vec<_>>();
if axes_set.len() < axes.len() {
bail!("Duplicate value in 'axes'");
}
if axes.len() > 1 {
axes.sort();
}
Some(axes)
} else {
None
};
// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields minus infinity (if supported by the datatype) or the minimum value of the data type otherwise"
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}
let output = if let Some(axes) = axes {
let mut result = input.clone();
for &axis in axes.iter().rev() {
result = if keepdims {
result.max_keepdim(axis)?
} else {
result.max(axis)?
}
}
result
} else {
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
input.clone()
} else {
let mut result = input.flatten_all()?;
if keepdims {
result = result.max_keepdim(0)?;
// If keepdims is true, reshape to match input dimensions
let shape = vec![1; input.rank()];
result.reshape(shape)?
} else {
result.max(0)?
}
}
};
values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
@ -1199,6 +1341,237 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin
"ReduceMin" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;
let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};
let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();
let mut axes = axes
.iter()
.map(|a| {
let axis = if *a < 0 {
(rank as i64 + *a) as usize
} else {
*a as usize
};
axes_set.insert(axis);
axis
})
.collect::<Vec<_>>();
if axes_set.len() < axes.len() {
bail!("Duplicate value in 'axes'");
}
if axes.len() > 1 {
axes.sort();
}
Some(axes)
} else {
None
};
// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields positive infinity (if supported by the datatype) or the max value of the data type otherwise"
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}
let output = if let Some(axes) = axes {
let mut result = input.clone();
for &axis in axes.iter().rev() {
result = if keepdims {
result.min_keepdim(axis)?
} else {
result.min(axis)?
}
}
result
} else {
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
input.clone()
} else {
let mut result = input.flatten_all()?;
if keepdims {
result = result.min_keepdim(0)?;
// If keepdims is true, reshape to match input dimensions
let shape = vec![1; input.rank()];
result.reshape(shape)?
} else {
result.min(0)?
}
}
};
values.insert(node.output[0].clone(), output);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
// Version 18 impl
"Split" => {
let input_tensor = get(&node.input[0])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = input_tensor.normalize_axis(axis)?;
// Determine split sizes
let splits = if node.input.len() > 1 {
// If the split tensor is provided, use it to determine sizes
let split_tensor = get(&node.input[1])?.to_vec1::<i64>()?;
split_tensor.iter().map(|&x| x as usize).collect::<Vec<_>>()
} else {
let num_outputs = if let Some(&num_outputs_attrib) =
get_attr_opt::<i64>(node, "num_outputs")?
{
num_outputs_attrib as usize
} else {
node.output.len()
};
let input_dim = input_tensor.dim(axis)?;
let mut split_sizes =
vec![input_dim / num_outputs as usize; num_outputs as usize];
let remainder = input_dim % num_outputs as usize;
if remainder > 0 {
// If there's a remainder, add it to the last split size
split_sizes[num_outputs as usize - 1] += remainder;
}
split_sizes
};
// Perform the split operation
let mut outputs = vec![];
let mut start = 0;
for &size in &splits {
let end = start + size;
let slice = input_tensor.narrow(axis, start, size)?;
outputs.push(slice);
start = end;
}
// Insert the split outputs into the values map
for (output, slice) in node.output.iter().zip(outputs.into_iter()) {
values.insert(output.clone(), slice);
}
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand
// Version 13 impl
"Expand" => {
// unlike broadcast_to, expand allows for the output shape to
// be different from the specified shape.
let input_tensor = get(&node.input[0])?;
let input_shape = get(&node.input[1])?;
// Check that the shape tensor is 1D
if input_shape.rank() != 1 {
bail!(
"Expand expects 'shape' input to be 1D tensor: {:?}",
input_shape
);
}
let input_tensor_dims = input_tensor.dims();
let input_shape_dims = input_shape
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>();
let target_shape = broadcast_shape(input_tensor_dims, input_shape_dims.as_slice())?;
let expanded_tensor = input_tensor.broadcast_as(target_shape)?;
values.insert(node.output[0].clone(), expanded_tensor);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum
// Version 13 impl
"ReduceSum" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
.copied()
.unwrap_or(0);
let axes = match axes {
Some(Ok(axes)) => axes
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
Some(Err(_)) | None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
(0..input.rank()).collect()
}
}
};
let output = if keepdims == 1 {
input.sum_keepdim(axes)?
} else {
input.sum(axes)?
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2
// Version 18 impl
"ReduceL2" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
.copied()
.unwrap_or(0);
let input_sq = input.sqr()?;
let axes = match axes {
Some(axes) => axes?
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
(0..input_sq.rank()).collect()
}
}
};
let output = if keepdims == 1 {
input_sq.sum_keepdim(axes)?.sqrt()?
} else {
input_sq.sum(axes)?.sqrt()?
};
values.insert(node.output[0].clone(), output);
}
random_type @ ("RandomUniform" | "RandomNormal") => {
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
// type by
@ -1395,13 +1768,6 @@ fn simple_eval_(
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
let r = get(&node.input[2])?;
let get_opt = |i: usize| {
node.input
.get(i)
.filter(|s: &&String| !s.is_empty())
.map(|s| get(s))
};
// The bias tensor for input gate.
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
// This tensor has shape `[num_directions, 8*hidden_size]`.
@ -1488,7 +1854,7 @@ fn simple_eval_(
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
let idx_wb = Tensor::arange(0, 4 * hidden_size, x.device())?;
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
let wb = b.index_select(&idx_wb, 0)?;
let rb = b.index_select(&idx_rb, 0)?;
@ -1497,8 +1863,8 @@ fn simple_eval_(
// w, r, wb, rb are all iofc but lstm expects ifco
// so we need to move some stuff around
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
let idx_i = Tensor::arange(0, hidden_size, x.device())?;
let idx_o = Tensor::arange(hidden_size, 2 * hidden_size, x.device())?;
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
@ -1522,7 +1888,7 @@ fn simple_eval_(
)?;
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
let mut h_acc = if node.output.first().map(String::as_str).unwrap_or("") != "" {
Some(vec![])
} else {
None
@ -1536,7 +1902,7 @@ fn simple_eval_(
}
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
if let Some(name) = node.output.get(0) {
if let Some(name) = node.output.first() {
let h_acc = h_acc.as_ref().unwrap();
let h_acc = lstm.states_to_tensor(h_acc)?;
let h_acc = h_acc.reshape((
@ -1568,6 +1934,16 @@ fn simple_eval_(
);
}
}
// https://onnx.ai/onnx/operators/onnx__Xor.html
"Xor" => {
// Since we don't have a `DType::Bool` yet, this ensures that we are working with `0`(False) & `1`(True)
let a = get(&node.input[0])?.gt(0_u8)?;
let b = get(&node.input[1])?.gt(0_u8)?;
let out = a.broadcast_add(&b)?.eq(1_u8)?;
values.insert(node.output[0].clone(), out);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}
@ -1580,3 +1956,36 @@ fn simple_eval_(
})
.collect()
}
fn broadcast_shape(shape_a: &[usize], shape_b: &[usize]) -> Result<Vec<usize>> {
let (longest, shortest) = if shape_a.len() > shape_b.len() {
(shape_a, shape_b)
} else {
(shape_b, shape_a)
};
let diff = longest.len() - shortest.len();
let mut target_shape = longest[0..diff].to_vec();
for (dim1, dim2) in longest[diff..].iter().zip(shortest.iter()) {
if *dim1 == *dim2 || *dim2 == 1 || *dim1 == 1 {
target_shape.push(usize::max(*dim1, *dim2));
} else {
bail!(
"Expand: incompatible shapes for broadcast, {:?} and {:?}",
shape_a,
shape_b
);
}
}
Ok(target_shape)
}
fn broadcast_shape_from_many(shapes: &[&[usize]]) -> Result<Vec<usize>> {
if shapes.is_empty() {
return Ok(Vec::new());
}
let mut shape_out = shapes[0].to_vec();
for shape in shapes[1..].iter() {
shape_out = broadcast_shape(&shape_out, shape)?;
}
Ok(shape_out)
}

File diff suppressed because it is too large Load Diff

View File

@ -20,10 +20,10 @@ candle-nn = { workspace = true }
candle-onnx = { workspace = true, optional = true }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] }
pyo3 = { version = "0.22.0", features = ["extension-module", "abi3-py38"] }
[build-dependencies]
pyo3-build-config = "0.21"
pyo3-build-config = "0.22"
[features]
default = []

View File

@ -33,9 +33,7 @@ def has_mkl() -> bool:
pass
@staticmethod
def load_ggml(
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
def load_ggml(path, device=None) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
"""
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
@ -43,9 +41,7 @@ def load_ggml(
pass
@staticmethod
def load_gguf(
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
def load_gguf(path, device=None) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
"""
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
and the second maps metadata keys to metadata values.
@ -60,7 +56,7 @@ def load_safetensors(path: Union[str, PathLike]) -> Dict[str, Tensor]:
pass
@staticmethod
def save_gguf(path: Union[str, PathLike], tensors: Dict[str, QTensor], metadata: Dict[str, Any]):
def save_gguf(path, tensors, metadata):
"""
Save quanitzed tensors and metadata to a GGUF file.
"""

View File

@ -6,7 +6,6 @@ use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::os::raw::c_long;
use std::sync::Arc;
use half::{bf16, f16};
@ -115,7 +114,7 @@ impl PyDevice {
}
impl<'source> FromPyObject<'source> for PyDevice {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
let device: String = ob.extract()?;
let device = match device.as_str() {
"cpu" => PyDevice::Cpu,
@ -217,11 +216,11 @@ enum Indexer {
IndexSelect(Tensor),
}
#[derive(Clone, Debug)]
#[derive(Debug)]
struct TorchTensor(PyObject);
impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
Ok(TorchTensor(numpy_value))
}
@ -540,7 +539,7 @@ impl PyTensor {
))
} else if let Ok(slice) = py_indexer.downcast::<pyo3::types::PySlice>() {
// Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
let index = slice.indices(dims[current_dim] as c_long)?;
let index = slice.indices(dims[current_dim] as isize)?;
Ok((
Indexer::Slice(index.start as usize, index.stop as usize),
current_dim + 1,
@ -1284,7 +1283,7 @@ fn save_safetensors(
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
#[pyo3(signature = (path, device = None))]
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
@ -1325,7 +1324,7 @@ fn load_ggml(
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
#[pyo3(signature = (path, device = None))]
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
/// and the second maps metadata keys to metadata values.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
@ -1384,7 +1383,7 @@ fn load_gguf(
#[pyfunction]
#[pyo3(
text_signature = "(path:Union[str,PathLike], tensors:Dict[str,QTensor], metadata:Dict[str,Any])"
signature = (path, tensors, metadata)
)]
/// Save quanitzed tensors and metadata to a GGUF file.
fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
@ -1430,7 +1429,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
Ok(v)
}
let tensors = tensors
.extract::<&PyDict>(py)
.downcast_bound::<PyDict>(py)
.map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
.iter()
.map(|(key, value)| {
@ -1443,7 +1442,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
.collect::<PyResult<Vec<_>>>()?;
let metadata = metadata
.extract::<&PyDict>(py)
.downcast_bound::<PyDict>(py)
.map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
.iter()
.map(|(key, value)| {

View File

@ -6,7 +6,7 @@ use pyo3::prelude::*;
pub struct PyShape(Vec<usize>);
impl<'source> pyo3::FromPyObject<'source> for PyShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
@ -16,10 +16,10 @@ impl<'source> pyo3::FromPyObject<'source> for PyShape {
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(&first_element)?;
Ok(PyShape(dims))
} else {
let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(tuple)?;
Ok(PyShape(dims))
}
}
@ -36,7 +36,7 @@ impl From<PyShape> for ::candle::Shape {
pub struct PyShapeWithHole(Vec<isize>);
impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
@ -46,9 +46,9 @@ impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
let dims: Vec<isize> = if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
pyo3::FromPyObject::extract(first_element)?
pyo3::FromPyObject::extract_bound(&first_element)?
} else {
pyo3::FromPyObject::extract(tuple)?
pyo3::FromPyObject::extract_bound(tuple)?
};
// Ensure we have only positive numbers and at most one "hole" (-1)

View File

@ -504,3 +504,100 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
(attention_mask.ones_like()? - &attention_mask)?
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
}
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
struct BertPredictionHeadTransform {
dense: Linear,
activation: HiddenActLayer,
layer_norm: LayerNorm,
}
impl BertPredictionHeadTransform {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
let activation = HiddenActLayer::new(config.hidden_act);
let layer_norm = layer_norm(
config.hidden_size,
config.layer_norm_eps,
vb.pp("LayerNorm"),
)?;
Ok(Self {
dense,
activation,
layer_norm,
})
}
}
impl Module for BertPredictionHeadTransform {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let hidden_states = self
.activation
.forward(&self.dense.forward(hidden_states)?)?;
self.layer_norm.forward(&hidden_states)
}
}
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
pub struct BertLMPredictionHead {
transform: BertPredictionHeadTransform,
decoder: Linear,
}
impl BertLMPredictionHead {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let transform = BertPredictionHeadTransform::load(vb.pp("transform"), config)?;
let decoder = linear(config.hidden_size, config.vocab_size, vb.pp("decoder"))?;
Ok(Self { transform, decoder })
}
}
impl Module for BertLMPredictionHead {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
self.decoder
.forward(&self.transform.forward(hidden_states)?)
}
}
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
pub struct BertOnlyMLMHead {
predictions: BertLMPredictionHead,
}
impl BertOnlyMLMHead {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let predictions = BertLMPredictionHead::load(vb.pp("predictions"), config)?;
Ok(Self { predictions })
}
}
impl Module for BertOnlyMLMHead {
fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
self.predictions.forward(sequence_output)
}
}
pub struct BertForMaskedLM {
bert: BertModel,
cls: BertOnlyMLMHead,
}
impl BertForMaskedLM {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let bert = BertModel::load(vb.pp("bert"), config)?;
let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?;
Ok(Self { bert, cls })
}
pub fn forward(
&self,
input_ids: &Tensor,
token_type_ids: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let sequence_output = self
.bert
.forward(input_ids, token_type_ids, attention_mask)?;
self.cls.forward(&sequence_output)
}
}

View File

@ -0,0 +1,208 @@
//! Chinese contrastive Language-Image Pre-Training
//!
//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
//! pairs of images with related texts.
//!
//! https://github.com/OFA-Sys/Chinese-CLIP
//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
use text_model::ChineseClipTextTransformer;
use vision_model::ChineseClipVisionTransformer;
pub mod text_model;
pub mod vision_model;
#[derive(Debug, Clone, Copy)]
pub enum Activation {
QuickGelu,
Gelu,
GeluNew,
Relu,
}
impl From<String> for Activation {
fn from(value: String) -> Self {
match value.as_str() {
"quick_gelu" => Activation::QuickGelu,
"gelu" => Activation::Gelu,
"gelu_new" => Activation::GeluNew,
"relu" => Activation::Relu,
_ => panic!("Invalid activation function: {}", value),
}
}
}
impl Module for Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
Activation::Gelu => xs.gelu_erf(),
Activation::GeluNew => xs.gelu(),
Activation::Relu => xs.relu(),
}
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipConfig {
pub text_config: text_model::ChineseClipTextConfig,
pub vision_config: vision_model::ChineseClipVisionConfig,
pub projection_dim: usize,
pub logit_scale_init_value: f32,
pub image_size: usize,
}
impl ChineseClipConfig {
/// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json
pub fn clip_vit_base_patch16() -> Self {
let text_config = text_model::ChineseClipTextConfig::clip_vit_base_patch16();
let vision_config = vision_model::ChineseClipVisionConfig::clip_vit_base_patch16();
Self {
text_config,
vision_config,
projection_dim: 512,
logit_scale_init_value: 2.6592,
image_size: 512,
}
}
}
#[derive(Clone, Debug)]
pub enum EncoderConfig {
Text(text_model::ChineseClipTextConfig),
Vision(vision_model::ChineseClipVisionConfig),
}
impl EncoderConfig {
pub fn embed_dim(&self) -> usize {
match self {
Self::Text(c) => c.hidden_size,
Self::Vision(c) => c.hidden_size,
}
}
pub fn num_attention_heads(&self) -> usize {
match self {
Self::Text(c) => c.num_attention_heads,
Self::Vision(c) => c.num_attention_heads,
}
}
pub fn intermediate_size(&self) -> usize {
match self {
Self::Text(c) => c.intermediate_size,
Self::Vision(c) => c.intermediate_size,
}
}
pub fn num_hidden_layers(&self) -> usize {
match self {
Self::Text(c) => c.num_hidden_layers,
Self::Vision(c) => c.num_hidden_layers,
}
}
pub fn activation(&self) -> Activation {
match self {
Self::Text(c) => c.hidden_act,
Self::Vision(c) => c.hidden_act,
}
}
pub fn layer_norm_eps(&self) -> f64 {
match self {
Self::Text(c) => c.layer_norm_eps,
Self::Vision(c) => c.layer_norm_eps,
}
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipModel {
text_model: ChineseClipTextTransformer,
vision_model: ChineseClipVisionTransformer,
visual_projection: nn::Linear,
text_projection: nn::Linear,
logit_scale: Tensor,
}
impl ChineseClipModel {
pub fn new(vs: nn::VarBuilder, c: &ChineseClipConfig) -> Result<Self> {
let text_model = ChineseClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
let vision_model =
ChineseClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
let vision_embed_dim = c.vision_config.hidden_size;
let vision_projection = nn::linear_no_bias(
vision_embed_dim,
c.projection_dim,
vs.pp("visual_projection"),
)?;
let text_embed_dim = c.text_config.hidden_size;
let text_projection =
nn::linear_no_bias(text_embed_dim, c.projection_dim, vs.pp("text_projection"))?;
let logit_scale = if vs.contains_tensor("logit_scale") {
vs.get(&[], "logit_scale")?
} else {
Tensor::new(&[c.logit_scale_init_value], vs.device())?
};
Ok(Self {
text_model,
vision_model,
visual_projection: vision_projection,
text_projection,
logit_scale,
})
}
pub fn get_text_features(
&self,
input_ids: &Tensor,
token_type_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let output = self
.text_model
.forward(input_ids, token_type_ids, attention_mask)?;
self.text_projection.forward(&output)
}
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
pixel_values
.apply(&self.vision_model)?
.apply(&self.visual_projection)
}
pub fn forward(
&self,
pixel_values: &Tensor,
input_ids: &Tensor,
token_type_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let image_features = self.get_image_features(pixel_values)?;
let text_features = self.get_text_features(input_ids, token_type_ids, attention_mask)?;
let image_features_normalized = div_l2_norm(&image_features)?;
let text_features_normalized = div_l2_norm(&text_features)?;
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
let logit_scale = self.logit_scale.exp()?;
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
let logits_per_image = logits_per_text.t()?;
Ok((logits_per_text, logits_per_image))
}
}
pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
v.broadcast_div(&l2_norm)
}

View File

@ -0,0 +1,540 @@
//! Chinese contrastive Language-Image Pre-Training
//!
//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
//! pairs of images with related texts.
//!
//! https://github.com/OFA-Sys/Chinese-CLIP
//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py
use candle::{DType, Device, IndexOp, Module, Result, Tensor};
use candle_nn as nn;
use super::Activation;
/// Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
/// positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
/// [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
/// For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
/// with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
#[derive(Clone, Debug)]
pub enum PositionEmbeddingType {
Absolute,
RelativeKey,
RelativeKeyQuery,
}
#[derive(Clone, Debug)]
pub struct ChineseClipTextConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub hidden_act: Activation,
pub hidden_dropout_prob: f32,
pub attention_probs_dropout_prob: f64,
pub max_position_embeddings: usize,
pub type_vocab_size: usize,
pub initializer_range: f64,
pub initializer_factor: f64,
pub layer_norm_eps: f64,
pub pad_token_id: usize,
pub position_embedding_type: PositionEmbeddingType,
pub use_cache: bool,
}
impl Default for ChineseClipTextConfig {
fn default() -> Self {
Self {
vocab_size: 30522,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: Activation::Gelu,
hidden_dropout_prob: 0.1,
attention_probs_dropout_prob: 0.1,
max_position_embeddings: 512,
type_vocab_size: 2,
initializer_range: 0.02,
initializer_factor: 1.0,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true,
}
}
}
impl ChineseClipTextConfig {
/// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json
pub fn clip_vit_base_patch16() -> Self {
Self {
vocab_size: 21128,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: Activation::Gelu,
hidden_dropout_prob: 0.1,
attention_probs_dropout_prob: 0.1,
max_position_embeddings: 512,
type_vocab_size: 2,
initializer_range: 0.02,
initializer_factor: 1.0,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true,
}
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipTextEmbeddings {
word_embeddings: nn::Embedding,
position_embeddings: nn::Embedding,
token_type_embeddings: nn::Embedding,
layer_norm: nn::LayerNorm,
dropout: nn::Dropout,
position_embedding_type: PositionEmbeddingType,
position_ids: Tensor,
token_type_ids: Tensor,
}
impl ChineseClipTextEmbeddings {
pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let word_embeddings = nn::embedding(
config.vocab_size,
config.hidden_size,
var.pp("word_embeddings"),
)?;
let position_embeddings = nn::embedding(
config.max_position_embeddings,
config.hidden_size,
var.pp("position_embeddings"),
)?;
let token_type_embeddings = nn::embedding(
config.type_vocab_size,
config.hidden_size,
var.pp("token_type_embeddings"),
)?;
let layer_norm = nn::layer_norm::<f64>(
config.hidden_size,
config.layer_norm_eps,
var.pp("LayerNorm"),
)?;
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
let position_ids =
Tensor::arange(0u32, config.max_position_embeddings as u32, var.device())?
.unsqueeze(0)?;
let token_type_ids = Tensor::zeros(position_ids.shape(), DType::I64, var.device())?;
Ok(Self {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
position_embedding_type: config.position_embedding_type.clone(),
position_ids,
token_type_ids,
})
}
fn forward(&self, xs: &Tensor, token_type_ids: Option<&Tensor>) -> Result<Tensor> {
let (_batch_size, seq_length) = xs.dims2()?;
let position_ids = (0..seq_length as u32).collect::<Vec<_>>();
let position_ids = self.position_ids.index_select(
&Tensor::new(&position_ids[..], self.position_ids.device())?,
1,
)?;
let word_embeddings = self.word_embeddings.forward(xs)?;
let token_type_ids = match token_type_ids {
Some(token_type_ids) => token_type_ids,
None => &self.token_type_ids.i((.., 0..seq_length))?,
};
let token_type_ids = token_type_ids.expand(xs.shape())?;
let token_type_embeddings = self.token_type_embeddings.forward(&token_type_ids)?;
let embeddings = (&word_embeddings + token_type_embeddings)?;
let embeddings = match self.position_embedding_type {
PositionEmbeddingType::Absolute => {
let position_embeddings = self.position_embeddings.forward(&position_ids)?;
let position_embeddings = position_embeddings.expand(embeddings.shape())?;
(embeddings + position_embeddings)?
}
_ => embeddings,
};
let embeddings = self.layer_norm.forward(&embeddings)?;
let embeddings = self.dropout.forward(&embeddings, false)?;
Ok(embeddings)
}
}
/// Copied from [`crate::models::bert::BertSelfOutput`] to [`ChineseClipTextSelfOutput`]
#[derive(Clone, Debug)]
struct ChineseClipTextSelfOutput {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: nn::Dropout,
span: tracing::Span,
}
impl ChineseClipTextSelfOutput {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?;
let layer_norm = nn::layer_norm(
config.hidden_size,
config.layer_norm_eps,
var.pp("LayerNorm"),
)?;
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
Ok(Self {
dense,
layer_norm,
dropout,
span: tracing::span!(tracing::Level::TRACE, "self-out"),
})
}
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states, false)?;
self.layer_norm.forward(&(hidden_states + input_tensor)?)
}
}
/// Copied from [`crate::models::bert::BertSelfAttention`] to [`ChineseClipTextSelfAttention`]
#[derive(Clone, Debug)]
struct ChineseClipTextSelfAttention {
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
dropout: nn::Dropout,
num_attention_heads: usize,
attention_head_size: usize,
span: tracing::Span,
span_softmax: tracing::Span,
}
impl ChineseClipTextSelfAttention {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
let hidden_size = config.hidden_size;
let query = nn::linear(hidden_size, all_head_size, var.pp("query"))?;
let value = nn::linear(hidden_size, all_head_size, var.pp("value"))?;
let key = nn::linear(hidden_size, all_head_size, var.pp("key"))?;
Ok(Self {
query,
key,
value,
dropout,
num_attention_heads: config.num_attention_heads,
attention_head_size,
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
})
}
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
let mut new_x_shape = xs.dims().to_vec();
new_x_shape.pop();
new_x_shape.push(self.num_attention_heads);
new_x_shape.push(self.attention_head_size);
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
xs.contiguous()
}
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let query_layer = self.query.forward(hidden_states)?;
let key_layer = self.key.forward(hidden_states)?;
let value_layer = self.value.forward(hidden_states)?;
let query_layer = self.transpose_for_scores(&query_layer)?;
let key_layer = self.transpose_for_scores(&key_layer)?;
let value_layer = self.transpose_for_scores(&value_layer)?;
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_scores = attention_scores.broadcast_add(attention_mask)?;
let attention_probs = {
let _enter_sm = self.span_softmax.enter();
nn::ops::softmax(&attention_scores, candle::D::Minus1)?
};
let attention_probs = self.dropout.forward(&attention_probs, false)?;
let context_layer = attention_probs.matmul(&value_layer)?;
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
Ok(context_layer)
}
}
/// Copied from [`crate::models::bert::BertAttention`] to [`ChineseClipTextAttention`]
#[derive(Clone, Debug)]
struct ChineseClipTextAttention {
self_attention: ChineseClipTextSelfAttention,
self_output: ChineseClipTextSelfOutput,
span: tracing::Span,
}
impl ChineseClipTextAttention {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let self_attention = ChineseClipTextSelfAttention::new(var.pp("self"), config)?;
let self_output = ChineseClipTextSelfOutput::new(var.pp("output"), config)?;
Ok(Self {
self_attention,
self_output,
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
Ok(attention_output)
}
}
type HiddenActLayer = Activation;
/// Copied from [`crate::models::bert::BertIntermediate`] to [`ChineseClipTextIntermediate`]
#[derive(Clone, Debug)]
struct ChineseClipTextIntermediate {
dense: nn::Linear,
intermediate_act: HiddenActLayer,
span: tracing::Span,
}
impl ChineseClipTextIntermediate {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let dense = nn::linear(
config.hidden_size,
config.intermediate_size,
var.pp("dense"),
)?;
Ok(Self {
dense,
intermediate_act: config.hidden_act,
span: tracing::span!(tracing::Level::TRACE, "inter"),
})
}
}
impl Module for ChineseClipTextIntermediate {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?;
let ys = self.intermediate_act.forward(&hidden_states)?;
Ok(ys)
}
}
/// Copied from [`crate::models::bert::BertOutput`] to [`ChineseClipTextOutput`]
#[derive(Clone, Debug)]
struct ChineseClipTextOutput {
dense: nn::Linear,
layer_norm: nn::LayerNorm,
dropout: nn::Dropout,
span: tracing::Span,
}
impl ChineseClipTextOutput {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let dense = nn::linear(
config.intermediate_size,
config.hidden_size,
var.pp("dense"),
)?;
let layer_norm = nn::layer_norm(
config.hidden_size,
config.layer_norm_eps,
var.pp("LayerNorm"),
)?;
let dropout = nn::Dropout::new(config.hidden_dropout_prob);
Ok(Self {
dense,
layer_norm,
dropout,
span: tracing::span!(tracing::Level::TRACE, "out"),
})
}
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states, false)?;
self.layer_norm.forward(&(hidden_states + input_tensor)?)
}
}
/// Copied from [`crate::models::bert::BertLayer`] to [`ChineseClipTextLayer`]
#[derive(Clone, Debug)]
struct ChineseClipTextLayer {
attention: ChineseClipTextAttention,
intermediate: ChineseClipTextIntermediate,
output: ChineseClipTextOutput,
span: tracing::Span,
}
impl ChineseClipTextLayer {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let attention = ChineseClipTextAttention::new(var.pp("attention"), config)?;
let intermediate = ChineseClipTextIntermediate::new(var.pp("intermediate"), config)?;
let output = ChineseClipTextOutput::new(var.pp("output"), config)?;
Ok(Self {
attention,
intermediate,
output,
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let attention_output = self.attention.forward(hidden_states, attention_mask)?;
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
let intermediate_output = self.intermediate.forward(&attention_output)?;
let layer_output = self
.output
.forward(&intermediate_output, &attention_output)?;
Ok(layer_output)
}
}
#[derive(Clone, Debug)]
struct Tanh;
impl Tanh {
pub fn new() -> Self {
Self {}
}
}
impl Module for Tanh {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.tanh()
}
}
#[derive(Clone, Debug)]
struct ChineseClipTextPooler {
dense: nn::Linear,
activation: Tanh,
}
impl ChineseClipTextPooler {
pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?;
let activation = Tanh::new();
Ok(Self { dense, activation })
}
}
impl Module for ChineseClipTextPooler {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let first_token_tensor = hidden_states.i((.., 0))?;
let pooled_output = self.dense.forward(&first_token_tensor)?;
let pooled_output = self.activation.forward(&pooled_output)?;
Ok(pooled_output)
}
}
#[derive(Clone, Debug)]
struct ChineseClipTextEncoder {
layers: Vec<ChineseClipTextLayer>,
span: tracing::Span,
}
impl ChineseClipTextEncoder {
fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| ChineseClipTextLayer::new(var.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(ChineseClipTextEncoder { layers, span })
}
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
for layer in self.layers.iter() {
hidden_states = layer.forward(&hidden_states, attention_mask)?
}
Ok(hidden_states)
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipTextTransformer {
embeddings: ChineseClipTextEmbeddings,
encoder: ChineseClipTextEncoder,
pooler: Option<ChineseClipTextPooler>,
pub device: Device,
span: tracing::Span,
}
impl ChineseClipTextTransformer {
pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result<Self> {
let embeddings = ChineseClipTextEmbeddings::new(var.pp("embeddings"), config)?;
let encoder = ChineseClipTextEncoder::new(var.pp("encoder"), config)?;
// see: https://github.com/huggingface/transformers/blob/e40bb4845e0eefb52ec1e9cac9c2446ab36aef81/src/transformers/models/chinese_clip/modeling_chinese_clip.py#L1362
// In the original Python version of the code, the pooler is not used, and there are no parameters for the pooler in the weight file.
let pooler = if var.contains_tensor("pooler") {
Some(ChineseClipTextPooler::new(var.pp("pooler"), config)?)
} else {
None
};
Ok(Self {
embeddings,
encoder,
pooler,
device: var.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
pub fn forward(
&self,
input_ids: &Tensor,
token_type_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
let attention_mask = match attention_mask {
Some(attention_mask) => attention_mask.clone(),
None => input_ids.ones_like()?,
};
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
let encoder_output = encoder_outputs.i((.., 0, ..))?;
let pooled_output = match &self.pooler {
Some(pooler) => pooler.forward(&encoder_output)?,
None => encoder_output,
};
Ok(pooled_output)
}
}
fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
let attention_mask = match attention_mask.rank() {
3 => attention_mask.unsqueeze(1)?,
2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
_ => candle::bail!("Wrong shape for input_ids or attention_mask"),
};
let attention_mask = attention_mask.to_dtype(dtype)?;
// torch.finfo(dtype).min
(attention_mask.ones_like()? - &attention_mask)?
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
}

View File

@ -0,0 +1,385 @@
//! Chinese contrastive Language-Image Pre-Training
//!
//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
//! pairs of images with related texts.
//!
//! https://github.com/OFA-Sys/Chinese-CLIP
//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py
use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D};
use candle_nn as nn;
use super::{Activation, EncoderConfig};
#[derive(Clone, Debug)]
pub struct ChineseClipVisionConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub projection_dim: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_channels: usize,
pub image_size: usize,
pub patch_size: usize,
pub hidden_act: Activation,
pub layer_norm_eps: f64,
pub attention_dropout: f32,
pub initializer_range: f32,
pub initializer_factor: f32,
}
impl Default for ChineseClipVisionConfig {
fn default() -> Self {
ChineseClipVisionConfig {
hidden_size: 768,
intermediate_size: 3072,
projection_dim: 512,
num_hidden_layers: 12,
num_attention_heads: 12,
num_channels: 3,
image_size: 224,
patch_size: 32,
hidden_act: Activation::QuickGelu,
layer_norm_eps: 1e-5,
attention_dropout: 0.0,
initializer_range: 0.02,
initializer_factor: 1.0,
}
}
}
impl ChineseClipVisionConfig {
/// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json
pub fn clip_vit_base_patch16() -> Self {
Self {
hidden_size: 768,
intermediate_size: 3072,
projection_dim: 512,
num_hidden_layers: 12,
num_attention_heads: 12,
num_channels: 3,
image_size: 224,
patch_size: 16,
hidden_act: Activation::QuickGelu,
layer_norm_eps: 1e-5,
attention_dropout: 0.0,
initializer_range: 0.02,
initializer_factor: 1.0,
}
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipVisionEmbeddings {
patch_embedding: nn::Conv2d,
position_ids: Tensor,
class_embedding: Tensor,
position_embedding: nn::Embedding,
}
impl ChineseClipVisionEmbeddings {
pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result<Self> {
let embed_dim = config.hidden_size;
// originally nn.Parameter
let class_embedding = if var.contains_tensor("class_embedding") {
var.get(embed_dim, "class_embedding")?
} else {
Tensor::randn(0f32, 1f32, embed_dim, var.device())?
};
let num_patches = (config.image_size / config.patch_size).pow(2);
let num_positions = num_patches + 1;
let position_ids = Tensor::arange(0, num_positions as i64, var.device())?;
let conv2dconfig = nn::Conv2dConfig {
stride: config.patch_size,
..Default::default()
};
let position_embedding =
nn::embedding(num_positions, embed_dim, var.pp("position_embedding"))?;
let patch_embedding = nn::conv2d_no_bias(
config.num_channels,
embed_dim,
config.patch_size,
conv2dconfig,
var.pp("patch_embedding"),
)?;
Ok(Self {
patch_embedding,
position_ids,
class_embedding,
position_embedding,
})
}
}
impl Module for ChineseClipVisionEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let batch_size = xs.shape().dims();
let patch_embeds = self
.patch_embedding
.forward(xs)?
.flatten_from(2)?
.transpose(1, 2)?;
let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
let class_embeds = self.class_embedding.expand(shape)?;
let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
embeddings.broadcast_add(&position_embedding)
}
}
#[derive(Clone, Debug)]
struct ChineseClipVisionAttention {
k_proj: nn::Linear,
v_proj: nn::Linear,
q_proj: nn::Linear,
out_proj: nn::Linear,
head_dim: usize,
scale: f64,
num_attention_heads: usize,
}
impl ChineseClipVisionAttention {
fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
let embed_dim = config.embed_dim();
let num_attention_heads = config.num_attention_heads();
let k_proj = nn::linear(embed_dim, embed_dim, var.pp("k_proj"))?;
let v_proj = nn::linear(embed_dim, embed_dim, var.pp("v_proj"))?;
let q_proj = nn::linear(embed_dim, embed_dim, var.pp("q_proj"))?;
let out_proj = nn::linear(embed_dim, embed_dim, var.pp("out_proj"))?;
let head_dim = embed_dim / num_attention_heads;
let scale = (head_dim as f64).powf(-0.5);
Ok(ChineseClipVisionAttention {
k_proj,
v_proj,
q_proj,
out_proj,
head_dim,
scale,
num_attention_heads,
})
}
fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()
}
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let in_dtype = xs.dtype();
let (bsz, seq_len, embed_dim) = xs.dims3()?;
let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
let query_states = self
.shape(&(self.q_proj.forward(xs)? * self.scale)?, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let key_states = self
.shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let value_states = self
.shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
.reshape(proj_shape)?
.to_dtype(DType::F32)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let src_len = key_states.dim(1)?;
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
attn_weights
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
.broadcast_add(causal_attention_mask)?
.reshape((bsz * self.num_attention_heads, seq_len, src_len))?
} else {
attn_weights
};
let attn_weights = nn::ops::softmax(&attn_weights, D::Minus1)?;
let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
let attn_output = attn_output
.reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
.transpose(1, 2)?
.reshape((bsz, seq_len, embed_dim))?;
self.out_proj.forward(&attn_output)
}
}
#[derive(Clone, Debug)]
struct ChineseClipVisionMlp {
fc1: nn::Linear,
fc2: nn::Linear,
activation: Activation,
}
impl ChineseClipVisionMlp {
fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
let fc1 = nn::linear(
config.embed_dim(),
config.intermediate_size(),
var.pp("fc1"),
)?;
let fc2 = nn::linear(
config.intermediate_size(),
config.embed_dim(),
var.pp("fc2"),
)?;
Ok(ChineseClipVisionMlp {
fc1,
fc2,
activation: config.activation(),
})
}
}
impl ChineseClipVisionMlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?;
self.fc2.forward(&self.activation.forward(&xs)?)
}
}
#[derive(Clone, Debug)]
struct ChineseClipVisionEncoderLayer {
self_attn: ChineseClipVisionAttention,
layer_norm1: nn::LayerNorm,
mlp: ChineseClipVisionMlp,
layer_norm2: nn::LayerNorm,
}
impl ChineseClipVisionEncoderLayer {
fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
let self_attn = ChineseClipVisionAttention::new(var.pp("self_attn"), config)?;
let layer_norm1 = nn::layer_norm(
config.embed_dim(),
config.layer_norm_eps(),
var.pp("layer_norm1"),
)?;
let mlp = ChineseClipVisionMlp::new(var.pp("mlp"), config)?;
let layer_norm2 = nn::layer_norm(
config.embed_dim(),
config.layer_norm_eps(),
var.pp("layer_norm2"),
)?;
Ok(ChineseClipVisionEncoderLayer {
self_attn,
layer_norm1,
mlp,
layer_norm2,
})
}
fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let residual = xs;
let xs = self.layer_norm1.forward(xs)?;
let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self.layer_norm2.forward(&xs)?;
let xs = self.mlp.forward(&xs)?;
xs + residual
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipVisionEncoder {
layers: Vec<ChineseClipVisionEncoderLayer>,
}
impl ChineseClipVisionEncoder {
pub fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result<Self> {
let vs = var.pp("layers");
let mut layers: Vec<ChineseClipVisionEncoderLayer> = Vec::new();
for index in 0..config.num_hidden_layers() {
let layer = ChineseClipVisionEncoderLayer::new(vs.pp(index.to_string()), config)?;
layers.push(layer)
}
Ok(ChineseClipVisionEncoder { layers })
}
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?;
}
Ok(xs)
}
// required by LLaVA
pub fn output_hidden_states(
&self,
xs: &Tensor,
causal_attention_mask: Option<&Tensor>,
) -> Result<Vec<Tensor>> {
let mut xs = xs.clone();
let mut hidden_states = Vec::new();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?;
hidden_states.push(xs.clone());
}
Ok(hidden_states)
}
}
#[derive(Clone, Debug)]
pub struct ChineseClipVisionTransformer {
embeddings: ChineseClipVisionEmbeddings,
encoder: ChineseClipVisionEncoder,
pre_layer_norm: nn::LayerNorm,
final_layer_norm: nn::LayerNorm,
}
impl ChineseClipVisionTransformer {
pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result<Self> {
let embed_dim = config.hidden_size;
let embeddings = ChineseClipVisionEmbeddings::new(var.pp("embeddings"), config)?;
let pre_layer_norm =
nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("pre_layrnorm"))?;
let encoder = ChineseClipVisionEncoder::new(
var.pp("encoder"),
&EncoderConfig::Vision(config.clone()),
)?;
let final_layer_norm =
nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("post_layernorm"))?;
Ok(Self {
embeddings,
encoder,
final_layer_norm,
pre_layer_norm,
})
}
// required by LLaVA
pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
let hidden_states = pixel_values
.apply(&self.embeddings)?
.apply(&self.pre_layer_norm)?;
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
let encoder_outputs = result.last().unwrap();
let pooled_output = encoder_outputs.i((.., 0, ..))?;
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
Ok(result)
}
}
impl Module for ChineseClipVisionTransformer {
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
let hidden_states = pixel_values
.apply(&self.embeddings)?
.apply(&self.pre_layer_norm)?;
let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
// referer: https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
let pooled_output = encoder_outputs.i((.., 0, ..))?;
self.final_layer_norm.forward(&pooled_output)
}
}

View File

@ -92,28 +92,23 @@ impl ClipConfig {
impl ClipModel {
pub fn new(vs: candle_nn::VarBuilder, c: &ClipConfig) -> Result<Self> {
let text_model = ClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?;
let vision_model = ClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?;
let visual_projection = candle_nn::linear_no_bias(
c.vision_config.embed_dim,
c.vision_config.projection_dim,
vs.pp("visual_projection"),
)?;
let text_projection = candle_nn::linear_no_bias(
c.text_config.embed_dim,
c.text_config.projection_dim,
vs.pp("text_projection"),
)?;
// originally nn.Parameter
let logit_scale = if vs.contains_tensor("logit_scale") {
vs.get(&[], "logit_scale")?
} else {
Tensor::new(&[c.logit_scale_init_value], vs.device())?
};
Ok(Self {
text_model,
vision_model,

View File

@ -77,7 +77,7 @@ impl ClipTextEmbeddings {
)?;
let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings {
Ok(Self {
token_embedding,
position_embedding,
position_ids,
@ -298,7 +298,7 @@ impl ClipTextTransformer {
})
}
// TODO: rewrrite to newer version
// TODO: rewrite to newer version
fn build_causal_attention_mask(
bsz: usize,
seq_len: usize,

View File

@ -0,0 +1,42 @@
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
use super::paligemma;
use candle_nn::{linear, Linear};
pub struct Model {
pub model: paligemma::Model,
pub custom_text_projection: Linear,
}
impl Model {
pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
let model = paligemma::Model::new(config, vb.pp("model"))?;
let custom_text_projection = linear(
config.text_config.hidden_size,
128,
vb.pp("custom_text_proj"),
)?;
Ok(Self {
model,
custom_text_projection,
})
}
pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self
.model
.setup_without_projection(pixel_values, input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self.model.forward_without_projection(input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
}

View File

@ -11,13 +11,13 @@ use candle_nn::{
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
};
#[derive(Clone, Debug)]
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
pub struct Config {
exp_ratio: usize,
in_channels: usize,
blocks: [usize; 4],
attn: bool,
lkc_use_act: bool,
pub exp_ratio: usize,
pub in_channels: usize,
pub blocks: [usize; 4],
pub attn: bool,
pub lkc_use_act: bool,
}
impl Config {
@ -495,7 +495,6 @@ fn fastvit_model(cfg: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Resul
.apply(&stage3)?
.apply(&stage4)?
.apply(&final_conv)?;
match &cls {
None => Ok(xs),
Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls),

View File

@ -362,6 +362,10 @@ impl Model {
})
}
pub fn embed_tokens(&self) -> &candle_nn::Embedding {
&self.embed_tokens
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
@ -399,6 +403,36 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn forward_embeds(
&mut self,
xs: &Tensor,
attn_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (_, seq_len, _) = xs.dims3()?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
// Forward the model and return the hidden states without the lm_head
pub fn forward_embeds_without_projection(
&mut self,
xs: &Tensor,
attn_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (_, _, _) = xs.dims3()?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
}
Ok(xs)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {

View File

@ -44,6 +44,7 @@ pub struct LlamaConfig {
pub eos_token_id: Option<LlamaEosToks>,
pub rope_scaling: Option<Llama3RopeConfig>,
pub max_position_embeddings: usize,
pub tie_word_embeddings: Option<bool>,
}
impl LlamaConfig {
@ -72,6 +73,7 @@ impl LlamaConfig {
eos_token_id: self.eos_token_id,
rope_scaling: self.rope_scaling,
max_position_embeddings: self.max_position_embeddings,
tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
}
}
}
@ -91,6 +93,7 @@ pub struct Config {
pub eos_token_id: Option<LlamaEosToks>,
pub rope_scaling: Option<Llama3RopeConfig>,
pub max_position_embeddings: usize,
pub tie_word_embeddings: bool,
}
impl Config {
@ -109,6 +112,7 @@ impl Config {
eos_token_id: None,
rope_scaling: None,
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
tie_word_embeddings: false,
}
}
@ -127,6 +131,7 @@ impl Config {
eos_token_id: None,
rope_scaling: None,
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
tie_word_embeddings: false,
}
}
}
@ -336,7 +341,8 @@ impl CausalSelfAttention {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
};
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};
@ -504,7 +510,11 @@ impl Llama {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let lm_head = if cfg.tie_word_embeddings {
Linear::from_weights(wte.embeddings().clone(), None)
} else {
linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
};
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())

View File

@ -43,6 +43,7 @@ pub struct LLaVAConfig {
pub image_token_index: isize,
#[serde(default = "default_hf")]
pub hf: bool,
pub tie_word_embeddings: Option<bool>,
}
fn default_hf() -> bool {
@ -77,6 +78,7 @@ impl LLaVAConfig {
use_flash_attn: false,
rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
max_position_embeddings: self.max_position_embeddings,
tie_word_embeddings: self.tie_word_embeddings.unwrap_or(false),
}
}
}
@ -264,6 +266,7 @@ impl HFLLaVAConfig {
use_cache: self.text_config.use_cache,
vocab_size: self.vocab_size,
image_token_index: self.image_token_index,
tie_word_embeddings: None,
}
}
}

View File

@ -279,7 +279,7 @@ impl LLaVA {
(),
))?
} else {
todo!("not implemented in original python LLaVA yet")
bail!("not implemented in original python LLaVA yet")
};
let new_image_feature = if mm_patch_merge_type.contains("unpad") {
let new_image_feature = new_image_feature

View File

@ -2,7 +2,7 @@ use super::with_tracing::{linear, Embedding, Linear};
use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
#[derive(Debug, Clone)]
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub decoder_vocab_size: Option<usize>,

View File

@ -4,19 +4,29 @@ use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
fn default_num_attention_heads() -> usize {
32
}
fn default_use_flash_attn() -> bool {
false
}
fn default_hidden_act() -> candle_nn::Activation {
candle_nn::Activation::Silu
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
#[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
pub head_dim: Option<usize>,
pub num_key_value_heads: usize,
#[serde(default = "default_hidden_act")]
pub hidden_act: Activation,
pub max_position_embeddings: usize,
pub rms_norm_eps: f64,
@ -107,14 +117,14 @@ impl RotaryEmbedding {
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
})
}
@ -404,6 +414,10 @@ impl Model {
.to_dtype(self.dtype)
}
pub fn embed_tokens(&self) -> &candle_nn::Embedding {
&self.embed_tokens
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
@ -421,6 +435,22 @@ impl Model {
.apply(&self.lm_head)
}
pub fn forward_embeds(
&mut self,
xs: &Tensor,
attn_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (_b_size, seq_len, _) = xs.dims3()?;
let mut xs = xs.clone();
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()

View File

@ -36,7 +36,6 @@ impl Module for LayerNormNoAffine {
impl DiTBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
// {'hidden_size': 1536, 'num_heads': 24}
let norm1 = LayerNormNoAffine::new(1e-6);
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
let norm2 = LayerNormNoAffine::new(1e-6);
@ -103,6 +102,117 @@ impl DiTBlock {
}
}
pub struct SelfAttnModulateIntermediates {
gate_msa: Tensor,
shift_mlp: Tensor,
scale_mlp: Tensor,
gate_mlp: Tensor,
gate_msa2: Tensor,
}
pub struct SelfAttnDiTBlock {
norm1: LayerNormNoAffine,
attn: AttnProjections,
attn2: AttnProjections,
norm2: LayerNormNoAffine,
mlp: Mlp,
ada_ln_modulation: nn::Sequential,
}
impl SelfAttnDiTBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
let norm1 = LayerNormNoAffine::new(1e-6);
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
let attn2 = AttnProjections::new(hidden_size, num_heads, vb.pp("attn2"))?;
let norm2 = LayerNormNoAffine::new(1e-6);
let mlp_ratio = 4;
let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
let n_mods = 9;
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
hidden_size,
n_mods * hidden_size,
vb.pp("adaLN_modulation.1"),
)?);
Ok(Self {
norm1,
attn,
attn2,
norm2,
mlp,
ada_ln_modulation,
})
}
pub fn pre_attention(
&self,
x: &Tensor,
c: &Tensor,
) -> Result<(Qkv, Qkv, SelfAttnModulateIntermediates)> {
let modulation = self.ada_ln_modulation.forward(c)?;
let chunks = modulation.chunk(9, D::Minus1)?;
let (
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
shift_msa2,
scale_msa2,
gate_msa2,
) = (
chunks[0].clone(),
chunks[1].clone(),
chunks[2].clone(),
chunks[3].clone(),
chunks[4].clone(),
chunks[5].clone(),
chunks[6].clone(),
chunks[7].clone(),
chunks[8].clone(),
);
let norm_x = self.norm1.forward(x)?;
let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
let qkv = self.attn.pre_attention(&modulated_x)?;
let modulated_x2 = modulate(&norm_x, &shift_msa2, &scale_msa2)?;
let qkv2 = self.attn2.pre_attention(&modulated_x2)?;
Ok((
qkv,
qkv2,
SelfAttnModulateIntermediates {
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
gate_msa2,
},
))
}
pub fn post_attention(
&self,
attn: &Tensor,
attn2: &Tensor,
x: &Tensor,
mod_interm: &SelfAttnModulateIntermediates,
) -> Result<Tensor> {
let attn_out = self.attn.post_attention(attn)?;
let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
let attn_out2 = self.attn2.post_attention(attn2)?;
let x = x.add(&attn_out2.broadcast_mul(&mod_interm.gate_msa2.unsqueeze(1)?)?)?;
let norm_x = self.norm2.forward(&x)?;
let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
let mlp_out = self.mlp.forward(&modulated_x)?;
let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
Ok(x)
}
}
pub struct QkvOnlyDiTBlock {
norm1: LayerNormNoAffine,
attn: QkvOnlyAttnProjections,
@ -190,14 +300,24 @@ fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?)
}
pub struct JointBlock {
pub trait JointBlock {
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)>;
}
pub struct MMDiTJointBlock {
x_block: DiTBlock,
context_block: DiTBlock,
num_heads: usize,
use_flash_attn: bool,
}
impl JointBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
impl MMDiTJointBlock {
pub fn new(
hidden_size: usize,
num_heads: usize,
use_flash_attn: bool,
vb: nn::VarBuilder,
) -> Result<Self> {
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
@ -205,13 +325,17 @@ impl JointBlock {
x_block,
context_block,
num_heads,
use_flash_attn,
})
}
}
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
impl JointBlock for MMDiTJointBlock {
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
let (context_attn, x_attn) =
joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
let context_out =
self.context_block
.post_attention(&context_attn, context, &context_interm)?;
@ -220,20 +344,70 @@ impl JointBlock {
}
}
pub struct MMDiTXJointBlock {
x_block: SelfAttnDiTBlock,
context_block: DiTBlock,
num_heads: usize,
use_flash_attn: bool,
}
impl MMDiTXJointBlock {
pub fn new(
hidden_size: usize,
num_heads: usize,
use_flash_attn: bool,
vb: nn::VarBuilder,
) -> Result<Self> {
let x_block = SelfAttnDiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
Ok(Self {
x_block,
context_block,
num_heads,
use_flash_attn,
})
}
}
impl JointBlock for MMDiTXJointBlock {
fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_qkv2, x_interm) = self.x_block.pre_attention(x, c)?;
let (context_attn, x_attn) =
joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
let x_attn2 = attn(&x_qkv2, self.num_heads, self.use_flash_attn)?;
let context_out =
self.context_block
.post_attention(&context_attn, context, &context_interm)?;
let x_out = self
.x_block
.post_attention(&x_attn, &x_attn2, x, &x_interm)?;
Ok((context_out, x_out))
}
}
pub struct ContextQkvOnlyJointBlock {
x_block: DiTBlock,
context_block: QkvOnlyDiTBlock,
num_heads: usize,
use_flash_attn: bool,
}
impl ContextQkvOnlyJointBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
pub fn new(
hidden_size: usize,
num_heads: usize,
use_flash_attn: bool,
vb: nn::VarBuilder,
) -> Result<Self> {
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
Ok(Self {
x_block,
context_block,
num_heads,
use_flash_attn,
})
}
@ -241,7 +415,7 @@ impl ContextQkvOnlyJointBlock {
let context_qkv = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
Ok(x_out)
@ -266,29 +440,58 @@ fn flash_compatible_attention(
attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
}
fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> {
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
fn joint_attn(
context_qkv: &Qkv,
x_qkv: &Qkv,
num_heads: usize,
use_flash_attn: bool,
) -> Result<(Tensor, Tensor)> {
let qkv = Qkv {
q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?,
};
let (batch_size, seqlen, _) = qkv.q.dims3()?;
let qkv = Qkv {
q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
v: qkv.v,
};
let headdim = qkv.q.dim(D::Minus1)?;
let softmax_scale = 1.0 / (headdim as f64).sqrt();
// let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?;
let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?;
let attn = attn.reshape((batch_size, seqlen, ()))?;
let seqlen = qkv.q.dim(1)?;
let attn = attn(&qkv, num_heads, use_flash_attn)?;
let context_qkv_seqlen = context_qkv.q.dim(1)?;
let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;
Ok((context_attn, x_attn))
}
fn attn(qkv: &Qkv, num_heads: usize, use_flash_attn: bool) -> Result<Tensor> {
let batch_size = qkv.q.dim(0)?;
let seqlen = qkv.q.dim(1)?;
let qkv = Qkv {
q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
v: qkv.v.clone(),
};
let headdim = qkv.q.dim(D::Minus1)?;
let softmax_scale = 1.0 / (headdim as f64).sqrt();
let attn = if use_flash_attn {
flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
} else {
flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
};
attn.reshape((batch_size, seqlen, ()))
}

View File

@ -1,10 +1,15 @@
// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206).
// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206),
// as well as the MMDiT-X variant introduced for Stable Diffusion 3.5-medium (https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
// This follows the implementation of the MMDiT model in the ComfyUI repository.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1
// with MMDiT-X support following the Stability-AI/sd3.5 repository.
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py#L1
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock};
use super::blocks::{
ContextQkvOnlyJointBlock, FinalLayer, JointBlock, MMDiTJointBlock, MMDiTXJointBlock,
};
use super::embedding::{
PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
};
@ -23,7 +28,7 @@ pub struct Config {
}
impl Config {
pub fn sd3() -> Self {
pub fn sd3_medium() -> Self {
Self {
patch_size: 2,
in_channels: 16,
@ -36,6 +41,34 @@ impl Config {
frequency_embedding_size: 256,
}
}
pub fn sd3_5_medium() -> Self {
Self {
patch_size: 2,
in_channels: 16,
out_channels: 16,
depth: 24,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 384,
context_embed_size: 4096,
frequency_embedding_size: 256,
}
}
pub fn sd3_5_large() -> Self {
Self {
patch_size: 2,
in_channels: 16,
out_channels: 16,
depth: 38,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 192,
context_embed_size: 4096,
frequency_embedding_size: 256,
}
}
}
pub struct MMDiT {
@ -49,7 +82,7 @@ pub struct MMDiT {
}
impl MMDiT {
pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> {
pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result<Self> {
let hidden_size = cfg.head_size * cfg.depth;
let core = MMDiTCore::new(
cfg.depth,
@ -57,6 +90,7 @@ impl MMDiT {
cfg.depth,
cfg.patch_size,
cfg.out_channels,
use_flash_attn,
vb.clone(),
)?;
let patch_embedder = PatchEmbedder::new(
@ -96,7 +130,14 @@ impl MMDiT {
})
}
pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
pub fn forward(
&self,
x: &Tensor,
t: &Tensor,
y: &Tensor,
context: &Tensor,
skip_layers: Option<&[usize]>,
) -> Result<Tensor> {
// Following the convention of the ComfyUI implementation.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
//
@ -116,14 +157,14 @@ impl MMDiT {
let c = (c + y)?;
let context = self.context_embedder.forward(context)?;
let x = self.core.forward(&context, &x, &c)?;
let x = self.core.forward(&context, &x, &c, skip_layers)?;
let x = self.unpatchifier.unpatchify(&x, h, w)?;
x.narrow(2, 0, h)?.narrow(3, 0, w)
}
}
pub struct MMDiTCore {
joint_blocks: Vec<JointBlock>,
joint_blocks: Vec<Box<dyn JointBlock>>,
context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
final_layer: FinalLayer,
}
@ -135,15 +176,29 @@ impl MMDiTCore {
num_heads: usize,
patch_size: usize,
out_channels: usize,
use_flash_attn: bool,
vb: nn::VarBuilder,
) -> Result<Self> {
let mut joint_blocks = Vec::with_capacity(depth - 1);
for i in 0..depth - 1 {
joint_blocks.push(JointBlock::new(
hidden_size,
num_heads,
vb.pp(format!("joint_blocks.{}", i)),
)?);
let joint_block_vb_pp = format!("joint_blocks.{}", i);
let joint_block: Box<dyn JointBlock> =
if vb.contains_tensor(&format!("{}.x_block.attn2.qkv.weight", joint_block_vb_pp)) {
Box::new(MMDiTXJointBlock::new(
hidden_size,
num_heads,
use_flash_attn,
vb.pp(&joint_block_vb_pp),
)?)
} else {
Box::new(MMDiTJointBlock::new(
hidden_size,
num_heads,
use_flash_attn,
vb.pp(&joint_block_vb_pp),
)?)
};
joint_blocks.push(joint_block);
}
Ok(Self {
@ -151,6 +206,7 @@ impl MMDiTCore {
context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
hidden_size,
num_heads,
use_flash_attn,
vb.pp(format!("joint_blocks.{}", depth - 1)),
)?,
final_layer: FinalLayer::new(
@ -162,9 +218,20 @@ impl MMDiTCore {
})
}
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
pub fn forward(
&self,
context: &Tensor,
x: &Tensor,
c: &Tensor,
skip_layers: Option<&[usize]>,
) -> Result<Tensor> {
let (mut context, mut x) = (context.clone(), x.clone());
for joint_block in &self.joint_blocks {
for (i, joint_block) in self.joint_blocks.iter().enumerate() {
if let Some(skip_layers) = &skip_layers {
if skip_layers.contains(&i) {
continue;
}
}
(context, x) = joint_block.forward(&context, &x, c)?;
}
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;

View File

@ -42,7 +42,6 @@ pub struct QkvOnlyAttnProjections {
impl QkvOnlyAttnProjections {
pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
// {'dim': 1536, 'num_heads': 24}
let head_dim = dim / num_heads;
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
Ok(Self { qkv, head_dim })
@ -57,6 +56,8 @@ impl QkvOnlyAttnProjections {
pub struct AttnProjections {
head_dim: usize,
qkv: nn::Linear,
ln_k: Option<candle_nn::RmsNorm>,
ln_q: Option<candle_nn::RmsNorm>,
proj: nn::Linear,
}
@ -65,16 +66,42 @@ impl AttnProjections {
let head_dim = dim / num_heads;
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
let proj = nn::linear(dim, dim, vb.pp("proj"))?;
let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") {
let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?;
let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?;
(Some(ln_k), Some(ln_q))
} else {
(None, None)
};
Ok(Self {
head_dim,
qkv,
proj,
ln_k,
ln_q,
})
}
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
let qkv = self.qkv.forward(x)?;
split_qkv(&qkv, self.head_dim)
let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?;
let q = match self.ln_q.as_ref() {
None => q,
Some(l) => {
let (b, t, h) = q.dims3()?;
l.forward(&q.reshape((b, t, (), self.head_dim))?)?
.reshape((b, t, h))?
}
};
let k = match self.ln_k.as_ref() {
None => k,
Some(l) => {
let (b, t, h) = k.dims3()?;
l.forward(&k.reshape((b, t, (), self.head_dim))?)?
.reshape((b, t, h))?
}
};
Ok(Qkv { q, k, v })
}
pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {

View File

@ -22,7 +22,6 @@ impl MobileClipConfig {
pub fn s1() -> Self {
let text_config = text_model::Config::vit_base_patch32();
let vision_config = fastvit::Config::mci1();
Self {
text_config,
vision_config,
@ -32,7 +31,6 @@ impl MobileClipConfig {
pub fn s2() -> Self {
let text_config = text_model::Config::vit_base_patch32();
let vision_config = fastvit::Config::mci2();
Self {
text_config,
vision_config,
@ -45,12 +43,10 @@ impl MobileClipModel {
pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result<Self> {
let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?;
let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?;
let text_projection = vs.get(
(c.text_config.embed_dim, c.text_config.projection_dim),
"text.text_projection",
)?;
let logit_scale = vs.get(&[], "logit_scale")?;
Ok(Self {
text_model,

View File

@ -5,8 +5,10 @@ pub mod bigcode;
pub mod blip;
pub mod blip_text;
pub mod chatglm;
pub mod chinese_clip;
pub mod clip;
pub mod codegeex4_9b;
pub mod colpali;
pub mod convmixer;
pub mod convnext;
pub mod dac;
@ -46,10 +48,12 @@ pub mod moondream;
pub mod mpt;
pub mod olmo;
pub mod openclip;
pub mod paligemma;
pub mod parler_tts;
pub mod persimmon;
pub mod phi;
pub mod phi3;
pub mod pixtral;
pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_llama;
@ -76,9 +80,11 @@ pub mod rwkv_v5;
pub mod rwkv_v6;
pub mod segformer;
pub mod segment_anything;
pub mod siglip;
pub mod stable_diffusion;
pub mod stable_lm;
pub mod starcoder2;
pub mod stella_en_v5;
pub mod t5;
pub mod trocr;
pub mod vgg;

View File

@ -0,0 +1,154 @@
use crate::models::{gemma, siglip};
use candle::{Module, Result, Tensor};
use candle_nn::{linear, Linear, VarBuilder};
#[derive(serde::Deserialize, Clone, Debug)]
pub struct Config {
pub vision_config: siglip::VisionConfig,
pub text_config: gemma::Config,
pub projection_dim: usize,
}
impl Config {
pub fn paligemma_3b_224() -> Self {
// https://huggingface.co/google/paligemma-3b-pt-224/blob/main/config.json
Self {
vision_config: siglip::VisionConfig::paligemma_3b_224(),
text_config: gemma::Config {
hidden_size: 2048,
intermediate_size: 16384,
num_attention_heads: 8,
num_hidden_layers: 18,
num_key_value_heads: 1,
vocab_size: 257216,
// Default values.
rope_theta: 10000.,
head_dim: 256,
hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
hidden_activation: None,
attention_bias: false,
max_position_embeddings: 8192,
rms_norm_eps: 1e-6,
},
projection_dim: 2048,
}
}
pub fn paligemma_3b_448() -> Self {
Self {
vision_config: siglip::VisionConfig::paligemma_3b_448(),
text_config: gemma::Config {
hidden_size: 2048,
intermediate_size: 16384,
num_attention_heads: 8,
num_hidden_layers: 18,
num_key_value_heads: 1,
// Default values.
rope_theta: 10000.,
head_dim: 256,
hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
hidden_activation: None,
attention_bias: false,
max_position_embeddings: 8192,
rms_norm_eps: 1e-6,
vocab_size: 257216,
},
projection_dim: 2048,
}
}
}
#[derive(Clone, Debug)]
pub struct MultiModalProjector {
linear: Linear,
}
impl MultiModalProjector {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let linear = linear(
cfg.vision_config.hidden_size,
cfg.projection_dim,
vb.pp("linear"),
)?;
Ok(Self { linear })
}
}
impl Module for MultiModalProjector {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.linear)
}
}
#[derive(Clone, Debug)]
pub struct Model {
pos: usize,
vision_tower: siglip::VisionModel,
multi_modal_projector: MultiModalProjector,
language_model: gemma::Model,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vision_tower = siglip::VisionModel::new(
&cfg.vision_config,
false,
vb.pp("vision_tower.vision_model"),
)?;
let multi_modal_projector = MultiModalProjector::new(cfg, vb.pp("multi_modal_projector"))?;
let language_model = gemma::Model::new(false, &cfg.text_config, vb.pp("language_model"))?;
Ok(Self {
pos: 0,
language_model,
vision_tower,
multi_modal_projector,
})
}
pub fn setup(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
self.clear_kv_cache();
let image_features = self
.vision_tower
.forward(pixel_values)?
.apply(&self.multi_modal_projector)?;
let image_features = crate::models::clip::div_l2_norm(&image_features)?;
let text_features = self.language_model.embed_tokens().forward(input_ids)?;
let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
self.pos = input_embeds.dim(1)?;
self.language_model.forward_embeds(&input_embeds, None, 0)
}
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let pos = self.pos;
let seq_len = input_ids.dim(1)?;
self.pos = pos + seq_len;
self.language_model.forward(input_ids, pos)
}
pub fn forward_without_projection(&mut self, input_ids: &Tensor) -> Result<Tensor> {
self.clear_kv_cache();
let input_embeds = self.language_model.embed_tokens().forward(input_ids)?;
self.language_model
.forward_embeds_without_projection(&input_embeds, None, 0)
}
pub fn setup_without_projection(
&mut self,
pixel_values: &Tensor,
input_ids: &Tensor,
) -> Result<Tensor> {
self.clear_kv_cache();
let image_features = self
.vision_tower
.forward(pixel_values)?
.apply(&self.multi_modal_projector)?;
let image_features = crate::models::clip::div_l2_norm(&image_features)?;
let text_features = self.language_model.embed_tokens().forward(input_ids)?;
let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
self.language_model
.forward_embeds_without_projection(&input_embeds, None, 0)
}
pub fn clear_kv_cache(&mut self) {
self.pos = 0;
self.language_model.clear_kv_cache()
}
}

Some files were not shown because too many files have changed in this diff Show More